about summary refs log tree commit diff
path: root/src/libstd/sys
diff options
context:
space:
mode:
authorTyler Julian <tjulian@uber.com>2017-01-10 19:11:56 -0800
committerTyler Julian <tjulian@uber.com>2017-02-04 12:00:19 -0800
commita40be0857c7bf48e39f815417b0b5293cd8ed1aa (patch)
treef142739d4a6f706f583380436bcb28e4ca6f8ffd /src/libstd/sys
parent9749df52b7ecfc8123e392b9d49786e2abf20320 (diff)
downloadrust-a40be0857c7bf48e39f815417b0b5293cd8ed1aa.tar.gz
rust-a40be0857c7bf48e39f815417b0b5293cd8ed1aa.zip
libstd/net: Add `peek` APIs to UdpSocket and TcpStream
These methods enable socket reads without side-effects. That is,
repeated calls to peek() return identical data. This is accomplished
by providing the POSIX flag MSG_PEEK to the underlying socket read
operations.

This also moves the current implementation of recv_from out of the
platform-independent sys_common and into respective sys/windows and
sys/unix implementations. This allows for more platform-dependent
implementations.
Diffstat (limited to 'src/libstd/sys')
-rw-r--r--src/libstd/sys/unix/net.rs45
-rw-r--r--src/libstd/sys/windows/c.rs1
-rw-r--r--src/libstd/sys/windows/net.rs44
3 files changed, 85 insertions, 5 deletions
diff --git a/src/libstd/sys/unix/net.rs b/src/libstd/sys/unix/net.rs
index ad287bbec38..5efddca110f 100644
--- a/src/libstd/sys/unix/net.rs
+++ b/src/libstd/sys/unix/net.rs
@@ -10,12 +10,13 @@
 
 use ffi::CStr;
 use io;
-use libc::{self, c_int, size_t, sockaddr, socklen_t, EAI_SYSTEM};
+use libc::{self, c_int, c_void, size_t, sockaddr, socklen_t, EAI_SYSTEM, MSG_PEEK};
+use mem;
 use net::{SocketAddr, Shutdown};
 use str;
 use sys::fd::FileDesc;
 use sys_common::{AsInner, FromInner, IntoInner};
-use sys_common::net::{getsockopt, setsockopt};
+use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr};
 use time::Duration;
 
 pub use sys::{cvt, cvt_r};
@@ -155,8 +156,46 @@ impl Socket {
         self.0.duplicate().map(Socket)
     }
 
+    fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
+        let ret = cvt(unsafe {
+            libc::recv(self.0.raw(),
+                       buf.as_mut_ptr() as *mut c_void,
+                       buf.len(),
+                       flags)
+        })?;
+        Ok(ret as usize)
+    }
+
     pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
-        self.0.read(buf)
+        self.recv_with_flags(buf, 0)
+    }
+
+    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
+        self.recv_with_flags(buf, MSG_PEEK)
+    }
+
+    fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
+                            -> io::Result<(usize, SocketAddr)> {
+        let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
+        let mut addrlen = mem::size_of_val(&storage) as libc::socklen_t;
+
+        let n = cvt(unsafe {
+            libc::recvfrom(self.0.raw(),
+                        buf.as_mut_ptr() as *mut c_void,
+                        buf.len(),
+                        flags,
+                        &mut storage as *mut _ as *mut _,
+                        &mut addrlen)
+        })?;
+        Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
+    }
+
+    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+        self.recv_from_with_flags(buf, 0)
+    }
+
+    pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+        self.recv_from_with_flags(buf, MSG_PEEK)
     }
 
     pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
diff --git a/src/libstd/sys/windows/c.rs b/src/libstd/sys/windows/c.rs
index dc7b2fc9a6b..9f03f5c9717 100644
--- a/src/libstd/sys/windows/c.rs
+++ b/src/libstd/sys/windows/c.rs
@@ -244,6 +244,7 @@ pub const IP_ADD_MEMBERSHIP: c_int = 12;
 pub const IP_DROP_MEMBERSHIP: c_int = 13;
 pub const IPV6_ADD_MEMBERSHIP: c_int = 12;
 pub const IPV6_DROP_MEMBERSHIP: c_int = 13;
+pub const MSG_PEEK: c_int = 0x2;
 
 #[repr(C)]
 pub struct ip_mreq {
diff --git a/src/libstd/sys/windows/net.rs b/src/libstd/sys/windows/net.rs
index aca6994503f..adf6210d82e 100644
--- a/src/libstd/sys/windows/net.rs
+++ b/src/libstd/sys/windows/net.rs
@@ -147,12 +147,12 @@ impl Socket {
         Ok(socket)
     }
 
-    pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
+    fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
         // On unix when a socket is shut down all further reads return 0, so we
         // do the same on windows to map a shut down socket to returning EOF.
         let len = cmp::min(buf.len(), i32::max_value() as usize) as i32;
         unsafe {
-            match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, 0) {
+            match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, flags) {
                 -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0),
                 -1 => Err(last_error()),
                 n => Ok(n as usize)
@@ -160,6 +160,46 @@ impl Socket {
         }
     }
 
+    pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
+        self.recv_with_flags(buf, 0)
+    }
+
+    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
+        self.recv_with_flags(buf, c::MSG_PEEK)
+    }
+
+    fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
+                            -> io::Result<(usize, SocketAddr)> {
+        let mut storage: c::SOCKADDR_STORAGE_LH = unsafe { mem::zeroed() };
+        let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
+        let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
+
+        // On unix when a socket is shut down all further reads return 0, so we
+        // do the same on windows to map a shut down socket to returning EOF.
+        unsafe {
+            match c::recvfrom(self.0,
+                              buf.as_mut_ptr() as *mut c_void,
+                              len,
+                              flags,
+                              &mut storage as *mut _ as *mut _,
+                              &mut addrlen) {
+                -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => {
+                    Ok((0, net::sockaddr_to_addr(&storage, addrlen as usize)?))
+                },
+                -1 => Err(last_error()),
+                n => Ok((n as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)),
+            }
+        }
+    }
+
+    pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+        self.recv_from_with_flags(buf, 0)
+    }
+
+    pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+        self.recv_from_with_flags(buf, c::MSG_PEEK)
+    }
+
     pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
         let mut me = self;
         (&mut me).read_to_end(buf)