about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main.rs327
1 files changed, 316 insertions, 11 deletions
diff --git a/src/main.rs b/src/main.rs
index b9393f0..a3da041 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,10 +1,20 @@
 use core::fmt;
-use std::{net::SocketAddr, sync::Arc};
+use std::{
+	borrow::{Borrow, Cow},
+	clone,
+	collections::HashMap,
+	io::{self, BufRead, BufReader, BufWriter, Write},
+	net::{IpAddr, SocketAddr},
+	path::PathBuf,
+	str::FromStr,
+	sync::Arc,
+};
 
 use httparse::{Request, Status};
 use smol::{
 	LocalExecutor, future,
 	io::{AsyncReadExt, AsyncWriteExt},
+	lock::RwLock,
 	net::{TcpListener, TcpStream},
 };
 
@@ -12,23 +22,32 @@ fn main() {
 	let exec = LocalExecutor::new();
 	let arxc = Arc::new(exec);
 
+	let mut db = Database::new("database".into());
+	if db.path.exists() {
+		db.load().unwrap();
+	}
+
+	let arcdb = Arc::new(RwLock::new(db));
+
 	let moving_arxc = Arc::clone(&arxc);
-	future::block_on(arxc.run(async move { async_main(moving_arxc).await }));
+	let moving_arcdb = Arc::clone(&arcdb);
+	future::block_on(arxc.run(async move { async_main(moving_arxc, moving_arcdb).await }));
 }
 
