diff options
Diffstat (limited to 'corgi/src/main.rs')
-rw-r--r-- | corgi/src/main.rs | 396 |
1 files changed, 396 insertions, 0 deletions
diff --git a/corgi/src/main.rs b/corgi/src/main.rs new file mode 100644 index 0000000..9831331 --- /dev/null +++ b/corgi/src/main.rs @@ -0,0 +1,396 @@ +use std::{ + net::{IpAddr, SocketAddr}, + pin::Pin, + process::Stdio, + time::Instant, +}; + +use confindent::{Confindent, Value, ValueParseError}; +use http_body_util::{BodyExt, Full}; +use hyper::{ + HeaderMap, Request, Response, StatusCode, + body::{Bytes, Incoming}, + header::{self, HeaderValue}, + server::conn::http1, + service::Service, +}; +use hyper_util::rt::TokioIo; +use regex_lite::Regex; +use tokio::{ + io::{AsyncWriteExt, BufWriter}, + net::TcpListener, + process::Command, + runtime::Runtime, +}; + +#[derive(Clone, Debug)] +pub struct Settings { + port: u16, + scripts: Vec<Script>, +} + +#[derive(Clone, Debug)] +pub struct Script { + name: String, + regex: Option<Regex>, + filename: String, + env: Vec<(String, String)>, +} + +const CONF_DEFAULT: &str = "/etc/corgi.conf"; + +fn main() { + let conf_path = std::env::args().nth(1).unwrap_or(String::from(CONF_DEFAULT)); + let conf = Confindent::from_file(conf_path).expect("failed to open conf"); + + let mut settings = Settings { + port: 26744, + scripts: conf.children("Script").into_iter().map(parse_script_conf).collect(), + }; + + if let Some(server) = conf.child("Server") { + match server.child_parse("Port") { + Err(ValueParseError::NoValue) => (), + Err(err) => { + eprintln!("Server.Port is malformed: {err}"); + std::process::exit(1); + } + Ok(port) => settings.port = port, + } + } + + let rt = Runtime::new().unwrap(); + rt.block_on(async { run(settings).await }); +} + +fn parse_script_conf(conf: &Value) -> Script { + let name = conf.value_owned().expect("Missing value for 'Script' key"); + let filename = conf.child_owned("Path").expect("Missing 'Path' key"); + let environment = conf.child("Environment"); + let env = environment + .map(|e| e.values().map(|v| (v.key_owned(), v.value_owned().unwrap())).collect()); + + let regex = match conf.get("Match/Regex") { + None => None, + Some(restr) => match Regex::new(restr) { + Err(err) => { + eprintln!("Failed to compile regex: {restr}\nerror: {err}"); + std::process::exit(1); + } + Ok(re) => Some(re), + }, + }; + + Script { + name, + regex, + filename, + env: env.unwrap_or_default(), + } +} + +// We have tokio::main at home :) +async fn run(settings: Settings) { + let addr = SocketAddr::from(([0, 0, 0, 0], settings.port)); + let listen = TcpListener::bind(addr).await.unwrap(); + + let svc = Svc { + settings, + client_addr: addr, + }; + + loop { + let (stream, caddr) = listen.accept().await.unwrap(); + let io = TokioIo::new(stream); + let mut svc_clone = svc.clone(); + svc_clone.client_addr = caddr; + tokio::task::spawn( + async move { http1::Builder::new().serve_connection(io, svc_clone).await }, + ); + } +} + +#[derive(Clone, Debug)] +struct Svc { + settings: Settings, + client_addr: SocketAddr, +} + +impl Service<Request<Incoming>> for Svc { + type Response = Response<Full<Bytes>>; + type Error = hyper::Error; + type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; + + fn call(&self, req: Request<Incoming>) -> Self::Future { + let settings = self.settings.clone(); + let caddr = self.client_addr; + + Box::pin(async move { Ok(Self::handle(settings, caddr, req).await) }) + } +} + +impl Svc { + async fn handle( + settings: Settings, + caddr: SocketAddr, + req: Request<Incoming>, + ) -> Response<Full<Bytes>> { + let start = Instant::now(); + + // Collect things we need from the request before we eat it's body + let method = req.method().as_str().to_ascii_uppercase(); + let version = req.version(); + let path = req.uri().path().to_owned(); + let query = req.uri().query().unwrap_or_default().to_owned(); + let headers = req.headers().clone(); + let body = req.into_body().collect().await.unwrap().to_bytes(); + let content_length = body.len(); + + let mut script = settings.scripts[0].clone(); + + for set_script in settings.scripts { + if let Some(regex) = set_script.regex.as_ref() { + if regex.is_match(&path) { + script = set_script; + break; + } + } else { + script = set_script; + } + } + + let content_type = headers + .get("content-type") + .map(|s| s.to_str().ok()) + .flatten() + .unwrap_or_default() + .to_owned(); + + println!("!!! new request. type {content_type} // {method}"); + println!("!!! {path}?{query}"); + + let uagent = headers + .get("user-agent") + .map(|s| s.to_str().ok()) + .flatten() + .unwrap_or_default() + .to_owned(); + + // Find the client address + let client_addr = { + let x_forward = Self::parse_addr_from_header(headers.get("x-forwarded-for")); + let forward = Self::parse_addr_from_header(headers.get("forwarded-for")); + + forward.unwrap_or(x_forward.unwrap_or(caddr.ip())) + }; + + let server_name = headers + .get("Host") + .expect("no http host header set") + .to_str() + .expect("failed to decode http host as string"); + + let mut cmd = Command::new(&script.filename); + cmd.env("CONTENT_TYPE", content_type) + .env("GATEWAY_INTERFACE", "CGI/1.1") + .env("PATH_INFO", &path) + .env("QUERY_STRING", query) + .env("REMOTE_ADDR", client_addr.to_string()) + .env("REQUEST_METHOD", method) + .env("SCRIPT_NAME", script.filename) + .env("SERVER_NAME", server_name) + .env("SERVER_PORT", settings.port.to_string()) + .env("SERVER_PROTOCOL", format!("{:?}", version)) + .env("SERVER_SOFTWARE", Self::SERVER_SOFTWARE); + + if content_length > 0 { + cmd.env("CONTENT_LENGTH", content_length.to_string()); + } + + // Set env associated with the HTTP request headers + Self::set_http_env(headers, &mut cmd); + + // Set env specified in the conf. Be sure we do this after we + // set the HTTP headers as to overwrite any we might want + for (key, value) in &script.env { + cmd.env(key.to_ascii_uppercase(), value); + } + + let debugcgi = script.name == "git-backend"; + + let cgibody = if content_length > 0 { + Some(&body) + } else { + None + }; + + let start_cgi = Instant::now(); + let cgi_response = + Self::call_and_parse_cgi(cmd, cgibody, caddr.ip(), debugcgi, &path).await; + let cgi_time = start_cgi.elapsed(); + + let status = StatusCode::from_u16(cgi_response.status).unwrap(); + let mut response = Response::builder().status(status); + + for (key, value) in cgi_response.headers { + response = response.header(key, value); + } + + println!( + "served to [{client_addr}]\n\tscript: {}\n\tpath: {path}\n\tcgi took {}ms. total time {}ms\n\tUA: {uagent}", + &script.name, + cgi_time.as_millis(), + start.elapsed().as_millis() + ); + + let response_body = cgi_response.body.map(|v| Bytes::from(v)).unwrap_or(Bytes::new()); + response.body(Full::new(response_body)).unwrap() + } + + fn parse_addr_from_header(maybe_hval: Option<&HeaderValue>) -> Option<IpAddr> { + maybe_hval.map(|h| h.to_str().ok()).flatten().map(|s| s.parse().ok()).flatten() + } + + const SERVER_SOFTWARE: &'static str = + concat!(env!("CARGO_PKG_NAME"), '/', env!("CARGO_PKG_VERSION")); + + fn set_http_env(headers: HeaderMap, cmd: &mut Command) { + for (key, value) in headers.iter() { + let key_str = key.as_str(); + + let mut key_upper = String::with_capacity(key_str.len() + 5); + key_upper.push_str("HTTP_"); + + for ch in key_str.chars() { + match ch { + _ if ch as u8 > 0x60 && ch as u8 <= 0x7A => { + key_upper.push((ch as u8 - 0x20) as char); + } + '-' => key_upper.push('_'), + ch => key_upper.push(ch), + } + } + + match value.to_str() { + Ok(val_str) => { + cmd.env(key_upper, val_str); + } + Err(err) => { + eprintln!("value for header {key_str} is not a string: {err}") + } + } + } + } + + async fn call_and_parse_cgi( + mut cmd: Command, + body: Option<&Bytes>, + caddr: IpAddr, + debug: bool, + path: &str, + ) -> CgiResponse { + let mut response = CgiResponse { + // Default status code is 200 per RFC + status: 200, + headers: vec![], + body: None, + }; + + println!("!!! before spawn: {path}"); + let cmd = cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); + let output = if let Some(bytes) = body { + println!("!!! has body len={}", bytes.len()); + let mut child = cmd.stdin(Stdio::piped()).spawn().unwrap(); + + let cmd_stdin = child.stdin.take().unwrap(); + let mut bufwrite = BufWriter::new(cmd_stdin); + bufwrite.write_all(bytes).await.unwrap(); + + drop(bufwrite); + println!("!!! after drop ({path})"); + + child.wait_with_output().await.unwrap() + } else { + cmd.spawn().unwrap().wait_with_output().await.unwrap() + }; + println!("!!! after spawn ({path})"); + + let response_raw = output.stdout; + + if debug { + std::fs::write( + format!("/tmp/{caddr}-gitbackend-{}", path_to_name(path)), + &response_raw, + ) + .unwrap(); + } + + let mut curr = response_raw.as_slice(); + loop { + // Find the newline to know where this header ends + let nl = curr.iter().position(|b| *b == b'\n').expect("no nl in header"); + let line = &curr[..nl]; + + // Find the colon to separate the key from the value + let colon = line.iter().position(|b| *b == b':').expect("no colon in header"); + let key = &line[..colon]; + let mut value = &line[colon + 1..]; + + if value[0] == b' ' { + value = &value[1..]; + } + if value[value.len().saturating_sub(1)] == b'\r' { + value = &value[..value.len().saturating_sub(1)]; + } + + response.headers.push((key.to_vec(), value.to_vec())); + + // Is this header a status line? + let key_string = String::from_utf8_lossy(key); + if key_string == "Status" { + let value_string = String::from_utf8_lossy(value); + if let Some((raw_code, _raw_msg)) = value_string.trim().split_once(' ') { + let code: u16 = raw_code.parse().unwrap(); + response.status = code; + } + } + + // Body next? + let next_nl = curr[nl + 1] == b'\n'; + let next_crlf = curr[nl + 1] == b'\r' && curr[nl + 2] == b'\n'; + if next_nl || next_crlf { + let offset = if next_nl { 2 } else { 3 }; + let body = &curr[nl + offset..]; + if body.len() > 0 { + response.body = Some(body.to_vec()); + } + + println!("!!! before call_and_parse return ({path})"); + return response; + } + + // Move past the newline + curr = &curr[nl + 1..]; + } + } +} + +struct CgiResponse { + /// The Status header of the CGI response + status: u16, + /// Headers except "Status" + headers: Vec<(Vec<u8>, Vec<u8>)>, + /// CGI response body + body: Option<Vec<u8>>, +} + +fn path_to_name(path: &str) -> String { + let mut ret = String::with_capacity(path.len()); + for ch in path.chars() { + match ch { + '/' => ret.push('-'), + ch => ret.push(ch), + } + } + ret +} |