about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/count.rs67
-rw-r--r--src/main.rs138
2 files changed, 129 insertions, 76 deletions
diff --git a/src/count.rs b/src/count.rs
new file mode 100644
index 0000000..c3d92bf
--- /dev/null
+++ b/src/count.rs
@@ -0,0 +1,67 @@
+const U16_MAX: usize = u16::MAX as usize;
+
+pub struct Count {
+	tcp_tx: Box<[usize]>,
+	tcp_rx: Box<[usize]>,
+	udp_tx: Box<[usize]>,
+	udp_rx: Box<[usize]>,
+}
+
+impl Count {
+	pub fn new() -> Self {
+		Self {
+			tcp_tx: Box::new([0; U16_MAX]),
+			tcp_rx: Box::new([0; U16_MAX]),
+			udp_tx: Box::new([0; U16_MAX]),
+			udp_rx: Box::new([0; U16_MAX]),
+		}
+	}
+
+	pub fn push_tcp(&mut self, tx_flag: bool, src: u16, dst: u16, len: usize) {
+		// The unsafe code here is bad and not good. As far as I am aware, this
+		// is safe except for allowing race conditions.
+		if tx_flag {
+			let bad_mut = unsafe { &mut *(self.tcp_tx.as_ptr() as *mut [usize; U16_MAX]) };
+			bad_mut[src as usize] += len;
+		} else {
+			let bad_mut = unsafe { &mut *(self.tcp_rx.as_ptr() as *mut [usize; U16_MAX]) };
+			bad_mut[dst as usize] += len;
+		}
+	}
+
+	pub fn push_udp(&mut self, tx_flag: bool, src: u16, dst: u16, len: usize) {
+		// The unsafe code here is bad and not good. As far as I am aware, this
+		// is safe except for allowing race conditions.
+		if tx_flag {
+			let bad_mut = unsafe { &mut *(self.udp_tx.as_ptr() as *mut [usize; U16_MAX]) };
+			bad_mut[src as usize] += len;
+		} else {
+			let bad_mut = unsafe { &mut *(self.udp_rx.as_ptr() as *mut [usize; U16_MAX]) };
+			bad_mut[dst as usize] += len;
+		}
+	}
+
+	pub fn tcp_tx(&self, port: u16) -> usize {
+		self.tcp_tx[port as usize]
+	}
+
+	pub fn tcp_rx(&self, port: u16) -> usize {
+		self.tcp_rx[port as usize]
+	}
+
+	pub fn udp_tx(&self, port: u16) -> usize {
+		self.udp_tx[port as usize]
+	}
+
+	pub fn udp_rx(&self, port: u16) -> usize {
+		self.udp_rx[port as usize]
+	}
+
+	pub fn many_tcp_tx(&self, ports: &[u16]) -> usize {
+		ports.iter().fold(0, |acc, port| acc + self.tcp_tx[*port as usize])
+	}
+
+	pub fn many_tcp_rx(&self, ports: &[u16]) -> usize {
+		ports.iter().fold(0, |acc, port| acc + self.tcp_rx[*port as usize])
+	}
+}
diff --git a/src/main.rs b/src/main.rs
index defd70b..3b0f63c 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,17 +1,19 @@
-use std::{
-	error::Error,
-	net::{Ipv4Addr, Shutdown},
-	sync::mpsc::{self, Receiver, Sender, channel},
-	thread::JoinHandle,
-	time::{Duration, Instant},
-};
+use std::{cell::UnsafeCell, error::Error, net::Ipv4Addr, rc::Rc, thread::JoinHandle};
 
+use count::Count;
 use ethertype::EtherType;
 use ippacket::{IpNextHeader, Ipv4Packet};
 use layer3::{Tcp, Udp};
 use pnet_datalink::{Channel, Config};
 use scurvy::{Argument, Scurvy};
+use smol::{
+	LocalExecutor,
+	channel::{self, Receiver, Sender},
+	io::AsyncWriteExt,
+	net::TcpListener,
+};
 
+mod count;
 mod ethertype;
 mod ippacket;
 mod layer3;
