about summary refs log tree commit diff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs101
1 files changed, 77 insertions, 24 deletions
diff --git a/src/main.rs b/src/main.rs
index d288cba..e352fe5 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,5 @@
-use std::{net::SocketAddr, sync::Arc};
+use core::fmt;
+use std::{borrow::Cow, fmt::format, net::SocketAddr, sync::Arc};
 
 use httparse::{Request, Status};
 use smol::{
@@ -69,32 +70,11 @@ async fn handle_request(mut stream: TcpStream, caddr: SocketAddr) {
 			}
 		}
 
-		let minor = req.version.unwrap();
-		let status = 200;
-		let message = "Gotcha";
-
 		// Fall through the match. We should guarantee that, if we end up here,
 		// we got a Status::Complete from the parser
-		let mut response = format!("HTTP/1.{minor} {status} {message}\n");
-
-		let body = format!("{}", caddr.ip());
-		let length = body.len();
-
-		macro_rules! header {
-			($content:expr) => {
-				response.push_str(&format!($content));
-				response.push('\n');
-			};
-		}
-
-		header!("Content-Length: {length}");
-		header!("Content-Type: text/plain");
-		header!("Cache-Control: no-store");
-
-		response.push('\n');
-		response.push_str(&body);
+		let response = form_response(req, caddr).await;
 
-		if let Err(err) = stream.write_all(response.as_bytes()).await {
+		if let Err(err) = stream.write_all(&response).await {
 			eprintln!("Error in handle_request sending response back. Bailing.\nError: {err}");
 			return;
 		} else {
@@ -102,3 +82,76 @@ async fn handle_request(mut stream: TcpStream, caddr: SocketAddr) {
 		}
 	}
 }
+
+fn error(code: u16, msg: &'static str) -> Vec<u8> {
+	format!(
+		"HTTP/1.1 {code} {msg}\n\
+		server: piper/0.1\n\
+		content-length: 0\n\
+		cache-control: no-store\n\n"
+	)
+	.as_bytes()
+	.to_vec()
+}
+
+fn get_header<'h>(
+	headers: &'h [httparse::Header<'h>],
+	key: &'static str,
+) -> Result<Option<&'h str>, Vec<u8>> {
+	headers
+		.iter()
+		.find(|h| h.name.to_lowercase() == key)
+		.map(|v| str::from_utf8(v.value).map_err(|_| error(400, "malformed header")))
+		.transpose()
+}
+
+async fn form_response<'req>(request: Request<'req, 'req>, caddr: SocketAddr) -> Vec<u8> {
+	let minor = request.version.unwrap();
+	let status = 200;
+	let message = "Gotcha";
+
+	let mut response = format!("HTTP/1.{minor} {status} {message}\n");
+
+	let requested_type = match get_header(request.headers, "content-type") {
+		Err(e) => return e,
+		Ok(None) | Ok(Some("text/plain")) => ContentType::Plain,
+		Ok(Some("application/json")) => ContentType::Json,
+		Ok(Some(_)) => return error(400, "unknown content-type"),
+	};
+
+	let body = match requested_type {
+		ContentType::Plain => format!("{}", caddr.ip()),
+		ContentType::Json => format!(r#"{{"ip": "{}"}}"#, caddr.ip()),
+	};
+	let length = body.len();
+
+	macro_rules! header {
+		($content:expr) => {
+			response.push_str(&format!($content));
+			response.push('\n');
+		};
+	}
+
+	header!("Content-Length: {length}");
+	header!("Content-Type: {requested_type}");
+	header!("Cache-Control: no-store");
+
+	response.push('\n');
+	response.push_str(&body);
+
+	response.into_bytes()
+}
+
+enum ContentType {
+	Plain,
+	Json,
+}
+
+impl fmt::Display for ContentType {
+	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+		match self {
+			ContentType::Plain => write!(f, "text/plain"),
+			ContentType::Json => write!(f, "application/json"),
+		}
+	}
+}