about summary refs log tree commit diff
diff options
context:
space:
mode:
authorgennyble <gen@nyble.dev>2025-04-04 04:38:33 -0500
committergennyble <gen@nyble.dev>2025-04-04 04:38:33 -0500
commita5d0dd1653b564a56a224fe374503cb63c511cc1 (patch)
tree88c0cd7e4d52cdd0703a436a7948db9e329a5c9e
parent92ccd386407e095f72523428febccea20abeb1c8 (diff)
downloadcorgi-a5d0dd1653b564a56a224fe374503cb63c511cc1.tar.gz
corgi-a5d0dd1653b564a56a224fe374503cb63c511cc1.zip
refactor
-rw-r--r--corgi/src/main.rs177
-rw-r--r--corgi/src/settings.rs82
-rw-r--r--corgi/src/stats.rs4
-rw-r--r--corgi/src/util.rs66
4 files changed, 169 insertions, 160 deletions
diff --git a/corgi/src/main.rs b/corgi/src/main.rs
index 7f94ba2..1772d68 100644
--- a/corgi/src/main.rs
+++ b/corgi/src/main.rs
@@ -1,109 +1,34 @@
 use core::fmt;
-use std::{
-	net::{IpAddr, SocketAddr},
-	path::PathBuf,
-	pin::Pin,
-	sync::Arc,
-};
+use std::{net::SocketAddr, pin::Pin, sync::Arc};
 
 use caller::HttpRequest;
-use confindent::{Confindent, Value, ValueParseError};
 use http_body_util::{BodyExt, Full};
 use hyper::{
 	HeaderMap, Request, Response, StatusCode,
 	body::{Bytes, Incoming},
-	header::HeaderValue,
 	server::conn::http1,
 	service::Service,
 };
 use hyper_util::rt::TokioIo;
-use regex_lite::Regex;
+use settings::{Script, Settings};
 use stats::Stats;
 use tokio::{net::TcpListener, runtime::Runtime};
+use util::owned_header;
 
 mod caller;
+mod settings;
 mod stats;
-
-#[derive(Clone, Debug)]
-pub struct Settings {
-	port: u16,
-	scripts: Vec<Script>,
-}
-
-#[derive(Clone, Debug)]
-pub struct Script {
-	name: String,
-	regex: Option<Regex>,
-	filename: String,
-	env: Vec<(String, String)>,
-}
-
-const CONF_DEFAULT: &str = "/etc/corgi.conf";
+mod util;
 
 fn main() {
-	let conf_path = std::env::args()
-		.nth(1)
-		.unwrap_or(String::from(CONF_DEFAULT));
-	let conf = Confindent::from_file(conf_path).expect("failed to open conf");
-
-	let mut settings = Settings {
-		port: 26744,
-		scripts: conf
-			.children("Script")
-			.into_iter()
-			.map(parse_script_conf)
-			.collect(),
-	};
-
-	if let Some(server) = conf.child("Server") {
-		match server.child_parse("Port") {
-			Err(ValueParseError::NoValue) => (),
-			Err(err) => {
-				eprintln!("Server.Port is malformed: {err}");
-				std::process::exit(1);
-			}
-			Ok(port) => settings.port = port,
-		}
-	}
-
-	let stats = Stats::new(PathBuf::from(
-		conf.get("Server/StatsDb").unwrap().to_owned(),
-	));
+	let settings = Settings::get();
+	let stats = Stats::new(&settings.stats_path);
 	stats.create_tables();
 
 	let rt = Runtime::new().unwrap();
 	rt.block_on(async { run(settings, stats).await });
 }
 
