about summary refs log tree commit diff
path: root/corgi/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'corgi/src/main.rs')
-rw-r--r--corgi/src/main.rs177
1 files changed, 19 insertions, 158 deletions
diff --git a/corgi/src/main.rs b/corgi/src/main.rs
index 7f94ba2..1772d68 100644
--- a/corgi/src/main.rs
+++ b/corgi/src/main.rs
@@ -1,109 +1,34 @@
 use core::fmt;
-use std::{
-	net::{IpAddr, SocketAddr},
-	path::PathBuf,
-	pin::Pin,
-	sync::Arc,
-};
+use std::{net::SocketAddr, pin::Pin, sync::Arc};
 
 use caller::HttpRequest;
-use confindent::{Confindent, Value, ValueParseError};
 use http_body_util::{BodyExt, Full};
 use hyper::{
 	HeaderMap, Request, Response, StatusCode,
 	body::{Bytes, Incoming},
-	header::HeaderValue,
 	server::conn::http1,
 	service::Service,
 };
 use hyper_util::rt::TokioIo;
-use regex_lite::Regex;
+use settings::{Script, Settings};
 use stats::Stats;
 use tokio::{net::TcpListener, runtime::Runtime};
+use util::owned_header;
 
 mod caller;
+mod settings;
 mod stats;
-
-#[derive(Clone, Debug)]
-pub struct Settings {
-	port: u16,
-	scripts: Vec<Script>,
-}
-
-#[derive(Clone, Debug)]
-pub struct Script {
-	name: String,
-	regex: Option<Regex>,
-	filename: String,
-	env: Vec<(String, String)>,
-}
-
-const CONF_DEFAULT: &str = "/etc/corgi.conf";
+mod util;
 
 fn main() {
-	let conf_path = std::env::args()
-		.nth(1)
-		.unwrap_or(String::from(CONF_DEFAULT));
-	let conf = Confindent::from_file(conf_path).expect("failed to open conf");
-
-	let mut settings = Settings {
-		port: 26744,
-		scripts: conf
-			.children("Script")
-			.into_iter()
-			.map(parse_script_conf)
-			.collect(),
-	};
-
-	if let Some(server) = conf.child("Server") {
-		match server.child_parse("Port") {
-			Err(ValueParseError::NoValue) => (),
-			Err(err) => {
-				eprintln!("Server.Port is malformed: {err}");
-				std::process::exit(1);
-			}
-			Ok(port) => settings.port = port,
-		}
-	}
-
-	let stats = Stats::new(PathBuf::from(
-		conf.get("Server/StatsDb").unwrap().to_owned(),
-	));
+	let settings = Settings::get();
+	let stats = Stats::new(&settings.stats_path);
 	stats.create_tables();
 
 	let rt = Runtime::new().unwrap();
 	rt.block_on(async { run(settings, stats).await });
 }
 
