diff options
-rw-r--r-- | .rustfmt.toml | 1 | ||||
-rw-r--r-- | corgi/src/main.rs | 125 |
2 files changed, 95 insertions, 31 deletions
diff --git a/.rustfmt.toml b/.rustfmt.toml index 4639247..218e203 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,2 +1 @@ hard_tabs = true -chain_width = 100 diff --git a/corgi/src/main.rs b/corgi/src/main.rs index ba7195b..7f94ba2 100644 --- a/corgi/src/main.rs +++ b/corgi/src/main.rs @@ -1,3 +1,4 @@ +use core::fmt; use std::{ net::{IpAddr, SocketAddr}, path::PathBuf, @@ -40,12 +41,18 @@ pub struct Script { const CONF_DEFAULT: &str = "/etc/corgi.conf"; fn main() { - let conf_path = std::env::args().nth(1).unwrap_or(String::from(CONF_DEFAULT)); + let conf_path = std::env::args() + .nth(1) + .unwrap_or(String::from(CONF_DEFAULT)); let conf = Confindent::from_file(conf_path).expect("failed to open conf"); let mut settings = Settings { port: 26744, - scripts: conf.children("Script").into_iter().map(parse_script_conf).collect(), + scripts: conf + .children("Script") + .into_iter() + .map(parse_script_conf) + .collect(), }; if let Some(server) = conf.child("Server") { @@ -72,8 +79,11 @@ fn parse_script_conf(conf: &Value) -> Script { let name = conf.value_owned().expect("Missing value for 'Script' key"); let filename = conf.child_owned("Path").expect("Missing 'Path' key"); let environment = conf.child("Environment"); - let env = environment - .map(|e| e.values().map(|v| (v.key_owned(), v.value_owned().unwrap())).collect()); + let env = environment.map(|e| { + e.values() + .map(|v| (v.key_owned(), v.value_owned().unwrap())) + .collect() + }); let regex = match conf.get("Match/Regex") { None => None, @@ -144,36 +154,36 @@ impl Svc { caddr: SocketAddr, req: Request<Incoming>, ) -> Response<Full<Bytes>> { + match Self::handle_fallible(settings, stats, caddr, req).await { + Err(re) => re.into_response(), + Ok(response) => response, + } + } + + async fn handle_fallible( + settings: Settings, + stats: Arc<Stats>, + caddr: SocketAddr, + req: Request<Incoming>, + ) -> Result<Response<Full<Bytes>>, RuntimeError> { // Collect things we need from the request before we eat it's body let method = req.method().as_str().to_ascii_uppercase(); let version = req.version(); - let path = url_decode(req.uri().path().to_owned(), false).unwrap(); - let query = url_decode(req.uri().query().unwrap_or_default().to_owned(), false).unwrap(); + let path = url_decode(req.uri().path(), false).ok_or(RuntimeError::MalformedRequest)?; + + let query = req + .uri() + .query() + .map(|s| url_decode(s, false).ok_or(RuntimeError::MalformedRequest)) + .transpose()? + .unwrap_or_default(); + let headers = req.headers().clone(); let body = req.into_body().collect().await.unwrap().to_bytes().to_vec(); let content_length = body.len(); - let mut maybe_script = None; - for set_script in settings.scripts { - if let Some(regex) = set_script.regex.as_ref() { - if regex.is_match(&path) { - maybe_script = Some(set_script); - break; - } - } else { - maybe_script = Some(set_script); - break; - } - } - - let script = match maybe_script { - Some(script) => script, - None => { - eprintln!("path didn't match any script"); - panic!("TODO recover?"); - } - }; + let script = Self::select_script(&settings, &path).ok_or(RuntimeError::NoScript)?; let content_type = headers .get("content-type") @@ -218,8 +228,8 @@ impl Svc { }; let cgi_response = caller::call_and_parse_cgi(script.clone(), http_request).await; - let status = StatusCode::from_u16(cgi_response.status).unwrap(); + let mut response = Response::builder().status(status); for (key, value) in cgi_response.headers { @@ -240,12 +250,34 @@ impl Svc { stats.log_request(db_req); - let response_body = cgi_response.body.map(|v| Bytes::from(v)).unwrap_or(Bytes::new()); - response.body(Full::new(response_body)).unwrap() + let response_body = cgi_response + .body + .map(|v| Bytes::from(v)) + .unwrap_or(Bytes::new()); + + Ok(response.body(Full::new(response_body)).unwrap()) } fn parse_addr_from_header(maybe_hval: Option<&HeaderValue>) -> Option<IpAddr> { - maybe_hval.map(|h| h.to_str().ok()).flatten().map(|s| s.parse().ok()).flatten() + maybe_hval + .map(|h| h.to_str().ok()) + .flatten() + .map(|s| s.parse().ok()) + .flatten() + } + + fn select_script<'s>(settings: &'s Settings, path: &str) -> Option<&'s Script> { + for script in &settings.scripts { + if let Some(regex) = script.regex.as_ref() { + if regex.is_match(path) { + return Some(script); + } + } else { + return Some(script); + } + } + + None } fn build_http_vec(headers: HeaderMap) -> Vec<(String, String)> { @@ -281,6 +313,25 @@ impl Svc { } } +fn status_page<D: fmt::Display>(status: u16, msg: D) -> Response<Full<Bytes>> { + let body_str = format!( + "<html>\n\ + \t<head><title>{status}</title></head>\n\ + \t<body style='width: 20rem; padding: 0px; margin: 2rem;'>\n\ + \t\t<h1>{status}</h1>\n\ + \t\t<hr/>\n\ + \t\t<p>{msg}</p>\n\ + \t</body>\n\ + </html>" + ); + + Response::builder() + .status(status) + .header("Content-Type", "text/html") + .body(Full::new(body_str.into())) + .unwrap() +} + // Ripped and modified from gennyble/mavourings query.rs fn url_decode<S: AsRef<str>>(urlencoded: S, plus_as_space: bool) -> Option<String> { let urlencoded = urlencoded.as_ref(); @@ -325,3 +376,17 @@ fn url_decode<S: AsRef<str>>(urlencoded: S, plus_as_space: bool) -> Option<Strin } } } + +enum RuntimeError { + MalformedRequest, + NoScript, +} + +impl RuntimeError { + pub fn into_response(&self) -> Response<Full<Bytes>> { + match self { + Self::MalformedRequest => status_page(400, "bad request"), + Self::NoScript => status_page(404, "failed to route request"), + } + } +} |