about summary refs log tree commit diff
path: root/src/main.rs
diff options
context:
space:
mode:
authorgennyble <gen@nyble.dev>2025-06-16 21:45:59 -0500
committergennyble <gen@nyble.dev>2025-06-16 21:45:59 -0500
commitdf11f96162f89c740aba87d62e4512cfc46049bb (patch)
treee06c5b66f0cdd74fe6d588ef93d9f8244b268e45 /src/main.rs
parent764f1cdf94182ba84d500013b000f718f6fe37a7 (diff)
downloadpiper-df11f96162f89c740aba87d62e4512cfc46049bb.tar.gz
piper-df11f96162f89c740aba87d62e4512cfc46049bb.zip
allow getting ip-address from header
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs45
1 files changed, 37 insertions, 8 deletions
diff --git a/src/main.rs b/src/main.rs
index 6423ad7..3859754 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -10,6 +10,7 @@ use std::{
 };
 
 use httparse::{Request, Status};
+use scurvy::{Argument, Scurvy};
 use smol::{
 	LocalExecutor, future,
 	io::{AsyncReadExt, AsyncWriteExt},
@@ -18,6 +19,10 @@ use smol::{
 };
 
 fn main() {
+	let args = vec![Argument::new(&["ip-header"]).arg("name")];
+	let scurvy = Scurvy::make(args);
+	let ip_header = scurvy.get("ip-header");
+
 	let exec = LocalExecutor::new();
 	let arxc = Arc::new(exec);
 
@@ -30,12 +35,18 @@ fn main() {
 
 	let moving_arxc = Arc::clone(&arxc);
 	let moving_arcdb = Arc::clone(&arcdb);
-	future::block_on(arxc.run(async move { async_main(moving_arxc, moving_arcdb).await }));
+	future::block_on(
+		arxc.run(async move { async_main(moving_arxc, moving_arcdb, ip_header).await }),
+	);
 }
 
 /// This function is what's run during the lifetime of the program by our
 /// [LocalExecutor]
-async fn async_main(ex: Arc<LocalExecutor<'_>>, db: Arc<RwLock<Database>>) {
+async fn async_main<'a>(
+	ex: Arc<LocalExecutor<'a>>,
+	db: Arc<RwLock<Database>>,
+	ip_header: Option<&'a str>,
+) {
 	let addr = SocketAddr::from(([0, 0, 0, 0], 12345));
 	let listen = TcpListener::bind(addr).await.unwrap();
 
@@ -43,13 +54,18 @@ async fn async_main(ex: Arc<LocalExecutor<'_>>, db: Arc<RwLock<Database>>) {
 		let (stream, caddr) = listen.accept().await.unwrap();
 		println!("accepted!");
 		let moving_db = Arc::clone(&db);
-		ex.spawn(async move { handle_request(stream, caddr, moving_db).await })
+		ex.spawn(async move { handle_request(stream, caddr, moving_db, ip_header).await })
 			.detach();
 	}
 }
 
 /// Handles a single connection from a client.
-async fn handle_request(mut stream: TcpStream, caddr: SocketAddr, db: Arc<RwLock<Database>>) {
+async fn handle_request(
+	mut stream: TcpStream,
+	caddr: SocketAddr,
+	db: Arc<RwLock<Database>>,
+	ip_header: Option<&str>,
+) {
 	let mut buffer = vec![0; 1024];
 	let mut data_end = 0;
 	loop {
@@ -93,7 +109,7 @@ async fn handle_request(mut stream: TcpStream, caddr: SocketAddr, db: Arc<RwLock
 
 		// Fall through the match. We should guarantee that, if we end up here,
 		// we got a Status::Complete from the parser.
-		let response = form_response(req, caddr, db).await;
+		let response = form_response(req, caddr, db, ip_header).await;
 
 		if let Err(err) = stream.write_all(&response).await {
 			eprintln!("Error in handle_request sending response back. Bailing.\nError: {err}");
@@ -125,11 +141,11 @@ fn error(code: u16, msg: &'static str) -> Vec<u8> {
 /// - Err(Utf8Error) if the header value could not be parsed as a utf8 string
 fn get_header<'h>(
 	headers: &'h [httparse::Header<'h>],
-	key: &'static str,
+	key: &str,
 ) -> Result<Option<&'h str>, Vec<u8>> {
 	headers
 		.iter()
-		.find(|h| h.name.to_lowercase() == key)
+		.find(|h| h.name.to_lowercase() == key.to_lowercase())
 		.map(|v| str::from_utf8(v.value).map_err(|_| error(400, "malformed header")))
 		.transpose()
 }
@@ -151,6 +167,7 @@ async fn form_response<'req>(
 	request: Request<'req, 'req>,
 	caddr: SocketAddr,
 	db: Arc<RwLock<Database>>,
+	ip_header: Option<&str>,
 ) -> Vec<u8> {
 	let is_head = match request.method.unwrap() {
 		"GET" => false,
@@ -169,8 +186,20 @@ async fn form_response<'req>(
 		Ok(Some(_)) => return error(415, "unsupported content-type"),
 	};
 
+	let client_addr = match ip_header {
+		None => caddr.ip(),
+		Some(header_name) => match get_header(request.headers, header_name) {
+			Err(_e) => panic!("panic on ip header parse failure"),
+			Ok(None) => panic!("panic on ip header missing"),
+			Ok(Some(addr)) => match IpAddr::from_str(addr) {
+				Err(_e) => panic!("panic on ip header non-ipaddr value"),
+				Ok(addr) => addr,
+			},
+		},
+	};
+
 	let response = Response {
-		client_addr: caddr.ip(),
+		client_addr,
 		http_minor: minor,
 		is_head,
 	};