about summary refs log tree commit diff
path: root/corgi/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'corgi/src/main.rs')
-rw-r--r--corgi/src/main.rs249
1 files changed, 89 insertions, 160 deletions
diff --git a/corgi/src/main.rs b/corgi/src/main.rs
index 6a3c528..1772d68 100644
--- a/corgi/src/main.rs
+++ b/corgi/src/main.rs
@@ -1,119 +1,34 @@
-use std::{
-	net::{IpAddr, SocketAddr},
-	path::PathBuf,
-	pin::Pin,
-	process::Stdio,
-	sync::Arc,
-	time::Instant,
-};
+use core::fmt;
+use std::{net::SocketAddr, pin::Pin, sync::Arc};
 
 use caller::HttpRequest;
-use confindent::{Confindent, Value, ValueParseError};
 use http_body_util::{BodyExt, Full};
 use hyper::{
 	HeaderMap, Request, Response, StatusCode,
 	body::{Bytes, Incoming},
-	header::HeaderValue,
 	server::conn::http1,
 	service::Service,
 };
 use hyper_util::rt::TokioIo;
-use regex_lite::Regex;
+use settings::{Script, Settings};
 use stats::Stats;
-use tokio::{io::AsyncWriteExt, net::TcpListener, process::Command, runtime::Runtime};
+use tokio::{net::TcpListener, runtime::Runtime};
+use util::owned_header;
 
 mod caller;
+mod settings;
 mod stats;
-
-#[derive(Clone, Debug)]
-pub struct Settings {
-	port: u16,
-	scripts: Vec<Script>,
-}
-
-#[derive(Clone, Debug, PartialEq)]
-pub enum ScriptKind {
-	Executable,
-	Object,
-}
-
-#[derive(Clone, Debug)]
-pub struct Script {
-	name: String,
-	kind: ScriptKind,
-	regex: Option<Regex>,
-	filename: String,
-	env: Vec<(String, String)>,
-}
-
-const CONF_DEFAULT: &str = "/etc/corgi.conf";
+mod util;
 
 fn main() {
-	let conf_path = std::env::args().nth(1).unwrap_or(String::from(CONF_DEFAULT));
-	let conf = Confindent::from_file(conf_path).expect("failed to open conf");
-
-	let mut settings = Settings {
-		port: 26744,
-		scripts: conf.children("Script").into_iter().map(parse_script_conf).collect(),
-	};
-
-	if let Some(server) = conf.child("Server") {
-		match server.child_parse("Port") {
-			Err(ValueParseError::NoValue) => (),
-			Err(err) => {
-				eprintln!("Server.Port is malformed: {err}");
-				std::process::exit(1);
-			}
-			Ok(port) => settings.port = port,
-		}
-	}
-
-	let stats = Stats::new(PathBuf::from(
-		conf.get("Server/StatsDb").unwrap().to_owned(),
-	));
+	let settings = Settings::get();
+	let stats = Stats::new(&settings.stats_path);
 	stats.create_tables();
 
 	let rt = Runtime::new().unwrap();
 	rt.block_on(async { run(settings, stats).await });
 }
 
