From 2602759ef92a840e2723e1795c6c4677a2d11ab4 Mon Sep 17 00:00:00 2001 From: gennyble Date: Mon, 16 Jun 2025 18:51:34 -0500 Subject: use authorization header to match ips between requests --- src/main.rs | 327 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 316 insertions(+), 11 deletions(-) diff --git a/src/main.rs b/src/main.rs index b9393f0..a3da041 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,20 @@ use core::fmt; -use std::{net::SocketAddr, sync::Arc}; +use std::{ + borrow::{Borrow, Cow}, + clone, + collections::HashMap, + io::{self, BufRead, BufReader, BufWriter, Write}, + net::{IpAddr, SocketAddr}, + path::PathBuf, + str::FromStr, + sync::Arc, +}; use httparse::{Request, Status}; use smol::{ LocalExecutor, future, io::{AsyncReadExt, AsyncWriteExt}, + lock::RwLock, net::{TcpListener, TcpStream}, }; @@ -12,23 +22,32 @@ fn main() { let exec = LocalExecutor::new(); let arxc = Arc::new(exec); + let mut db = Database::new("database".into()); + if db.path.exists() { + db.load().unwrap(); + } + + let arcdb = Arc::new(RwLock::new(db)); + let moving_arxc = Arc::clone(&arxc); - future::block_on(arxc.run(async move { async_main(moving_arxc).await })); + let moving_arcdb = Arc::clone(&arcdb); + future::block_on(arxc.run(async move { async_main(moving_arxc, moving_arcdb).await })); } -async fn async_main(ex: Arc>) { +async fn async_main(ex: Arc>, db: Arc>) { let addr = SocketAddr::from(([0, 0, 0, 0], 12345)); let listen = TcpListener::bind(addr).await.unwrap(); loop { let (stream, caddr) = listen.accept().await.unwrap(); println!("accepted!"); - ex.spawn(async move { handle_request(stream, caddr).await }) + let moving_db = Arc::clone(&db); + ex.spawn(async move { handle_request(stream, caddr, moving_db).await }) .detach(); } } -async fn handle_request(mut stream: TcpStream, caddr: SocketAddr) { +async fn handle_request(mut stream: TcpStream, caddr: SocketAddr, db: Arc>) { let mut buffer = vec![0; 1024]; let mut data_end = 0; loop { @@ -72,7 +91,7 @@ async fn handle_request(mut stream: TcpStream, caddr: SocketAddr) { // Fall through the match. We should guarantee that, if we end up here, // we got a Status::Complete from the parser. - let response = form_response(req, caddr).await; + let response = form_response(req, caddr, db).await; if let Err(err) = stream.write_all(&response).await { eprintln!("Error in handle_request sending response back. Bailing.\nError: {err}"); @@ -105,9 +124,26 @@ fn get_header<'h>( .transpose() } +struct Response { + /// IP the request came from. This is very important in our application. + client_addr: IpAddr, + /// HTTP minor version. Read this from the request so we can respond properly. + http_minor: u8, + /// Response status code + status: u16, + /// Response status message + message: &'static str, + /// Indicates if this is a GET or a HEAD request + is_head: bool, +} + /// Create a response, whatever it is. This function forms responses for /// success, and errors. -async fn form_response<'req>(request: Request<'req, 'req>, caddr: SocketAddr) -> Vec { +async fn form_response<'req>( + request: Request<'req, 'req>, + caddr: SocketAddr, + db: Arc>, +) -> Vec { let is_head = match request.method.unwrap() { "GET" => false, "HEAD" => true, @@ -128,15 +164,81 @@ async fn form_response<'req>(request: Request<'req, 'req>, caddr: SocketAddr) -> Ok(Some(_)) => return error(415, "unsupported content-type"), }; + let response = Response { + client_addr: caddr.ip(), + http_minor: minor, + status, + message, + is_head, + }; + + match get_header(request.headers, "authorization") { + Err(e) => return e, + Ok(None) => return make_ip_response(requested_type, response), + Ok(Some(auth)) => { + println!("Authorization: {auth}"); + + let aure = match handle_auth(caddr.ip(), auth.trim(), db).await { + Err(RuntimeError::AuthInvalid) => { + return error( + 400, + "Auth Invalid; the auth code contained invalid characters", + ); + } + Err(_e) => { + return error(500, "Internal Server Error"); + } + Ok(a) => a, + }; + + match aure { + AuthResponse::NewAuth | AuthResponse::DifferentIp { .. } => { + println!("New/Different"); + + return make_ip_response(requested_type, response); + } + AuthResponse::MatchingIp => { + println!("Matches"); + + let mut response = + format!("HTTP/1.{minor} 304 Not Modified; Your IP is the same\n"); + + macro_rules! header { + ($content:expr) => { + response.push_str(&format!($content)); + response.push('\n'); + }; + } + + header!("Content-Length: 0"); + header!("Cache-Control: no-store"); + response.push('\n'); + + return response.into_bytes(); + } + } + } + } +} + +fn make_ip_response(requested_type: ContentType, response: Response) -> Vec { + let Response { + client_addr, + http_minor, + status, + message, + is_head, + } = response; + // Make our response body! let body = match requested_type { - ContentType::Plain => format!("{}", caddr.ip()), - ContentType::Json => format!(r#"{{"ip": "{}"}}"#, caddr.ip()), + ContentType::Plain => format!("{client_addr}"), + ContentType::Json => format!(r#"{{"ip": "{client_addr}"}}"#), }; let length = body.len(); // And now we form the request :D - let mut response = format!("HTTP/1.{minor} {status} {message}\n"); + let mut response = format!("HTTP/1.{http_minor} {status} {message}\n"); macro_rules! header { ($content:expr) => { @@ -149,14 +251,60 @@ async fn form_response<'req>(request: Request<'req, 'req>, caddr: SocketAddr) -> header!("Content-Type: {requested_type}"); header!("Cache-Control: no-store"); + response.push('\n'); if !is_head { - response.push('\n'); response.push_str(&body); } response.into_bytes() } +enum AuthResponse { + /// Returned when the key has never been seen before. + NewAuth, + /// Returned when the IP address matches the last IP that made a request + /// with the provided auth key + MatchingIp, + /// The IP held by the auth key is not the same one the client is + /// requesting from. + DifferentIp { old: IpAddr }, +} + +async fn handle_auth( + client_addr: IpAddr, + auth: &str, + db: Arc>, +) -> Result { + let key = match AuthKey::from_str(auth) { + Ok(ak) => ak, + Err(_) => { + return Err(RuntimeError::AuthInvalid); + } + }; + + let lock = db.read().await; + match lock.get_by_key(key.clone()) { + None => { + drop(lock); + let mut write = db.write().await; + write.update_ip(key, client_addr); + + return Ok(AuthResponse::NewAuth); + } + Some(addr) => { + if addr == client_addr { + return Ok(AuthResponse::MatchingIp); + } else { + drop(lock); + let mut write = db.write().await; + write.update_ip(key, client_addr); + + return Ok(AuthResponse::DifferentIp { old: addr }); + } + } + } +} + enum ContentType { Plain, Json, @@ -170,3 +318,160 @@ impl fmt::Display for ContentType { } } } + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct AuthKey<'a>(Cow<'a, str>); +const BASE64: &'static str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890+/="; + +impl<'a> AuthKey<'a> { + fn from_str_owned(s: &str) -> Result { + Self::check_str(s)?; + Ok(AuthKey(Cow::Owned(s.to_string()))) + } + + fn from_str(s: &'a str) -> Result { + Self::check_str(s)?; + Ok(AuthKey(Cow::Borrowed(s))) + } + + fn check_str(s: &str) -> Result<(), ()> { + for ch in s.chars() { + if !BASE64.contains(ch) { + return Err(()); + } + } + + Ok(()) + } + + fn as_static_owned(self) -> AuthKey<'static> { + let string = self.0.to_string(); + AuthKey(Cow::Owned(string)) + } +} + +impl<'a> fmt::Display for AuthKey<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +pub struct Database { + path: PathBuf, + map: HashMap, IpAddr>, +} + +impl Database { + pub fn new(path: PathBuf) -> Self { + Self { + path, + map: HashMap::new(), + } + } + + fn load(&mut self) -> Result<(), RuntimeError> { + let file = + std::fs::File::open(&self.path).map_err(|io| RuntimeError::io(io, &self.path))?; + + let bufr = BufReader::new(file); + for (ln, line) in bufr.lines().enumerate() { + let line = line.map_err(|io| RuntimeError::io(io, &self.path))?; + + match line.split_once(' ') { + None => { + return Err(RuntimeError::malformed_database( + &self.path, + ln, + DatabaseMalformation::Line, + )); + } + Some((raw_key, raw_ip)) => { + let auth = AuthKey::from_str_owned(raw_key).map_err(|_| { + RuntimeError::malformed_database( + &self.path, + ln, + DatabaseMalformation::AuthKey, + ) + })?; + + let ip = IpAddr::from_str(raw_ip).map_err(|_| { + RuntimeError::malformed_database( + &self.path, + ln, + DatabaseMalformation::IpAddr, + ) + })?; + + self.map.insert(auth, ip); + } + } + } + + Ok(()) + } + + fn save(&self) -> Result<(), RuntimeError> { + let file = + std::fs::File::create(&self.path).map_err(|io| RuntimeError::io(io, &self.path))?; + + let mut bufw = BufWriter::new(file); + for (key, val) in self.map.iter() { + let line = format!("{key} {val}"); + + bufw.write_all(line.as_bytes()) + .map_err(|io| RuntimeError::io(io, &self.path))?; + } + + Ok(()) + } + + fn get_by_key(&self, key: AuthKey<'_>) -> Option { + self.map.get(&key).cloned() + } + + fn update_ip(&mut self, key: AuthKey<'_>, ip: IpAddr) { + self.map.insert(key.as_static_owned(), ip); + } +} + +#[derive(Clone, Copy, Debug)] +enum DatabaseMalformation { + AuthKey, + IpAddr, + Line, +} + +#[derive(Debug)] +enum RuntimeError { + IoError { + source: io::Error, + path: PathBuf, + }, + MalformedDatabase { + path: PathBuf, + ln: usize, + kind: DatabaseMalformation, + }, + AuthInvalid, +} + +impl RuntimeError { + pub fn io>(source: io::Error, path: P) -> Self { + RuntimeError::IoError { + source, + path: path.into(), + } + } + + pub fn malformed_database>( + path: P, + ln: usize, + kind: DatabaseMalformation, + ) -> Self { + RuntimeError::MalformedDatabase { + path: path.into(), + ln, + kind, + } + } +} -- cgit 1.4.1-3-g733a5