diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 45 |
1 files changed, 37 insertions, 8 deletions
diff --git a/src/main.rs b/src/main.rs index 6423ad7..3859754 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ use std::{ }; use httparse::{Request, Status}; +use scurvy::{Argument, Scurvy}; use smol::{ LocalExecutor, future, io::{AsyncReadExt, AsyncWriteExt}, @@ -18,6 +19,10 @@ use smol::{ }; fn main() { + let args = vec![Argument::new(&["ip-header"]).arg("name")]; + let scurvy = Scurvy::make(args); + let ip_header = scurvy.get("ip-header"); + let exec = LocalExecutor::new(); let arxc = Arc::new(exec); @@ -30,12 +35,18 @@ fn main() { let moving_arxc = Arc::clone(&arxc); let moving_arcdb = Arc::clone(&arcdb); - future::block_on(arxc.run(async move { async_main(moving_arxc, moving_arcdb).await })); + future::block_on( + arxc.run(async move { async_main(moving_arxc, moving_arcdb, ip_header).await }), + ); } /// This function is what's run during the lifetime of the program by our /// [LocalExecutor] -async fn async_main(ex: Arc<LocalExecutor<'_>>, db: Arc<RwLock<Database>>) { +async fn async_main<'a>( + ex: Arc<LocalExecutor<'a>>, + db: Arc<RwLock<Database>>, + ip_header: Option<&'a str>, +) { let addr = SocketAddr::from(([0, 0, 0, 0], 12345)); let listen = TcpListener::bind(addr).await.unwrap(); @@ -43,13 +54,18 @@ async fn async_main(ex: Arc<LocalExecutor<'_>>, db: Arc<RwLock<Database>>) { let (stream, caddr) = listen.accept().await.unwrap(); println!("accepted!"); let moving_db = Arc::clone(&db); - ex.spawn(async move { handle_request(stream, caddr, moving_db).await }) + ex.spawn(async move { handle_request(stream, caddr, moving_db, ip_header).await }) .detach(); } } /// Handles a single connection from a client. -async fn handle_request(mut stream: TcpStream, caddr: SocketAddr, db: Arc<RwLock<Database>>) { +async fn handle_request( + mut stream: TcpStream, + caddr: SocketAddr, + db: Arc<RwLock<Database>>, + ip_header: Option<&str>, +) { let mut buffer = vec![0; 1024]; let mut data_end = 0; loop { @@ -93,7 +109,7 @@ async fn handle_request(mut stream: TcpStream, caddr: SocketAddr, db: Arc<RwLock // Fall through the match. We should guarantee that, if we end up here, // we got a Status::Complete from the parser. - let response = form_response(req, caddr, db).await; + let response = form_response(req, caddr, db, ip_header).await; if let Err(err) = stream.write_all(&response).await { eprintln!("Error in handle_request sending response back. Bailing.\nError: {err}"); @@ -125,11 +141,11 @@ fn error(code: u16, msg: &'static str) -> Vec<u8> { /// - Err(Utf8Error) if the header value could not be parsed as a utf8 string fn get_header<'h>( headers: &'h [httparse::Header<'h>], - key: &'static str, + key: &str, ) -> Result<Option<&'h str>, Vec<u8>> { headers .iter() - .find(|h| h.name.to_lowercase() == key) + .find(|h| h.name.to_lowercase() == key.to_lowercase()) .map(|v| str::from_utf8(v.value).map_err(|_| error(400, "malformed header"))) .transpose() } @@ -151,6 +167,7 @@ async fn form_response<'req>( request: Request<'req, 'req>, caddr: SocketAddr, db: Arc<RwLock<Database>>, + ip_header: Option<&str>, ) -> Vec<u8> { let is_head = match request.method.unwrap() { "GET" => false, @@ -169,8 +186,20 @@ async fn form_response<'req>( Ok(Some(_)) => return error(415, "unsupported content-type"), }; + let client_addr = match ip_header { + None => caddr.ip(), + Some(header_name) => match get_header(request.headers, header_name) { + Err(_e) => panic!("panic on ip header parse failure"), + Ok(None) => panic!("panic on ip header missing"), + Ok(Some(addr)) => match IpAddr::from_str(addr) { + Err(_e) => panic!("panic on ip header non-ipaddr value"), + Ok(addr) => addr, + }, + }, + }; + let response = Response { - client_addr: caddr.ip(), + client_addr, http_minor: minor, is_head, }; |