@@ -41,12 +43,6 @@ fn main() -> Result<(), Box<dyn Error>> {
 
 	let stat = stat_thread();
 
-	let tx_clone = stat.tx.clone();
-	std::thread::spawn(move || {
-		std::thread::sleep(Duration::from_secs(30));
-		tx_clone.send(Meow::Show).unwrap();
-	});
-
 	loop {
 		let pkt = channel.next()?;
 
@@ -58,6 +54,14 @@ fn main() -> Result<(), Box<dyn Error>> {
 				let ip = Ipv4Packet::new(eth_payload);
 				let ip_payload = ip.get_payload();
 
+				let ip_tx = if ip.src == ip_want {
+					true
+				} else if ip.dst == ip_want {
+					false
+				} else {
+					continue;
+				};
+
 				// 6 byte per MAC (x2), 2 byte ethertype, 2 byte crc
 				let total_l2_len = ip.get_packet_len() + 18;
 
@@ -65,16 +69,8 @@ fn main() -> Result<(), Box<dyn Error>> {
 					IpNextHeader::Tcp => {
 						let tcp = Tcp::new(ip_payload);
 
-						let tcp_tx = if ip.src == ip_want {
-							true
-						} else if ip.dst == ip_want {
-							false
-						} else {
-							continue;
-						};
-
-						stat.tx.send(Meow::Tcp {
-							tx: tcp_tx,
+						stat.tx.send_blocking(Meow::Tcp {
+							tx: ip_tx,
 							src: tcp.source_port(),
 							dst: tcp.destination_port(),
 							len: total_l2_len,
@@ -83,16 +79,8 @@ fn main() -> Result<(), Box<dyn Error>> {
 					IpNextHeader::Udp => {
 						let udp = Udp::new(ip_payload);
 
-						let udp_tx = if ip.src == ip_want {
-							true
-						} else if ip.dst == ip_want {
-							false
-						} else {
-							continue;
-						};
-
-						stat.tx.send(Meow::Udp {
-							tx: udp_tx,
+						stat.tx.send_blocking(Meow::Udp {
+							tx: ip_tx,
 							src: udp.source_port(),
 							dst: udp.destination_port(),
 							len: total_l2_len,
@@ -106,6 +94,7 @@ fn main() -> Result<(), Box<dyn Error>> {
 	}
 }
 
+#[allow(dead_code)]
 struct StatHandle {
 	hwnd: JoinHandle<()>,
 	tx: Sender<Meow>,
@@ -124,12 +113,10 @@ enum Meow {
 		dst: u16,
 		len: usize,
 	},
-	Show,
-	Shutdown,
 }
 
 fn stat_thread() -> StatHandle {
-	let (tx, rx) = channel();
+	let (tx, rx) = channel::unbounded();
 
 	let hwnd = std::thread::spawn(|| stat(rx));
 
@@ -137,50 +124,49 @@ fn stat_thread() -> StatHandle {
 }
 
 fn stat(rx: Receiver<Meow>) {
-	let mut tcp_tx = vec![0; u16::MAX as usize];
-	let mut tcp_rx = vec![0; u16::MAX as usize];
-
-	let mut last_print = Instant::now();
-	loop {
-		if last_print.elapsed() > Duration::from_secs(10) {
-			let http_s = tcp_rx[80] + tcp_rx[443];
-			let http_s_kb = http_s / 1000;
-
-			let http_s_tx = tcp_tx[80] + tcp_tx[443];
-			let http_s_tx_kb = http_s_tx / 1000;
+	let cat = Rc::new(UnsafeCell::new(Count::new()));
 
-			println!("HTTP(S) rx {}kB // tx {}kB", http_s_kb, http_s_tx_kb);
+	let ex = Rc::new(LocalExecutor::new());
 
-			last_print = Instant::now();
-		}
-
-		match rx.recv() {
-			Err(_e) => {
-				eprintln!("error receiving! breaking from loop");
-				break;
-			}
-			Ok(Meow::Show) => {
-				if last_print.elapsed() > Duration::from_secs(30) {
-					println!("Activated by Meow::Show");
+	ex.spawn(async {
+		loop {
+			match rx.recv().await {
+				Err(_e) => {
+					eprintln!("error receiving! breaking from loop");
+					break;
 				}
-			}
-			Ok(Meow::Shutdown) => {
-				eprintln!("got shutdown! breaking from loop");
-				break;
-			}
-			Ok(Meow::Tcp {
-				tx: tcp_tx_flag,
-				src,
-				dst,
-				len,
-			}) => {
-				if tcp_tx_flag {
-					tcp_tx[src as usize] += len;
-				} else {
-					tcp_rx[dst as usize] += len;
+				Ok(Meow::Tcp { tx, src, dst, len }) => {
+					unsafe { &mut *cat.get() }.push_tcp(tx, src, dst, len);
+				}
+				Ok(Meow::Udp { tx, src, dst, len }) => {
+					unsafe { &mut *cat.get() }.push_udp(tx, src, dst, len);
 				}
 			}
-			_ => (),
 		}
-	}
+	})
+	.detach();
+
+	let run_cat = Rc::clone(&cat);
+	let run_ex = Rc::clone(&ex);
+	let run = ex.run(async move {
+		let socket = TcpListener::bind("127.0.0.1:6666").await.unwrap();
+
+		loop {
+			let (mut stream, _caddr) = socket.accept().await.unwrap();
+			let per_spawn_cat = Rc::clone(&run_cat);
+
+			run_ex
+				.spawn(async move {
+					let cat = unsafe { &mut *per_spawn_cat.get() };
+					let http_tx = cat.many_tcp_tx(&[80, 443]);
+					let http_rx = cat.many_tcp_rx(&[80, 443]);
+
+					let str = format!("{http_tx}\n{http_rx}");
+					stream.write_all(str.as_bytes()).await.unwrap();
+				})
+				.detach();
+		}
+	});
+
+	smol::block_on(run);
 }