about summary refs log tree commit diff
path: root/corgi/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'corgi/src/main.rs')
-rw-r--r--corgi/src/main.rs396
1 files changed, 396 insertions, 0 deletions
diff --git a/corgi/src/main.rs b/corgi/src/main.rs
new file mode 100644
index 0000000..9831331
--- /dev/null
+++ b/corgi/src/main.rs
@@ -0,0 +1,396 @@
+use std::{
+	net::{IpAddr, SocketAddr},
+	pin::Pin,
+	process::Stdio,
+	time::Instant,
+};
+
+use confindent::{Confindent, Value, ValueParseError};
+use http_body_util::{BodyExt, Full};
+use hyper::{
+	HeaderMap, Request, Response, StatusCode,
+	body::{Bytes, Incoming},
+	header::{self, HeaderValue},
+	server::conn::http1,
+	service::Service,
+};
+use hyper_util::rt::TokioIo;
+use regex_lite::Regex;
+use tokio::{
+	io::{AsyncWriteExt, BufWriter},
+	net::TcpListener,
+	process::Command,
+	runtime::Runtime,
+};
+
+#[derive(Clone, Debug)]
+pub struct Settings {
+	port: u16,
+	scripts: Vec<Script>,
+}
+
+#[derive(Clone, Debug)]
+pub struct Script {
+	name: String,
+	regex: Option<Regex>,
+	filename: String,
+	env: Vec<(String, String)>,
+}
+
+const CONF_DEFAULT: &str = "/etc/corgi.conf";
+
+fn main() {
+	let conf_path = std::env::args().nth(1).unwrap_or(String::from(CONF_DEFAULT));
+	let conf = Confindent::from_file(conf_path).expect("failed to open conf");
+
+	let mut settings = Settings {
+		port: 26744,
+		scripts: conf.children("Script").into_iter().map(parse_script_conf).collect(),
+	};
+
+	if let Some(server) = conf.child("Server") {
+		match server.child_parse("Port") {
+			Err(ValueParseError::NoValue) => (),
+			Err(err) => {
+				eprintln!("Server.Port is malformed: {err}");
+				std::process::exit(1);
+			}
+			Ok(port) => settings.port = port,
+		}
+	}
+
+	let rt = Runtime::new().unwrap();
+	rt.block_on(async { run(settings).await });
+}
+
+fn parse_script_conf(conf: &Value) -> Script {
+	let name = conf.value_owned().expect("Missing value for 'Script' key");
+	let filename = conf.child_owned("Path").expect("Missing 'Path' key");
+	let environment = conf.child("Environment");
+	let env = environment
+		.map(|e| e.values().map(|v| (v.key_owned(), v.value_owned().unwrap())).collect());
+
+	let regex = match conf.get("Match/Regex") {
+		None => None,
+		Some(restr) => match Regex::new(restr) {
+			Err(err) => {
+				eprintln!("Failed to compile regex: {restr}\nerror: {err}");
+				std::process::exit(1);
+			}
+			Ok(re) => Some(re),
+		},
+	};
+
+	Script {
+		name,
+		regex,
+		filename,
+		env: env.unwrap_or_default(),
+	}
+}
+
+// We have tokio::main at home :)
+async fn run(settings: Settings) {
+	let addr = SocketAddr::from(([0, 0, 0, 0], settings.port));
+	let listen = TcpListener::bind(addr).await.unwrap();
+
+	let svc = Svc {
+		settings,
+		client_addr: addr,
+	};
+
+	loop {
+		let (stream, caddr) = listen.accept().await.unwrap();
+		let io = TokioIo::new(stream);
+		let mut svc_clone = svc.clone();
+		svc_clone.client_addr = caddr;
+		tokio::task::spawn(
+			async move { http1::Builder::new().serve_connection(io, svc_clone).await },
+		);
+	}
+}
+
+#[derive(Clone, Debug)]
+struct Svc {
+	settings: Settings,
+	client_addr: SocketAddr,
+}
+
+impl Service<Request<Incoming>> for Svc {
+	type Response = Response<Full<Bytes>>;
+	type Error = hyper::Error;
+	type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
+
+	fn call(&self, req: Request<Incoming>) -> Self::Future {
+		let settings = self.settings.clone();
+		let caddr = self.client_addr;
+
+		Box::pin(async move { Ok(Self::handle(settings, caddr, req).await) })
+	}
+}
+
+impl Svc {
+	async fn handle(
+		settings: Settings,
+		caddr: SocketAddr,
+		req: Request<Incoming>,
+	) -> Response<Full<Bytes>> {
+		let start = Instant::now();
+
+		// Collect things we need from the request before we eat it's body
+		let method = req.method().as_str().to_ascii_uppercase();
+		let version = req.version();
+		let path = req.uri().path().to_owned();
+		let query = req.uri().query().unwrap_or_default().to_owned();
+		let headers = req.headers().clone();
+		let body = req.into_body().collect().await.unwrap().to_bytes();
+		let content_length = body.len();
+
+		let mut script = settings.scripts[0].clone();
+
+		for set_script in settings.scripts {
+			if let Some(regex) = set_script.regex.as_ref() {
+				if regex.is_match(&path) {
+					script = set_script;
+					break;
+				}
+			} else {
+				script = set_script;
+			}
+		}
+
+		let content_type = headers
+			.get("content-type")
+			.map(|s| s.to_str().ok())
+			.flatten()
+			.unwrap_or_default()
+			.to_owned();
+
+		println!("!!! new request. type {content_type} // {method}");
+		println!("!!! {path}?{query}");
+
+		let uagent = headers
+			.get("user-agent")
+			.map(|s| s.to_str().ok())
+			.flatten()
+			.unwrap_or_default()
+			.to_owned();
+
+		// Find the client address
+		let client_addr = {
+			let x_forward = Self::parse_addr_from_header(headers.get("x-forwarded-for"));
+			let forward = Self::parse_addr_from_header(headers.get("forwarded-for"));
+
+			forward.unwrap_or(x_forward.unwrap_or(caddr.ip()))
+		};
+
+		let server_name = headers
+			.get("Host")
+			.expect("no http host header set")
+			.to_str()
+			.expect("failed to decode http host as string");
+
+		let mut cmd = Command::new(&script.filename);
+		cmd.env("CONTENT_TYPE", content_type)
+			.env("GATEWAY_INTERFACE", "CGI/1.1")
+			.env("PATH_INFO", &path)
+			.env("QUERY_STRING", query)
+			.env("REMOTE_ADDR", client_addr.to_string())
+			.env("REQUEST_METHOD", method)
+			.env("SCRIPT_NAME", script.filename)
+			.env("SERVER_NAME", server_name)
+			.env("SERVER_PORT", settings.port.to_string())
+			.env("SERVER_PROTOCOL", format!("{:?}", version))
+			.env("SERVER_SOFTWARE", Self::SERVER_SOFTWARE);
+
+		if content_length > 0 {
+			cmd.env("CONTENT_LENGTH", content_length.to_string());
+		}
+
+		// Set env associated with the HTTP request headers
+		Self::set_http_env(headers, &mut cmd);
+
+		// Set env specified in the conf. Be sure we do this after we
+		// set the HTTP headers as to overwrite any we might want
+		for (key, value) in &script.env {
+			cmd.env(key.to_ascii_uppercase(), value);
+		}
+
+		let debugcgi = script.name == "git-backend";
+
+		let cgibody = if content_length > 0 {
+			Some(&body)
+		} else {
+			None
+		};
+
+		let start_cgi = Instant::now();
+		let cgi_response =
+			Self::call_and_parse_cgi(cmd, cgibody, caddr.ip(), debugcgi, &path).await;
+		let cgi_time = start_cgi.elapsed();
+
+		let status = StatusCode::from_u16(cgi_response.status).unwrap();
+		let mut response = Response::builder().status(status);
+
+		for (key, value) in cgi_response.headers {
+			response = response.header(key, value);
+		}
+
+		println!(
+			"served to [{client_addr}]\n\tscript: {}\n\tpath: {path}\n\tcgi took {}ms. total time {}ms\n\tUA: {uagent}",
+			&script.name,
+			cgi_time.as_millis(),
+			start.elapsed().as_millis()
+		);
+
+		let response_body = cgi_response.body.map(|v| Bytes::from(v)).unwrap_or(Bytes::new());
+		response.body(Full::new(response_body)).unwrap()
+	}
+
+	fn parse_addr_from_header(maybe_hval: Option<&HeaderValue>) -> Option<IpAddr> {
+		maybe_hval.map(|h| h.to_str().ok()).flatten().map(|s| s.parse().ok()).flatten()
+	}
+
+	const SERVER_SOFTWARE: &'static str =
+		concat!(env!("CARGO_PKG_NAME"), '/', env!("CARGO_PKG_VERSION"));
+
+	fn set_http_env(headers: HeaderMap, cmd: &mut Command) {
+		for (key, value) in headers.iter() {
+			let key_str = key.as_str();
+
+			let mut key_upper = String::with_capacity(key_str.len() + 5);
+			key_upper.push_str("HTTP_");
+
+			for ch in key_str.chars() {
+				match ch {
+					_ if ch as u8 > 0x60 && ch as u8 <= 0x7A => {
+						key_upper.push((ch as u8 - 0x20) as char);
+					}
+					'-' => key_upper.push('_'),
+					ch => key_upper.push(ch),
+				}
+			}
+
+			match value.to_str() {
+				Ok(val_str) => {
+					cmd.env(key_upper, val_str);
+				}
+				Err(err) => {
+					eprintln!("value for header {key_str} is not a string: {err}")
+				}
+			}
+		}
+	}
+
+	async fn call_and_parse_cgi(
+		mut cmd: Command,
+		body: Option<&Bytes>,
+		caddr: IpAddr,
+		debug: bool,
+		path: &str,
+	) -> CgiResponse {
+		let mut response = CgiResponse {
+			// Default status code is 200 per RFC
+			status: 200,
+			headers: vec![],
+			body: None,
+		};
+
+		println!("!!! before spawn: {path}");
+		let cmd = cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
+		let output = if let Some(bytes) = body {
+			println!("!!! has body len={}", bytes.len());
+			let mut child = cmd.stdin(Stdio::piped()).spawn().unwrap();
+
+			let cmd_stdin = child.stdin.take().unwrap();
+			let mut bufwrite = BufWriter::new(cmd_stdin);
+			bufwrite.write_all(bytes).await.unwrap();
+
+			drop(bufwrite);
+			println!("!!! after drop ({path})");
+
+			child.wait_with_output().await.unwrap()
+		} else {
+			cmd.spawn().unwrap().wait_with_output().await.unwrap()
+		};
+		println!("!!! after spawn ({path})");
+
+		let response_raw = output.stdout;
+
+		if debug {
+			std::fs::write(
+				format!("/tmp/{caddr}-gitbackend-{}", path_to_name(path)),
+				&response_raw,
+			)
+			.unwrap();
+		}
+
+		let mut curr = response_raw.as_slice();
+		loop {
+			// Find the newline to know where this header ends
+			let nl = curr.iter().position(|b| *b == b'\n').expect("no nl in header");
+			let line = &curr[..nl];
+
+			// Find the colon to separate the key from the value
+			let colon = line.iter().position(|b| *b == b':').expect("no colon in header");
+			let key = &line[..colon];
+			let mut value = &line[colon + 1..];
+
+			if value[0] == b' ' {
+				value = &value[1..];
+			}
+			if value[value.len().saturating_sub(1)] == b'\r' {
+				value = &value[..value.len().saturating_sub(1)];
+			}
+
+			response.headers.push((key.to_vec(), value.to_vec()));
+
+			// Is this header a status line?
+			let key_string = String::from_utf8_lossy(key);
+			if key_string == "Status" {
+				let value_string = String::from_utf8_lossy(value);
+				if let Some((raw_code, _raw_msg)) = value_string.trim().split_once(' ') {
+					let code: u16 = raw_code.parse().unwrap();
+					response.status = code;
+				}
+			}
+
+			// Body next?
+			let next_nl = curr[nl + 1] == b'\n';
+			let next_crlf = curr[nl + 1] == b'\r' && curr[nl + 2] == b'\n';
+			if next_nl || next_crlf {
+				let offset = if next_nl { 2 } else { 3 };
+				let body = &curr[nl + offset..];
+				if body.len() > 0 {
+					response.body = Some(body.to_vec());
+				}
+
+				println!("!!! before call_and_parse return ({path})");
+				return response;
+			}
+
+			// Move past the newline
+			curr = &curr[nl + 1..];
+		}
+	}
+}
+
+struct CgiResponse {
+	/// The Status header of the CGI response
+	status: u16,
+	/// Headers except "Status"
+	headers: Vec<(Vec<u8>, Vec<u8>)>,
+	/// CGI response body
+	body: Option<Vec<u8>>,
+}
+
+fn path_to_name(path: &str) -> String {
+	let mut ret = String::with_capacity(path.len());
+	for ch in path.chars() {
+		match ch {
+			'/' => ret.push('-'),
+			ch => ret.push(ch),
+		}
+	}
+	ret
+}