use core::fmt; use std::{ borrow::Cow, collections::HashMap, fmt::format, io::{self, BufRead, BufReader, BufWriter, Write}, net::{IpAddr, SocketAddr}, path::PathBuf, str::FromStr, sync::Arc, }; use httparse::{Request, Status}; use scurvy::{Argument, Scurvy}; use smol::{ LocalExecutor, future, io::{AsyncReadExt, AsyncWriteExt}, lock::RwLock, net::{TcpListener, TcpStream}, }; fn main() { let args = vec![Argument::new(&["ip-header"]).arg("name")]; let scurvy = Scurvy::make(args); let ip_header = scurvy.get("ip-header"); let exec = LocalExecutor::new(); let arxc = Arc::new(exec); let mut db = Database::new("database".into()); if db.path.exists() { db.load().unwrap(); } let arcdb = Arc::new(RwLock::new(db)); let moving_arxc = Arc::clone(&arxc); let moving_arcdb = Arc::clone(&arcdb); future::block_on( arxc.run(async move { async_main(moving_arxc, moving_arcdb, ip_header).await }), ); } /// This function is what's run during the lifetime of the program by our /// [LocalExecutor] async fn async_main<'a>( ex: Arc>, db: Arc>, ip_header: Option<&'a str>, ) { let addr = SocketAddr::from(([0, 0, 0, 0], 12345)); let listen = TcpListener::bind(addr).await.unwrap(); loop { let (stream, caddr) = listen.accept().await.unwrap(); println!("accepted!"); let moving_db = Arc::clone(&db); ex.spawn(async move { handle_request(stream, caddr, moving_db, ip_header).await }) .detach(); } } /// Handles a single connection from a client. async fn handle_request( mut stream: TcpStream, caddr: SocketAddr, db: Arc>, ip_header: Option<&str>, ) { let mut buffer = vec![0; 1024]; let mut data_end = 0; loop { println!("Waiting on read"); match stream.read(&mut buffer[data_end..]).await { Err(_err) => { eprintln!("Error in handle_request during read. Bailing."); return; } Ok(bytes_read) => { data_end += bytes_read; println!("read {bytes_read}"); } } let mut headers = [httparse::EMPTY_HEADER; 64]; let mut req = Request::new(&mut headers); match req.parse(&buffer[..data_end]) { Err(err) => { eprintln!("Error in handle_request during parsing. Bailing.\nerror: {err}"); return; } Ok(Status::Partial) => { // Going around again. // Make sure we have room in the buffer. 512 seems good :) if buffer.len() == data_end { buffer.resize(buffer.len() + 512, 0); } println!("partial!"); continue; } Ok(Status::Complete(body_offset)) => { println!("Got entire request!"); println!("method: {}", req.method.unwrap()); println!("path: {}", req.path.unwrap()); println!("body_offset: {body_offset}"); } } // Fall through the match. We should guarantee that, if we end up here, // we got a Status::Complete from the parser. let response = form_response(req, caddr, db, ip_header).await; if let Err(err) = stream.write_all(&response).await { eprintln!("Error in handle_request sending response back. Bailing.\nError: {err}"); return; } else { return; } } } /// Form a simple HTTP/1.1 error response with no body and the specified status /// code and message. fn error(code: u16, msg: &'static str) -> Vec { format!( "HTTP/1.1 {code} {msg}\n\ server: piper/0.1\n\ content-length: 0\n\ cache-control: no-store\n\n" ) .as_bytes() .to_vec() } /// Get the request header with the case-insensetive `key`. /// /// Returns: /// - Ok(None) the header was not present /// - Ok(Some(value)) the header value, interpreted as a utf8 string /// - Err(Utf8Error) if the header value could not be parsed as a utf8 string fn get_header<'h>( headers: &'h [httparse::Header<'h>], key: &str, ) -> Result, Vec> { headers .iter() .find(|h| h.name.to_lowercase() == key.to_lowercase()) .map(|v| str::from_utf8(v.value).map_err(|_| error(400, "malformed header"))) .transpose() } /// Small struct containing details from the request we might want in the /// response struct Response { /// IP the request came from. This is very important in our application. client_addr: IpAddr, /// HTTP minor version. Read this from the request so we can respond properly. http_minor: u8, /// Indicates if this is a GET or a HEAD request is_head: bool, } /// Create a response, whatever it is. This function forms responses for /// success, and errors. async fn form_response<'req>( request: Request<'req, 'req>, caddr: SocketAddr, db: Arc>, ip_header: Option<&str>, ) -> Vec { let is_head = match request.method.unwrap() { "GET" => false, "HEAD" => true, // This is not strictly correct, but it's correct enough? :) // 405 is "The request method is known by the server but // is not supported by the target resource" _ => return error(405, "we do not support that request method"), }; let minor = request.version.unwrap(); let requested_type = match get_header(request.headers, "content-type") { Err(e) => return e, Ok(None) | Ok(Some("text/plain")) => ContentType::Plain, Ok(Some("application/json")) => ContentType::Json, Ok(Some("application/xml")) | Ok(Some("text/xml")) => ContentType::Xml, Ok(Some(_)) => return error(415, "unsupported content-type"), }; let client_addr = match ip_header { None => caddr.ip(), Some(header_name) => match get_header(request.headers, header_name) { Err(_e) => panic!("panic on ip header parse failure"), Ok(None) => panic!("panic on ip header missing"), Ok(Some(addr)) => match IpAddr::from_str(addr) { Err(_e) => panic!("panic on ip header non-ipaddr value"), Ok(addr) => addr, }, }, }; let response = Response { client_addr, http_minor: minor, is_head, }; match get_header(request.headers, "authorization") { Err(e) => return e, Ok(None) => return make_ip_response(requested_type, response), Ok(Some(auth)) => { println!("Authorization: {auth}"); let aure = match handle_auth(caddr.ip(), auth.trim(), db).await { Err(RuntimeError::AuthInvalid) => { return error( 400, "Auth Invalid; the auth code contained invalid characters", ); } Err(_e) => { return error(500, "Internal Server Error"); } Ok(a) => a, }; match aure { AuthResponse::NewAuth | AuthResponse::DifferentIp { .. } => { println!("New/Different"); return make_ip_response(requested_type, response); } AuthResponse::MatchingIp => { println!("Matches"); let mut response = format!("HTTP/1.{minor} 304 Not Modified; Your IP is the same\n"); macro_rules! header { ($content:expr) => { response.push_str(&format!($content)); response.push('\n'); }; } header!("Content-Length: 0"); header!("Cache-Control: no-store"); response.push('\n'); return response.into_bytes(); } } } } } /// Form a `200 Gotcha` response containing the client's address and with the /// correct content-type. fn make_ip_response(requested_type: ContentType, response: Response) -> Vec { let Response { client_addr, http_minor, is_head, } = response; // Make our response body! let body = match requested_type { ContentType::Plain => format!("{client_addr}"), ContentType::Json => format!(r#"{{"ip": "{client_addr}"}}"#), ContentType::Xml => { format!("\n{client_addr}") } }; let length = body.len(); // And now we form the request :D let mut response = format!("HTTP/1.{http_minor} 200 Gotcha\n"); macro_rules! header { ($content:expr) => { response.push_str(&format!($content)); response.push('\n'); }; } header!("Content-Length: {length}"); header!("Content-Type: {requested_type}"); header!("Cache-Control: no-store"); response.push('\n'); if !is_head { response.push_str(&body); } response.into_bytes() } /// Possible responses when a client includes an `Authorization` header enum AuthResponse { /// Returned when the key has never been seen before. NewAuth, /// Returned when the IP address matches the last IP that made a request /// with the provided auth key MatchingIp, /// The IP held by the auth key is not the same one the client is /// requesting from. DifferentIp { old: IpAddr }, } async fn handle_auth( client_addr: IpAddr, auth: &str, db: Arc>, ) -> Result { let key = match AuthKey::from_str(auth) { Ok(ak) => ak, Err(_) => { return Err(RuntimeError::AuthInvalid); } }; let lock = db.read().await; match lock.get_by_key(key.clone()) { None => { drop(lock); let mut write = db.write().await; write.update_ip(key, client_addr); return Ok(AuthResponse::NewAuth); } Some(addr) => { if addr == client_addr { return Ok(AuthResponse::MatchingIp); } else { drop(lock); let mut write = db.write().await; write.update_ip(key, client_addr); return Ok(AuthResponse::DifferentIp { old: addr }); } } } } #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct AuthKey<'a>(Cow<'a, str>); const BASE64: &'static str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890+/="; impl<'a> AuthKey<'a> { fn from_str_owned(s: &str) -> Result { Self::check_str(s)?; Ok(AuthKey(Cow::Owned(s.to_string()))) } fn from_str(s: &'a str) -> Result { Self::check_str(s)?; Ok(AuthKey(Cow::Borrowed(s))) } fn check_str(s: &str) -> Result<(), ()> { for ch in s.chars() { if !BASE64.contains(ch) { return Err(()); } } Ok(()) } fn as_static_owned(self) -> AuthKey<'static> { let string = self.0.to_string(); AuthKey(Cow::Owned(string)) } } impl<'a> fmt::Display for AuthKey<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } pub struct Database { path: PathBuf, map: HashMap, IpAddr>, } impl Database { pub fn new(path: PathBuf) -> Self { Self { path, map: HashMap::new(), } } fn load(&mut self) -> Result<(), RuntimeError> { let file = std::fs::File::open(&self.path).map_err(|io| RuntimeError::io(io, &self.path))?; let bufr = BufReader::new(file); for (ln, line) in bufr.lines().enumerate() { let line = line.map_err(|io| RuntimeError::io(io, &self.path))?; match line.split_once(' ') { None => { return Err(RuntimeError::malformed_database( &self.path, ln, DatabaseMalformation::Line, )); } Some((raw_key, raw_ip)) => { let auth = AuthKey::from_str_owned(raw_key).map_err(|_| { RuntimeError::malformed_database( &self.path, ln, DatabaseMalformation::AuthKey, ) })?; let ip = IpAddr::from_str(raw_ip).map_err(|_| { RuntimeError::malformed_database( &self.path, ln, DatabaseMalformation::IpAddr, ) })?; self.map.insert(auth, ip); } } } Ok(()) } fn save(&self) -> Result<(), RuntimeError> { let file = std::fs::File::create(&self.path).map_err(|io| RuntimeError::io(io, &self.path))?; let mut bufw = BufWriter::new(file); for (key, val) in self.map.iter() { let line = format!("{key} {val}"); bufw.write_all(line.as_bytes()) .map_err(|io| RuntimeError::io(io, &self.path))?; } Ok(()) } fn get_by_key(&self, key: AuthKey<'_>) -> Option { self.map.get(&key).cloned() } fn update_ip(&mut self, key: AuthKey<'_>, ip: IpAddr) { self.map.insert(key.as_static_owned(), ip); } } /// Content types we know about enum ContentType { Plain, Json, Xml, } impl fmt::Display for ContentType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ContentType::Plain => write!(f, "text/plain"), ContentType::Json => write!(f, "application/json"), ContentType::Xml => write!(f, "application/xml"), } } } #[derive(Debug)] enum RuntimeError { IoError { source: io::Error, path: PathBuf, }, MalformedDatabase { path: PathBuf, ln: usize, kind: DatabaseMalformation, }, AuthInvalid, } impl RuntimeError { pub fn io>(source: io::Error, path: P) -> Self { RuntimeError::IoError { source, path: path.into(), } } pub fn malformed_database>( path: P, ln: usize, kind: DatabaseMalformation, ) -> Self { RuntimeError::MalformedDatabase { path: path.into(), ln, kind, } } } #[derive(Clone, Copy, Debug)] enum DatabaseMalformation { AuthKey, IpAddr, Line, }