-fn parse_script_conf(conf: &Value) -> Script {
-	let name = conf.value_owned().expect("Missing value for 'Script' key");
-	let filename = conf.child_owned("Path").expect("Missing 'Path' key");
-	let environment = conf.child("Environment");
-	let env = environment.map(|e| {
-		e.values()
-			.map(|v| (v.key_owned(), v.value_owned().unwrap()))
-			.collect()
-	});
-
-	let regex = match conf.get("Match/Regex") {
-		None => None,
-		Some(restr) => match Regex::new(restr) {
-			Err(err) => {
-				eprintln!("Failed to compile regex: {restr}\nerror: {err}");
-				std::process::exit(1);
-			}
-			Ok(re) => Some(re),
-		},
-	};
-
-	Script {
-		name,
-		regex,
-		filename,
-		env: env.unwrap_or_default(),
-	}
-}
-
 // We have tokio::main at home :)
 async fn run(settings: Settings, stats: Stats) {
 	let addr = SocketAddr::from(([0, 0, 0, 0], settings.port));
@@ -169,44 +94,33 @@ impl Svc {
 		// Collect things we need from the request before we eat it's body
 		let method = req.method().as_str().to_ascii_uppercase();
 		let version = req.version();
-		let path = url_decode(req.uri().path(), false).ok_or(RuntimeError::MalformedRequest)?;
-
+		let path = util::url_decode(req.uri().path(), false)?;
 		let query = req
 			.uri()
 			.query()
-			.map(|s| url_decode(s, false).ok_or(RuntimeError::MalformedRequest))
+			.map(|s| util::url_decode(s, false))
 			.transpose()?
 			.unwrap_or_default();
 
-		let headers = req.headers().clone();
-
-		let body = req.into_body().collect().await.unwrap().to_bytes().to_vec();
-		let content_length = body.len();
-
 		let script = Self::select_script(&settings, &path).ok_or(RuntimeError::NoScript)?;
 
-		let content_type = headers
-			.get("content-type")
-			.map(|s| s.to_str().ok())
-			.flatten()
-			.unwrap_or_default()
-			.to_owned();
-
-		let uagent = headers
-			.get("user-agent")
-			.map(|s| s.to_str().ok())
-			.flatten()
-			.unwrap_or_default()
-			.to_owned();
+		// Clone the headers and extract what we need
+		let headers = req.headers().clone();
+		let content_type = owned_header(headers.get("content-type")).unwrap_or_default();
+		let uagent = owned_header(headers.get("user-agent")).unwrap_or_default();
 
 		// Find the client address
 		let client_addr = {
-			let x_forward = Self::parse_addr_from_header(headers.get("x-forwarded-for"));
-			let forward = Self::parse_addr_from_header(headers.get("forwarded-for"));
+			let x_forward = util::parse_from_header(headers.get("x-forwarded-for"));
+			let forward = util::parse_from_header(headers.get("forwarded-for"));
 
 			forward.unwrap_or(x_forward.unwrap_or(caddr.ip()))
 		};
 
+		// Finally, get the body which consumes the request
+		let body = req.into_body().collect().await.unwrap().to_bytes().to_vec();
+		let content_length = body.len();
+
 		let server_name = headers
 			.get("Host")
 			.expect("no http host header set")
@@ -258,14 +172,6 @@ impl Svc {
 		Ok(response.body(Full::new(response_body)).unwrap())
 	}
 
-	fn parse_addr_from_header(maybe_hval: Option<&HeaderValue>) -> Option<IpAddr> {
-		maybe_hval
-			.map(|h| h.to_str().ok())
-			.flatten()
-			.map(|s| s.parse().ok())
-			.flatten()
-	}
-
 	fn select_script<'s>(settings: &'s Settings, path: &str) -> Option<&'s Script> {
 		for script in &settings.scripts {
 			if let Some(regex) = script.regex.as_ref() {
@@ -332,51 +238,6 @@ fn status_page<D: fmt::Display>(status: u16, msg: D) -> Response<Full<Bytes>> {
 		.unwrap()
 }
 
-// Ripped and modified from gennyble/mavourings query.rs
-fn url_decode<S: AsRef<str>>(urlencoded: S, plus_as_space: bool) -> Option<String> {
-	let urlencoded = urlencoded.as_ref();
-	let mut uncoded: Vec<u8> = Vec::with_capacity(urlencoded.len());
-
-	let mut chars = urlencoded.chars().peekable();
-	loop {
-		let mut utf8_bytes = [0; 4];
-		match chars.next() {
-			Some('+') => match plus_as_space {
-				true => uncoded.push(b' '),
-				false => uncoded.push(b'+'),
-			},
-			Some('%') => match chars.peek() {
-				Some(c) if c.is_ascii_hexdigit() => {
-					let upper = chars.next().unwrap();
-
-					if let Some(lower) = chars.peek() {
-						if lower.is_ascii_hexdigit() {
-							let upper = upper.to_digit(16).unwrap();
-							let lower = chars.next().unwrap().to_digit(16).unwrap();
-
-							uncoded.push(upper as u8 * 16 + lower as u8);
-							continue;
-						}
-					}
-
-					uncoded.push(b'%');
-					uncoded.extend_from_slice(upper.encode_utf8(&mut utf8_bytes).as_bytes());
-				}
-				_ => {
-					uncoded.push(b'%');
-				}
-			},
-			Some(c) => {
-				uncoded.extend_from_slice(c.encode_utf8(&mut utf8_bytes).as_bytes());
-			}
-			None => {
-				uncoded.shrink_to_fit();
-				return String::from_utf8(uncoded).ok();
-			}
-		}
-	}
-}
-
 enum RuntimeError {
 	MalformedRequest,
 	NoScript,