-async fn async_main(ex: Arc<LocalExecutor<'_>>) {
+async fn async_main(ex: Arc<LocalExecutor<'_>>, db: Arc<RwLock<Database>>) {
 	let addr = SocketAddr::from(([0, 0, 0, 0], 12345));
 	let listen = TcpListener::bind(addr).await.unwrap();
 
 	loop {
 		let (stream, caddr) = listen.accept().await.unwrap();
 		println!("accepted!");
-		ex.spawn(async move { handle_request(stream, caddr).await })
+		let moving_db = Arc::clone(&db);
+		ex.spawn(async move { handle_request(stream, caddr, moving_db).await })
 			.detach();
 	}
 }
 
-async fn handle_request(mut stream: TcpStream, caddr: SocketAddr) {
+async fn handle_request(mut stream: TcpStream, caddr: SocketAddr, db: Arc<RwLock<Database>>) {
 	let mut buffer = vec![0; 1024];
 	let mut data_end = 0;
 	loop {
@@ -72,7 +91,7 @@ async fn handle_request(mut stream: TcpStream, caddr: SocketAddr) {
 
 		// Fall through the match. We should guarantee that, if we end up here,
 		// we got a Status::Complete from the parser.
-		let response = form_response(req, caddr).await;
+		let response = form_response(req, caddr, db).await;
 
 		if let Err(err) = stream.write_all(&response).await {
 			eprintln!("Error in handle_request sending response back. Bailing.\nError: {err}");
@@ -105,9 +124,26 @@ fn get_header<'h>(
 		.transpose()
 }
 
+struct Response {
+	/// IP the request came from. This is very important in our application.
+	client_addr: IpAddr,
+	/// HTTP minor version. Read this from the request so we can respond properly.
+	http_minor: u8,
+	/// Response status code
+	status: u16,
+	/// Response status message
+	message: &'static str,
+	/// Indicates if this is a GET or a HEAD request
+	is_head: bool,
+}
+
 /// Create a response, whatever it is. This function forms responses for
 /// success, and errors.
-async fn form_response<'req>(request: Request<'req, 'req>, caddr: SocketAddr) -> Vec<u8> {
+async fn form_response<'req>(
+	request: Request<'req, 'req>,
+	caddr: SocketAddr,
+	db: Arc<RwLock<Database>>,
+) -> Vec<u8> {
 	let is_head = match request.method.unwrap() {
 		"GET" => false,
 		"HEAD" => true,
@@ -128,15 +164,81 @@ async fn form_response<'req>(request: Request<'req, 'req>, caddr: SocketAddr) ->
 		Ok(Some(_)) => return error(415, "unsupported content-type"),
 	};
 
+	let response = Response {
+		client_addr: caddr.ip(),
+		http_minor: minor,
+		status,
+		message,
+		is_head,
+	};
+
+	match get_header(request.headers, "authorization") {
+		Err(e) => return e,
+		Ok(None) => return make_ip_response(requested_type, response),
+		Ok(Some(auth)) => {
+			println!("Authorization: {auth}");
+
+			let aure = match handle_auth(caddr.ip(), auth.trim(), db).await {
+				Err(RuntimeError::AuthInvalid) => {
+					return error(
+						400,
+						"Auth Invalid; the auth code contained invalid characters",
+					);
+				}
+				Err(_e) => {
+					return error(500, "Internal Server Error");
+				}
+				Ok(a) => a,
+			};
+
+			match aure {
+				AuthResponse::NewAuth | AuthResponse::DifferentIp { .. } => {
+					println!("New/Different");
+
+					return make_ip_response(requested_type, response);
+				}
+				AuthResponse::MatchingIp => {
+					println!("Matches");
+
+					let mut response =
+						format!("HTTP/1.{minor} 304 Not Modified; Your IP is the same\n");
+
+					macro_rules! header {
+						($content:expr) => {
+							response.push_str(&format!($content));
+							response.push('\n');
+						};
+					}
+
+					header!("Content-Length: 0");
+					header!("Cache-Control: no-store");
+					response.push('\n');
+
+					return response.into_bytes();
+				}
+			}
+		}
+	}
+}
+
+fn make_ip_response(requested_type: ContentType, response: Response) -> Vec<u8> {
+	let Response {
+		client_addr,
+		http_minor,
+		status,
+		message,
+		is_head,
+	} = response;
+
 	// Make our response body!
 	let body = match requested_type {
-		ContentType::Plain => format!("{}", caddr.ip()),
-		ContentType::Json => format!(r#"{{"ip": "{}"}}"#, caddr.ip()),
+		ContentType::Plain => format!("{client_addr}"),
+		ContentType::Json => format!(r#"{{"ip": "{client_addr}"}}"#),
 	};
 	let length = body.len();
 
 	// And now we form the request :D
-	let mut response = format!("HTTP/1.{minor} {status} {message}\n");
+	let mut response = format!("HTTP/1.{http_minor} {status} {message}\n");
 
 	macro_rules! header {
 		($content:expr) => {
@@ -149,14 +251,60 @@ async fn form_response<'req>(request: Request<'req, 'req>, caddr: SocketAddr) ->
 	header!("Content-Type: {requested_type}");
 	header!("Cache-Control: no-store");
 
+	response.push('\n');
 	if !is_head {
-		response.push('\n');
 		response.push_str(&body);
 	}
 
 	response.into_bytes()
 }
 
+enum AuthResponse {
+	/// Returned when the key has never been seen before.
+	NewAuth,
+	/// Returned when the IP address matches the last IP that made a request
+	/// with the provided auth key
+	MatchingIp,
+	/// The IP held by the auth key is not the same one the client is
+	/// requesting from.
+	DifferentIp { old: IpAddr },
+}
+
+async fn handle_auth(
+	client_addr: IpAddr,
+	auth: &str,
+	db: Arc<RwLock<Database>>,
+) -> Result<AuthResponse, RuntimeError> {
+	let key = match AuthKey::from_str(auth) {
+		Ok(ak) => ak,
+		Err(_) => {
+			return Err(RuntimeError::AuthInvalid);
+		}
+	};
+
+	let lock = db.read().await;
+	match lock.get_by_key(key.clone()) {
+		None => {
+			drop(lock);
+			let mut write = db.write().await;
+			write.update_ip(key, client_addr);
+
+			return Ok(AuthResponse::NewAuth);
+		}
+		Some(addr) => {
+			if addr == client_addr {
+				return Ok(AuthResponse::MatchingIp);
+			} else {
+				drop(lock);
+				let mut write = db.write().await;
+				write.update_ip(key, client_addr);
+
+				return Ok(AuthResponse::DifferentIp { old: addr });
+			}
+		}
+	}
+}
+
 enum ContentType {
 	Plain,
 	Json,
@@ -170,3 +318,160 @@ impl fmt::Display for ContentType {
 		}
 	}
 }
+
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+pub struct AuthKey<'a>(Cow<'a, str>);
+const BASE64: &'static str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890+/=";
+
+impl<'a> AuthKey<'a> {
+	fn from_str_owned(s: &str) -> Result<Self, ()> {
+		Self::check_str(s)?;
+		Ok(AuthKey(Cow::Owned(s.to_string())))
+	}
+
+	fn from_str(s: &'a str) -> Result<Self, ()> {
+		Self::check_str(s)?;
+		Ok(AuthKey(Cow::Borrowed(s)))
+	}
+
+	fn check_str(s: &str) -> Result<(), ()> {
+		for ch in s.chars() {
+			if !BASE64.contains(ch) {
+				return Err(());
+			}
+		}
+
+		Ok(())
+	}
+
+	fn as_static_owned(self) -> AuthKey<'static> {
+		let string = self.0.to_string();
+		AuthKey(Cow::Owned(string))
+	}
+}
+
+impl<'a> fmt::Display for AuthKey<'a> {
+	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+		write!(f, "{}", self.0)
+	}
+}
+
+pub struct Database {
+	path: PathBuf,
+	map: HashMap<AuthKey<'static>, IpAddr>,
+}
+
+impl Database {
+	pub fn new(path: PathBuf) -> Self {
+		Self {
+			path,
+			map: HashMap::new(),
+		}
+	}
+
+	fn load(&mut self) -> Result<(), RuntimeError> {
+		let file =
+			std::fs::File::open(&self.path).map_err(|io| RuntimeError::io(io, &self.path))?;
+
+		let bufr = BufReader::new(file);
+		for (ln, line) in bufr.lines().enumerate() {
+			let line = line.map_err(|io| RuntimeError::io(io, &self.path))?;
+
+			match line.split_once(' ') {
+				None => {
+					return Err(RuntimeError::malformed_database(
+						&self.path,
+						ln,
+						DatabaseMalformation::Line,
+					));
+				}
+				Some((raw_key, raw_ip)) => {
+					let auth = AuthKey::from_str_owned(raw_key).map_err(|_| {
+						RuntimeError::malformed_database(
+							&self.path,
+							ln,
+							DatabaseMalformation::AuthKey,
+						)
+					})?;
+
+					let ip = IpAddr::from_str(raw_ip).map_err(|_| {
+						RuntimeError::malformed_database(
+							&self.path,
+							ln,
+							DatabaseMalformation::IpAddr,
+						)
+					})?;
+
+					self.map.insert(auth, ip);
+				}
+			}
+		}
+
+		Ok(())
+	}
+
+	fn save(&self) -> Result<(), RuntimeError> {
+		let file =
+			std::fs::File::create(&self.path).map_err(|io| RuntimeError::io(io, &self.path))?;
+
+		let mut bufw = BufWriter::new(file);
+		for (key, val) in self.map.iter() {
+			let line = format!("{key} {val}");
+
+			bufw.write_all(line.as_bytes())
+				.map_err(|io| RuntimeError::io(io, &self.path))?;
+		}
+
+		Ok(())
+	}
+
+	fn get_by_key(&self, key: AuthKey<'_>) -> Option<IpAddr> {
+		self.map.get(&key).cloned()
+	}
+
+	fn update_ip(&mut self, key: AuthKey<'_>, ip: IpAddr) {
+		self.map.insert(key.as_static_owned(), ip);
+	}
+}
+
+#[derive(Clone, Copy, Debug)]
+enum DatabaseMalformation {
+	AuthKey,
+	IpAddr,
+	Line,
+}
+
+#[derive(Debug)]
+enum RuntimeError {
+	IoError {
+		source: io::Error,
+		path: PathBuf,
+	},
+	MalformedDatabase {
+		path: PathBuf,
+		ln: usize,
+		kind: DatabaseMalformation,
+	},
+	AuthInvalid,
+}
+
+impl RuntimeError {
+	pub fn io<P: Into<PathBuf>>(source: io::Error, path: P) -> Self {
+		RuntimeError::IoError {
+			source,
+			path: path.into(),
+		}
+	}
+
+	pub fn malformed_database<P: Into<PathBuf>>(
+		path: P,
+		ln: usize,
+		kind: DatabaseMalformation,
+	) -> Self {
+		RuntimeError::MalformedDatabase {
+			path: path.into(),
+			ln,
+			kind,
+		}
+	}
+}