diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 101 |
1 files changed, 77 insertions, 24 deletions
diff --git a/src/main.rs b/src/main.rs index d288cba..e352fe5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ -use std::{net::SocketAddr, sync::Arc}; +use core::fmt; +use std::{borrow::Cow, fmt::format, net::SocketAddr, sync::Arc}; use httparse::{Request, Status}; use smol::{ @@ -69,32 +70,11 @@ async fn handle_request(mut stream: TcpStream, caddr: SocketAddr) { } } - let minor = req.version.unwrap(); - let status = 200; - let message = "Gotcha"; - // Fall through the match. We should guarantee that, if we end up here, // we got a Status::Complete from the parser - let mut response = format!("HTTP/1.{minor} {status} {message}\n"); - - let body = format!("{}", caddr.ip()); - let length = body.len(); - - macro_rules! header { - ($content:expr) => { - response.push_str(&format!($content)); - response.push('\n'); - }; - } - - header!("Content-Length: {length}"); - header!("Content-Type: text/plain"); - header!("Cache-Control: no-store"); - - response.push('\n'); - response.push_str(&body); + let response = form_response(req, caddr).await; - if let Err(err) = stream.write_all(response.as_bytes()).await { + if let Err(err) = stream.write_all(&response).await { eprintln!("Error in handle_request sending response back. Bailing.\nError: {err}"); return; } else { @@ -102,3 +82,76 @@ async fn handle_request(mut stream: TcpStream, caddr: SocketAddr) { } } } + +fn error(code: u16, msg: &'static str) -> Vec<u8> { + format!( + "HTTP/1.1 {code} {msg}\n\ + server: piper/0.1\n\ + content-length: 0\n\ + cache-control: no-store\n\n" + ) + .as_bytes() + .to_vec() +} + +fn get_header<'h>( + headers: &'h [httparse::Header<'h>], + key: &'static str, +) -> Result<Option<&'h str>, Vec<u8>> { + headers + .iter() + .find(|h| h.name.to_lowercase() == key) + .map(|v| str::from_utf8(v.value).map_err(|_| error(400, "malformed header"))) + .transpose() +} + +async fn form_response<'req>(request: Request<'req, 'req>, caddr: SocketAddr) -> Vec<u8> { + let minor = request.version.unwrap(); + let status = 200; + let message = "Gotcha"; + + let mut response = format!("HTTP/1.{minor} {status} {message}\n"); + + let requested_type = match get_header(request.headers, "content-type") { + Err(e) => return e, + Ok(None) | Ok(Some("text/plain")) => ContentType::Plain, + Ok(Some("application/json")) => ContentType::Json, + Ok(Some(_)) => return error(400, "unknown content-type"), + }; + + let body = match requested_type { + ContentType::Plain => format!("{}", caddr.ip()), + ContentType::Json => format!(r#"{{"ip": "{}"}}"#, caddr.ip()), + }; + let length = body.len(); + + macro_rules! header { + ($content:expr) => { + response.push_str(&format!($content)); + response.push('\n'); + }; + } + + header!("Content-Length: {length}"); + header!("Content-Type: {requested_type}"); + header!("Cache-Control: no-store"); + + response.push('\n'); + response.push_str(&body); + + response.into_bytes() +} + +enum ContentType { + Plain, + Json, +} + +impl fmt::Display for ContentType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ContentType::Plain => write!(f, "text/plain"), + ContentType::Json => write!(f, "application/json"), + } + } +} |