From df11f96162f89c740aba87d62e4512cfc46049bb Mon Sep 17 00:00:00 2001 From: gennyble Date: Mon, 16 Jun 2025 21:45:59 -0500 Subject: allow getting ip-address from header --- src/main.rs | 45 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 8 deletions(-) (limited to 'src/main.rs') diff --git a/src/main.rs b/src/main.rs index 6423ad7..3859754 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ use std::{ }; use httparse::{Request, Status}; +use scurvy::{Argument, Scurvy}; use smol::{ LocalExecutor, future, io::{AsyncReadExt, AsyncWriteExt}, @@ -18,6 +19,10 @@ use smol::{ }; fn main() { + let args = vec![Argument::new(&["ip-header"]).arg("name")]; + let scurvy = Scurvy::make(args); + let ip_header = scurvy.get("ip-header"); + let exec = LocalExecutor::new(); let arxc = Arc::new(exec); @@ -30,12 +35,18 @@ fn main() { let moving_arxc = Arc::clone(&arxc); let moving_arcdb = Arc::clone(&arcdb); - future::block_on(arxc.run(async move { async_main(moving_arxc, moving_arcdb).await })); + future::block_on( + arxc.run(async move { async_main(moving_arxc, moving_arcdb, ip_header).await }), + ); } /// This function is what's run during the lifetime of the program by our /// [LocalExecutor] -async fn async_main(ex: Arc>, db: Arc>) { +async fn async_main<'a>( + ex: Arc>, + db: Arc>, + ip_header: Option<&'a str>, +) { let addr = SocketAddr::from(([0, 0, 0, 0], 12345)); let listen = TcpListener::bind(addr).await.unwrap(); @@ -43,13 +54,18 @@ async fn async_main(ex: Arc>, db: Arc>) { let (stream, caddr) = listen.accept().await.unwrap(); println!("accepted!"); let moving_db = Arc::clone(&db); - ex.spawn(async move { handle_request(stream, caddr, moving_db).await }) + ex.spawn(async move { handle_request(stream, caddr, moving_db, ip_header).await }) .detach(); } } /// Handles a single connection from a client. -async fn handle_request(mut stream: TcpStream, caddr: SocketAddr, db: Arc>) { +async fn handle_request( + mut stream: TcpStream, + caddr: SocketAddr, + db: Arc>, + ip_header: Option<&str>, +) { let mut buffer = vec![0; 1024]; let mut data_end = 0; loop { @@ -93,7 +109,7 @@ async fn handle_request(mut stream: TcpStream, caddr: SocketAddr, db: Arc Vec { /// - Err(Utf8Error) if the header value could not be parsed as a utf8 string fn get_header<'h>( headers: &'h [httparse::Header<'h>], - key: &'static str, + key: &str, ) -> Result, Vec> { headers .iter() - .find(|h| h.name.to_lowercase() == key) + .find(|h| h.name.to_lowercase() == key.to_lowercase()) .map(|v| str::from_utf8(v.value).map_err(|_| error(400, "malformed header"))) .transpose() } @@ -151,6 +167,7 @@ async fn form_response<'req>( request: Request<'req, 'req>, caddr: SocketAddr, db: Arc>, + ip_header: Option<&str>, ) -> Vec { let is_head = match request.method.unwrap() { "GET" => false, @@ -169,8 +186,20 @@ async fn form_response<'req>( Ok(Some(_)) => return error(415, "unsupported content-type"), }; + let client_addr = match ip_header { + None => caddr.ip(), + Some(header_name) => match get_header(request.headers, header_name) { + Err(_e) => panic!("panic on ip header parse failure"), + Ok(None) => panic!("panic on ip header missing"), + Ok(Some(addr)) => match IpAddr::from_str(addr) { + Err(_e) => panic!("panic on ip header non-ipaddr value"), + Ok(addr) => addr, + }, + }, + }; + let response = Response { - client_addr: caddr.ip(), + client_addr, http_minor: minor, is_head, }; -- cgit 1.4.1-3-g733a5