-fn parse_script_conf(conf: &Value) -> Script {
-	let name = conf.value_owned().expect("Missing value for 'Script' key");
-	let filename = conf.child_owned("Path").expect("Missing 'Path' key");
-	let environment = conf.child("Environment");
-	let env = environment.map(|e| {
-		e.values()
-			.map(|v| (v.key_owned(), v.value_owned().unwrap()))
-			.collect()
-	});
-
-	let regex = match conf.get("Match/Regex") {
-		None => None,
-		Some(restr) => match Regex::new(restr) {
-			Err(err) => {
-				eprintln!("Failed to compile regex: {restr}\nerror: {err}");
-				std::process::exit(1);
-			}
-			Ok(re) => Some(re),
-		},
-	};
-
-	Script {
-		name,
-		regex,
-		filename,
-		env: env.unwrap_or_default(),
-	}
-}
-
 // We have tokio::main at home :)
 async fn run(settings: Settings, stats: Stats) {
 	let addr = SocketAddr::from(([0, 0, 0, 0], settings.port));
@@ -169,44 +94,33 @@ impl Svc {
 		// Collect things we need from the request before we eat it's body
 		let method = req.method().as_str().to_ascii_uppercase();
 		let version = req.version();
-		let path = url_decode(req.uri().path(), false).ok_or(RuntimeError::MalformedRequest)?;
-
+		let path = util::url_decode(req.uri().path(), false)?;
 		let query = req
 			.uri()
 			.query()
-			.map(|s| url_decode(s, false).ok_or(RuntimeError::MalformedRequest))
+			.map(|s| util::url_decode(s, false))
 			.transpose()?
 			.unwrap_or_default();
 
-		let headers = req.headers().clone();
-
-		let body = req.into_body().collect().await.unwrap().to_bytes().to_vec();
-		let content_length = body.len();
-
 		let script = Self::select_script(&settings, &path).ok_or(RuntimeError::NoScript)?;
 
-		let content_type = headers
-			.get("content-type")
-			.map(|s| s.to_str().ok())
-			.flatten()
-			.unwrap_or_default()
-			.to_owned();
-
-		let uagent = headers
-			.get("user-agent")
-			.map(|s| s.to_str().ok())
-			.flatten()
-			.unwrap_or_default()
-			.to_owned();
+		// Clone the headers and extract what we need
+		let headers = req.headers().clone();
+		let content_type = owned_header(headers.get("content-type")).unwrap_or_default();
+		let uagent = owned_header(headers.get("user-agent")).unwrap_or_default();
 
 		// Find the client address
 		let client_addr = {
-			let x_forward = Self::parse_addr_from_header(headers.get("x-forwarded-for"));
-			let forward = Self::parse_addr_from_header(headers.get("forwarded-for"));
+			let x_forward = util::parse_from_header(headers.get("x-forwarded-for"));
+			let forward = util::parse_from_header(headers.get("forwarded-for"));
 
 			forward.unwrap_or(x_forward.unwrap_or(caddr.ip()))
 		};
 
+		// Finally, get the body which consumes the request
+		let body = req.into_body().collect().await.unwrap().to_bytes().to_vec();
+		let content_length = body.len();
+
 		let server_name = headers
 			.get("Host")
 			.expect("no http host header set")
@@ -258,14 +172,6 @@ impl Svc {
 		Ok(response.body(Full::new(response_body)).unwrap())
 	}
 
-	fn parse_addr_from_header(maybe_hval: Option<&HeaderValue>) -> Option<IpAddr> {
-		maybe_hval
-			.map(|h| h.to_str().ok())
-			.flatten()
-			.map(|s| s.parse().ok())
-			.flatten()
-	}
-
 	fn select_script<'s>(settings: &'s Settings, path: &str) -> Option<&'s Script> {
 		for script in &settings.scripts {
 			if let Some(regex) = script.regex.as_ref() {
@@ -332,51 +238,6 @@ fn status_page<D: fmt::Display>(status: u16, msg: D) -> Response<Full<Bytes>> {
 		.unwrap()
 }
 
-// Ripped and modified from gennyble/mavourings query.rs
-fn url_decode<S: AsRef<str>>(urlencoded: S, plus_as_space: bool) -> Option<String> {
-	let urlencoded = urlencoded.as_ref();
-	let mut uncoded: Vec<u8> = Vec::with_capacity(urlencoded.len());
-
-	let mut chars = urlencoded.chars().peekable();
-	loop {
-		let mut utf8_bytes = [0; 4];
-		match chars.next() {
-			Some('+') => match plus_as_space {
-				true => uncoded.push(b' '),
-				false => uncoded.push(b'+'),
-			},
-			Some('%') => match chars.peek() {
-				Some(c) if c.is_ascii_hexdigit() => {
-					let upper = chars.next().unwrap();
-
-					if let Some(lower) = chars.peek() {
-						if lower.is_ascii_hexdigit() {
-							let upper = upper.to_digit(16).unwrap();
-							let lower = chars.next().unwrap().to_digit(16).unwrap();
-
-							uncoded.push(upper as u8 * 16 + lower as u8);
-							continue;
-						}
-					}
-
-					uncoded.push(b'%');
-					uncoded.extend_from_slice(upper.encode_utf8(&mut utf8_bytes).as_bytes());
-				}
-				_ => {
-					uncoded.push(b'%');
-				}
-			},
-			Some(c) => {
-				uncoded.extend_from_slice(c.encode_utf8(&mut utf8_bytes).as_bytes());
-			}
-			None => {
-				uncoded.shrink_to_fit();
-				return String::from_utf8(uncoded).ok();
-			}
-		}
-	}
-}
-
 enum RuntimeError {
 	MalformedRequest,
 	NoScript,
diff --git a/corgi/src/settings.rs b/corgi/src/settings.rs
new file mode 100644
index 0000000..ee701b0
--- /dev/null
+++ b/corgi/src/settings.rs
@@ -0,0 +1,82 @@
+use std::path::PathBuf;
+
+use confindent::{Confindent, Value, ValueParseError};
+use regex_lite::Regex;
+
+const CONF_DEFAULT: &str = "/etc/corgi.conf";
+
+#[derive(Clone, Debug)]
+pub struct Script {
+	pub name: String,
+	pub regex: Option<Regex>,
+	pub filename: String,
+	pub env: Vec<(String, String)>,
+}
+
+#[derive(Clone, Debug)]
+pub struct Settings {
+	pub port: u16,
+	pub scripts: Vec<Script>,
+	pub stats_path: PathBuf,
+}
+
+impl Settings {
+	pub fn get() -> Self {
+		let conf_path = std::env::args()
+			.nth(1)
+			.unwrap_or(String::from(CONF_DEFAULT));
+		let conf = Confindent::from_file(conf_path).expect("failed to open conf");
+
+		let mut settings = Settings {
+			port: 26744,
+			scripts: conf
+				.children("Script")
+				.into_iter()
+				.map(parse_script_conf)
+				.collect(),
+			stats_path: conf.get_parse("Server/StatsDb").unwrap(),
+		};
+
+		if let Some(server) = conf.child("Server") {
+			match server.child_parse("Port") {
+				Err(ValueParseError::NoValue) => (),
+				Err(err) => {
+					eprintln!("Server.Port is malformed: {err}");
+					std::process::exit(1);
+				}
+				Ok(port) => settings.port = port,
+			}
+		}
+
+		settings
+	}
+}
+
+fn parse_script_conf(conf: &Value) -> Script {
+	let name = conf.value_owned().expect("Missing value for 'Script' key");
+	let filename = conf.child_owned("Path").expect("Missing 'Path' key");
+	let environment = conf.child("Environment");
+	let env = environment.map(|e| {
+		e.values()
+			.map(|v| (v.key_owned(), v.value_owned().unwrap()))
+			.collect()
+	});
+
+	let regex = match conf.get("Match/Regex") {
+		None => None,
+		Some(restr) => match Regex::new(restr) {
+			Err(err) => {
+				eprintln!("Failed to compile regex: {restr}\nerror: {err}");
+				std::process::exit(1);
+			}
+			Ok(re) => Some(re),
+		},
+	};
+
+	Script {
+		name,
+		regex,
+		filename,
+		env: env.unwrap_or_default(),
+	}
+}
diff --git a/corgi/src/stats.rs b/corgi/src/stats.rs
index e475820..bf9c1ec 100644
--- a/corgi/src/stats.rs
+++ b/corgi/src/stats.rs
@@ -1,4 +1,4 @@
-use std::{net::IpAddr, path::PathBuf, sync::Mutex};
+use std::{net::IpAddr, path::Path, sync::Mutex};
 
 use base64::{Engine, prelude::BASE64_STANDARD_NO_PAD};
 use rusqlite::{Connection, OptionalExtension, params};
@@ -10,7 +10,7 @@ pub struct Stats {
 }
 
 impl Stats {
-	pub fn new(db_path: PathBuf) -> Self {
+	pub fn new(db_path: &Path) -> Self {
 		Self {
 			conn: Mutex::new(Connection::open(db_path).unwrap()),
 		}
diff --git a/corgi/src/util.rs b/corgi/src/util.rs
new file mode 100644
index 0000000..727c8c7
--- /dev/null
+++ b/corgi/src/util.rs
@@ -0,0 +1,66 @@
+use std::str::FromStr;
+
+use hyper::header::HeaderValue;
+
+use crate::RuntimeError;
+
+// Ripped and modified from gennyble/mavourings query.rs
+/// Decode a URL encoded string, optionally treating a plus, '+', as a space. If
+/// the final string is not UTF8, RuntimeError::MalformedRequest is returned
+pub fn url_decode(urlencoded: &str, plus_as_space: bool) -> Result<String, RuntimeError> {
+	let mut uncoded: Vec<u8> = Vec::with_capacity(urlencoded.len());
+
+	let mut chars = urlencoded.chars().peekable();
+	loop {
+		let mut utf8_bytes = [0; 4];
+		match chars.next() {
+			Some('+') => match plus_as_space {
+				true => uncoded.push(b' '),
+				false => uncoded.push(b'+'),
+			},
+			Some('%') => match chars.peek() {
+				Some(c) if c.is_ascii_hexdigit() => {
+					let upper = chars.next().unwrap();
+
+					if let Some(lower) = chars.peek() {
+						if lower.is_ascii_hexdigit() {
+							let upper = upper.to_digit(16).unwrap();
+							let lower = chars.next().unwrap().to_digit(16).unwrap();
+
+							uncoded.push(upper as u8 * 16 + lower as u8);
+							continue;
+						}
+					}
+
+					uncoded.push(b'%');
+					uncoded.extend_from_slice(upper.encode_utf8(&mut utf8_bytes).as_bytes());
+				}
+				_ => {
+					uncoded.push(b'%');
+				}
+			},
+			Some(c) => {
+				uncoded.extend_from_slice(c.encode_utf8(&mut utf8_bytes).as_bytes());
+			}
+			None => {
+				uncoded.shrink_to_fit();
+				return String::from_utf8(uncoded).map_err(|_| RuntimeError::MalformedRequest);
+			}
+		}
+	}
+}
+
+pub fn parse_from_header<T: FromStr>(maybe_hval: Option<&HeaderValue>) -> Option<T> {
+	maybe_hval
+		.map(|h| h.to_str().ok())
+		.flatten()
+		.map(|s| s.parse().ok())
+		.flatten()
+}
+
+pub fn owned_header(maybe_hval: Option<&HeaderValue>) -> Option<String> {
+	maybe_hval
+		.map(|h| h.to_str().ok())
+		.flatten()
+		.map(<_>::to_owned)
+}