-fn parse_script_conf(conf: &Value) -> Script {
-	let name = conf.value_owned().expect("Missing value for 'Script' key");
-	let filename = conf.child_owned("Path").expect("Missing 'Path' key");
-	let environment = conf.child("Environment");
-	let env = environment
-		.map(|e| e.values().map(|v| (v.key_owned(), v.value_owned().unwrap())).collect());
-
-	let regex = match conf.get("Match/Regex") {
-		None => None,
-		Some(restr) => match Regex::new(restr) {
-			Err(err) => {
-				eprintln!("Failed to compile regex: {restr}\nerror: {err}");
-				std::process::exit(1);
-			}
-			Ok(re) => Some(re),
-		},
-	};
-
-	let kind = match conf.get("Type") {
-		None => ScriptKind::Executable,
-		Some("executable") => ScriptKind::Executable,
-		Some("object") => ScriptKind::Object,
-		Some(kind) => {
-			eprintln!("'{kind}' is not a valid script type");
-			std::process::exit(1)
-		}
-	};
-
-	Script {
-		name,
-		kind,
-		regex,
-		filename,
-		env: env.unwrap_or_default(),
-	}
-}
-
 // We have tokio::main at home :)
 async fn run(settings: Settings, stats: Stats) {
 	let addr = SocketAddr::from(([0, 0, 0, 0], settings.port));
@@ -164,61 +79,48 @@ impl Svc {
 		caddr: SocketAddr,
 		req: Request<Incoming>,
 	) -> Response<Full<Bytes>> {
-		let start = Instant::now();
+		match Self::handle_fallible(settings, stats, caddr, req).await {
+			Err(re) => re.into_response(),
+			Ok(response) => response,
+		}
+	}
 
+	async fn handle_fallible(
+		settings: Settings,
+		stats: Arc<Stats>,
+		caddr: SocketAddr,
+		req: Request<Incoming>,
+	) -> Result<Response<Full<Bytes>>, RuntimeError> {
 		// Collect things we need from the request before we eat it's body
 		let method = req.method().as_str().to_ascii_uppercase();
 		let version = req.version();
-		let path = req.uri().path().to_owned();
-		let query = req.uri().query().unwrap_or_default().to_owned();
-		let headers = req.headers().clone();
-
-		let body = req.into_body().collect().await.unwrap().to_bytes().to_vec();
-		let content_length = body.len();
-
-		let mut maybe_script = None;
-		for set_script in settings.scripts {
-			if let Some(regex) = set_script.regex.as_ref() {
-				if regex.is_match(&path) {
-					maybe_script = Some(set_script);
-					break;
-				}
-			} else {
-				maybe_script = Some(set_script);
-				break;
-			}
-		}
-
-		let script = match maybe_script {
-			Some(script) => script,
-			None => {
-				eprintln!("path didn't match any script");
-				panic!("TODO recover?");
-			}
-		};
+		let path = util::url_decode(req.uri().path(), false)?;
+		let query = req
+			.uri()
+			.query()
+			.map(|s| util::url_decode(s, false))
+			.transpose()?
+			.unwrap_or_default();
 
-		let content_type = headers
-			.get("content-type")
-			.map(|s| s.to_str().ok())
-			.flatten()
-			.unwrap_or_default()
-			.to_owned();
+		let script = Self::select_script(&settings, &path).ok_or(RuntimeError::NoScript)?;
 
-		let uagent = headers
-			.get("user-agent")
-			.map(|s| s.to_str().ok())
-			.flatten()
-			.unwrap_or_default()
-			.to_owned();
+		// Clone the headers and extract what we need
+		let headers = req.headers().clone();
+		let content_type = owned_header(headers.get("content-type")).unwrap_or_default();
+		let uagent = owned_header(headers.get("user-agent")).unwrap_or_default();
 
 		// Find the client address
 		let client_addr = {
-			let x_forward = Self::parse_addr_from_header(headers.get("x-forwarded-for"));
-			let forward = Self::parse_addr_from_header(headers.get("forwarded-for"));
+			let x_forward = util::parse_from_header(headers.get("x-forwarded-for"));
+			let forward = util::parse_from_header(headers.get("forwarded-for"));
 
 			forward.unwrap_or(x_forward.unwrap_or(caddr.ip()))
 		};
 
+		// Finally, get the body which consumes the request
+		let body = req.into_body().collect().await.unwrap().to_bytes().to_vec();
+		let content_length = body.len();
+
 		let server_name = headers
 			.get("Host")
 			.expect("no http host header set")
@@ -239,16 +141,9 @@ impl Svc {
 			body: if content_length > 0 { Some(body) } else { None },
 		};
 
-		let start_cgi = Instant::now();
-		let cgi_response = match script.kind {
-			ScriptKind::Executable => {
-				caller::call_and_parse_cgi(script.clone(), http_request).await
-			}
-			ScriptKind::Object => caller::call_and_parse_module(script.clone(), http_request).await,
-		};
-		let cgi_time = start_cgi.elapsed();
-
+		let cgi_response = caller::call_and_parse_cgi(script.clone(), http_request).await;
 		let status = StatusCode::from_u16(cgi_response.status).unwrap();
+
 		let mut response = Response::builder().status(status);
 
 		for (key, value) in cgi_response.headers {
@@ -263,20 +158,32 @@ impl Svc {
 		};
 
 		println!(
-			"served to [{client_addr}]\n\tscript: {}\n\tpath: {path}\n\tcgi took {}ms. total time {}ms\n\tUA: {uagent}",
-			&script.name,
-			cgi_time.as_millis(),
-			start.elapsed().as_millis()
+			"served to [{client_addr}]\n\tscript: {}\n\tpath: {path}\n\tUA: {uagent}",
+			&script.name
 		);
 
 		stats.log_request(db_req);
 
-		let response_body = cgi_response.body.map(|v| Bytes::from(v)).unwrap_or(Bytes::new());
-		response.body(Full::new(response_body)).unwrap()
+		let response_body = cgi_response
+			.body
+			.map(|v| Bytes::from(v))
+			.unwrap_or(Bytes::new());
+
+		Ok(response.body(Full::new(response_body)).unwrap())
 	}
 
-	fn parse_addr_from_header(maybe_hval: Option<&HeaderValue>) -> Option<IpAddr> {
-		maybe_hval.map(|h| h.to_str().ok()).flatten().map(|s| s.parse().ok()).flatten()
+	fn select_script<'s>(settings: &'s Settings, path: &str) -> Option<&'s Script> {
+		for script in &settings.scripts {
+			if let Some(regex) = script.regex.as_ref() {
+				if regex.is_match(path) {
+					return Some(script);
+				}
+			} else {
+				return Some(script);
+			}
+		}
+
+		None
 	}
 
 	fn build_http_vec(headers: HeaderMap) -> Vec<(String, String)> {
@@ -312,13 +219,35 @@ impl Svc {
 	}
 }
 
-fn path_to_name(path: &str) -> String {
-	let mut ret = String::with_capacity(path.len());
-	for ch in path.chars() {
-		match ch {
-			'/' => ret.push('-'),
-			ch => ret.push(ch),
+fn status_page<D: fmt::Display>(status: u16, msg: D) -> Response<Full<Bytes>> {
+	let body_str = format!(
+		"<html>\n\
+			\t<head><title>{status}</title></head>\n\
+			\t<body style='width: 20rem; padding: 0px; margin: 2rem;'>\n\
+			\t\t<h1>{status}</h1>\n\
+			\t\t<hr/>\n\
+			\t\t<p>{msg}</p>\n\
+			\t</body>\n\
+		</html>"
+	);
+
+	Response::builder()
+		.status(status)
+		.header("Content-Type", "text/html")
+		.body(Full::new(body_str.into()))
+		.unwrap()
+}
+
+enum RuntimeError {
+	MalformedRequest,
+	NoScript,
+}
+
+impl RuntimeError {
+	pub fn into_response(&self) -> Response<Full<Bytes>> {
+		match self {
+			Self::MalformedRequest => status_page(400, "bad request"),
+			Self::NoScript => status_page(404, "failed to route request"),
 		}
 	}
-	ret
 }