about summary refs log tree commit diff
path: root/library/std/src/sys/net/connection/socket/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'library/std/src/sys/net/connection/socket/mod.rs')
-rw-r--r--library/std/src/sys/net/connection/socket/mod.rs199
1 files changed, 132 insertions, 67 deletions
diff --git a/library/std/src/sys/net/connection/socket/mod.rs b/library/std/src/sys/net/connection/socket/mod.rs
index 1dd06e97bba..d0a4a2fab49 100644
--- a/library/std/src/sys/net/connection/socket/mod.rs
+++ b/library/std/src/sys/net/connection/socket/mod.rs
@@ -3,6 +3,7 @@ mod tests;
 
 use crate::ffi::{c_int, c_void};
 use crate::io::{self, BorrowedCursor, ErrorKind, IoSlice, IoSliceMut};
+use crate::mem::MaybeUninit;
 use crate::net::{
     Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs,
 };
@@ -177,6 +178,18 @@ fn socket_addr_to_c(addr: &SocketAddr) -> (SocketAddrCRepr, c::socklen_t) {
     }
 }
 
+fn addr_family(addr: &SocketAddr) -> c_int {
+    match addr {
+        SocketAddr::V4(..) => c::AF_INET,
+        SocketAddr::V6(..) => c::AF_INET6,
+    }
+}
+
+/// Converts the C socket address stored in `storage` to a Rust `SocketAddr`.
+///
+/// # Safety
+/// * `storage` must contain a valid C socket address whose length is no larger
+///   than `len`.
 unsafe fn socket_addr_from_c(
     storage: *const c::sockaddr_storage,
     len: usize,
@@ -202,49 +215,85 @@ unsafe fn socket_addr_from_c(
 // sockaddr and misc bindings
 ////////////////////////////////////////////////////////////////////////////////
 
-pub fn setsockopt<T>(
+/// Sets the value of a socket option.
+///
+/// # Safety
+/// `T` must be the type associated with the given socket option.
+pub unsafe fn setsockopt<T>(
     sock: &Socket,
     level: c_int,
     option_name: c_int,
     option_value: T,
 ) -> io::Result<()> {
-    unsafe {
-        cvt(c::setsockopt(
+    let option_len = size_of::<T>() as c::socklen_t;
+    // SAFETY:
+    // * `sock` is opened for the duration of this call, as `sock` owns the socket.
+    // * the pointer to `option_value` is readable at a size of `size_of::<T>`
+    //   bytes
+    // * the value of `option_value` has a valid type for the given socket option
+    //   (guaranteed by caller).
+    cvt(unsafe {
+        c::setsockopt(
             sock.as_raw(),
             level,
             option_name,
             (&raw const option_value) as *const _,
-            size_of::<T>() as c::socklen_t,
-        ))?;
-        Ok(())
-    }
+            option_len,
+        )
+    })?;
+    Ok(())
 }
 
-pub fn getsockopt<T: Copy>(sock: &Socket, level: c_int, option_name: c_int) -> io::Result<T> {
-    unsafe {
-        let mut option_value: T = mem::zeroed();
-        let mut option_len = size_of::<T>() as c::socklen_t;
-        cvt(c::getsockopt(
+/// Gets the value of a socket option.
+///
+/// # Safety
+/// `T` must be the type associated with the given socket option.
+pub unsafe fn getsockopt<T: Copy>(
+    sock: &Socket,
+    level: c_int,
+    option_name: c_int,
+) -> io::Result<T> {
+    let mut option_value = MaybeUninit::<T>::zeroed();
+    let mut option_len = size_of::<T>() as c::socklen_t;
+
+    // SAFETY:
+    // * `sock` is opened for the duration of this call, as `sock` owns the socket.
+    // * the pointer to `option_value` is writable and the stack allocation has
+    //   space for `size_of::<T>` bytes.
+    cvt(unsafe {
+        c::getsockopt(
             sock.as_raw(),
             level,
             option_name,
-            (&raw mut option_value) as *mut _,
+            option_value.as_mut_ptr().cast(),
             &mut option_len,
-        ))?;
-        Ok(option_value)
-    }
-}
-
-fn sockname<F>(f: F) -> io::Result<SocketAddr>
+        )
+    })?;
+
+    // SAFETY: the `getsockopt` call succeeded and the caller guarantees that
+    //         `T` is the type of this option, thus `option_value` must have
+    //         been initialized by the system.
+    Ok(unsafe { option_value.assume_init() })
+}
+
+/// Wraps a call to a platform function that returns a socket address.
+///
+/// # Safety
+/// * if `f` returns a success (i.e. `cvt` returns `Ok` when called on the
+///   return value), the buffer provided to `f` must have been initialized
+///   with a valid C socket address, the length of which must be written
+///   to the second argument.
+unsafe fn sockname<F>(f: F) -> io::Result<SocketAddr>
 where
     F: FnOnce(*mut c::sockaddr, *mut c::socklen_t) -> c_int,
 {
-    unsafe {
-        let mut storage: c::sockaddr_storage = mem::zeroed();
-        let mut len = size_of_val(&storage) as c::socklen_t;
-        cvt(f((&raw mut storage) as *mut _, &mut len))?;
-        socket_addr_from_c(&storage, len as usize)
-    }
+    let mut storage = MaybeUninit::<c::sockaddr_storage>::zeroed();
+    let mut len = size_of::<c::sockaddr_storage>() as c::socklen_t;
+    cvt(f(storage.as_mut_ptr().cast(), &mut len))?;
+    // SAFETY:
+    // The caller guarantees that the storage has been successfully initialized
+    // and its size written to `len` if `f` returns a success.
+    unsafe { socket_addr_from_c(storage.as_ptr(), len as usize) }
 }
 
 #[cfg(target_os = "android")]
@@ -322,7 +371,7 @@ impl TcpStream {
         return each_addr(addr, inner);
 
         fn inner(addr: &SocketAddr) -> io::Result<TcpStream> {
-            let sock = Socket::new(addr, c::SOCK_STREAM)?;
+            let sock = Socket::new(addr_family(addr), c::SOCK_STREAM)?;
             sock.connect(addr)?;
             Ok(TcpStream { inner: sock })
         }
@@ -331,7 +380,7 @@ impl TcpStream {
     pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
         init();
 
-        let sock = Socket::new(addr, c::SOCK_STREAM)?;
+        let sock = Socket::new(addr_family(addr), c::SOCK_STREAM)?;
         sock.connect_timeout(addr, timeout)?;
         Ok(TcpStream { inner: sock })
     }
@@ -400,11 +449,11 @@ impl TcpStream {
     }
 
     pub fn peer_addr(&self) -> io::Result<SocketAddr> {
-        sockname(|buf, len| unsafe { c::getpeername(self.inner.as_raw(), buf, len) })
+        unsafe { sockname(|buf, len| c::getpeername(self.inner.as_raw(), buf, len)) }
     }
 
     pub fn socket_addr(&self) -> io::Result<SocketAddr> {
-        sockname(|buf, len| unsafe { c::getsockname(self.inner.as_raw(), buf, len) })
+        unsafe { sockname(|buf, len| c::getsockname(self.inner.as_raw(), buf, len)) }
     }
 
     pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
@@ -432,11 +481,11 @@ impl TcpStream {
     }
 
     pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
-        setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int)
+        unsafe { setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int) }
     }
 
     pub fn ttl(&self) -> io::Result<u32> {
-        let raw: c_int = getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)?;
+        let raw: c_int = unsafe { getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)? };
         Ok(raw as u32)
     }
 
@@ -493,7 +542,7 @@ impl TcpListener {
         return each_addr(addr, inner);
 
         fn inner(addr: &SocketAddr) -> io::Result<TcpListener> {
-            let sock = Socket::new(addr, c::SOCK_STREAM)?;
+            let sock = Socket::new(addr_family(addr), c::SOCK_STREAM)?;
 
             // On platforms with Berkeley-derived sockets, this allows to quickly
             // rebind a socket, without needing to wait for the OS to clean up the
@@ -503,7 +552,9 @@ impl TcpListener {
             // which allows “socket hijacking”, so we explicitly don't set it here.
             // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
             #[cfg(not(windows))]
-            setsockopt(&sock, c::SOL_SOCKET, c::SO_REUSEADDR, 1 as c_int)?;
+            unsafe {
+                setsockopt(&sock, c::SOL_SOCKET, c::SO_REUSEADDR, 1 as c_int)?
+            };
 
             // Bind our new socket
             let (addr, len) = socket_addr_to_c(addr);
@@ -539,15 +590,15 @@ impl TcpListener {
     }
 
     pub fn socket_addr(&self) -> io::Result<SocketAddr> {
-        sockname(|buf, len| unsafe { c::getsockname(self.inner.as_raw(), buf, len) })
+        unsafe { sockname(|buf, len| c::getsockname(self.inner.as_raw(), buf, len)) }
     }
 
     pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
         // The `accept` function will fill in the storage with the address,
         // so we don't need to zero it here.
         // reference: https://linux.die.net/man/2/accept4
-        let mut storage: mem::MaybeUninit<c::sockaddr_storage> = mem::MaybeUninit::uninit();
-        let mut len = size_of_val(&storage) as c::socklen_t;
+        let mut storage = MaybeUninit::<c::sockaddr_storage>::uninit();
+        let mut len = size_of::<c::sockaddr_storage>() as c::socklen_t;
         let sock = self.inner.accept(storage.as_mut_ptr() as *mut _, &mut len)?;
         let addr = unsafe { socket_addr_from_c(storage.as_ptr(), len as usize)? };
         Ok((TcpStream { inner: sock }, addr))
@@ -558,20 +609,20 @@ impl TcpListener {
     }
 
     pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
-        setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int)
+        unsafe { setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int) }
     }
 
     pub fn ttl(&self) -> io::Result<u32> {
-        let raw: c_int = getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)?;
+        let raw: c_int = unsafe { getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)? };
         Ok(raw as u32)
     }
 
     pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
-        setsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY, only_v6 as c_int)
+        unsafe { setsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY, only_v6 as c_int) }
     }
 
     pub fn only_v6(&self) -> io::Result<bool> {
-        let raw: c_int = getsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY)?;
+        let raw: c_int = unsafe { getsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_V6ONLY)? };
         Ok(raw != 0)
     }
 
@@ -617,7 +668,7 @@ impl UdpSocket {
         return each_addr(addr, inner);
 
         fn inner(addr: &SocketAddr) -> io::Result<UdpSocket> {
-            let sock = Socket::new(addr, c::SOCK_DGRAM)?;
+            let sock = Socket::new(addr_family(addr), c::SOCK_DGRAM)?;
             let (addr, len) = socket_addr_to_c(addr);
             cvt(unsafe { c::bind(sock.as_raw(), addr.as_ptr(), len as _) })?;
             Ok(UdpSocket { inner: sock })
@@ -634,11 +685,11 @@ impl UdpSocket {
     }
 
     pub fn peer_addr(&self) -> io::Result<SocketAddr> {
-        sockname(|buf, len| unsafe { c::getpeername(self.inner.as_raw(), buf, len) })
+        unsafe { sockname(|buf, len| c::getpeername(self.inner.as_raw(), buf, len)) }
     }
 
     pub fn socket_addr(&self) -> io::Result<SocketAddr> {
-        sockname(|buf, len| unsafe { c::getsockname(self.inner.as_raw(), buf, len) })
+        unsafe { sockname(|buf, len| c::getsockname(self.inner.as_raw(), buf, len)) }
     }
 
     pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
@@ -686,48 +737,62 @@ impl UdpSocket {
     }
 
     pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
-        setsockopt(&self.inner, c::SOL_SOCKET, c::SO_BROADCAST, broadcast as c_int)
+        unsafe { setsockopt(&self.inner, c::SOL_SOCKET, c::SO_BROADCAST, broadcast as c_int) }
     }
 
     pub fn broadcast(&self) -> io::Result<bool> {
-        let raw: c_int = getsockopt(&self.inner, c::SOL_SOCKET, c::SO_BROADCAST)?;
+        let raw: c_int = unsafe { getsockopt(&self.inner, c::SOL_SOCKET, c::SO_BROADCAST)? };
         Ok(raw != 0)
     }
 
     pub fn set_multicast_loop_v4(&self, multicast_loop_v4: bool) -> io::Result<()> {
-        setsockopt(
-            &self.inner,
-            c::IPPROTO_IP,
-            c::IP_MULTICAST_LOOP,
-            multicast_loop_v4 as IpV4MultiCastType,
-        )
+        unsafe {
+            setsockopt(
+                &self.inner,
+                c::IPPROTO_IP,
+                c::IP_MULTICAST_LOOP,
+                multicast_loop_v4 as IpV4MultiCastType,
+            )
+        }
     }
 
     pub fn multicast_loop_v4(&self) -> io::Result<bool> {
-        let raw: IpV4MultiCastType = getsockopt(&self.inner, c::IPPROTO_IP, c::IP_MULTICAST_LOOP)?;
+        let raw: IpV4MultiCastType =
+            unsafe { getsockopt(&self.inner, c::IPPROTO_IP, c::IP_MULTICAST_LOOP)? };
         Ok(raw != 0)
     }
 
     pub fn set_multicast_ttl_v4(&self, multicast_ttl_v4: u32) -> io::Result<()> {
-        setsockopt(
-            &self.inner,
-            c::IPPROTO_IP,
-            c::IP_MULTICAST_TTL,
-            multicast_ttl_v4 as IpV4MultiCastType,
-        )
+        unsafe {
+            setsockopt(
+                &self.inner,
+                c::IPPROTO_IP,
+                c::IP_MULTICAST_TTL,
+                multicast_ttl_v4 as IpV4MultiCastType,
+            )
+        }
     }
 
     pub fn multicast_ttl_v4(&self) -> io::Result<u32> {
-        let raw: IpV4MultiCastType = getsockopt(&self.inner, c::IPPROTO_IP, c::IP_MULTICAST_TTL)?;
+        let raw: IpV4MultiCastType =
+            unsafe { getsockopt(&self.inner, c::IPPROTO_IP, c::IP_MULTICAST_TTL)? };
         Ok(raw as u32)
     }
 
     pub fn set_multicast_loop_v6(&self, multicast_loop_v6: bool) -> io::Result<()> {
-        setsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_MULTICAST_LOOP, multicast_loop_v6 as c_int)
+        unsafe {
+            setsockopt(
+                &self.inner,
+                c::IPPROTO_IPV6,
+                c::IPV6_MULTICAST_LOOP,
+                multicast_loop_v6 as c_int,
+            )
+        }
     }
 
     pub fn multicast_loop_v6(&self) -> io::Result<bool> {
-        let raw: c_int = getsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_MULTICAST_LOOP)?;
+        let raw: c_int =
+            unsafe { getsockopt(&self.inner, c::IPPROTO_IPV6, c::IPV6_MULTICAST_LOOP)? };
         Ok(raw != 0)
     }
 
@@ -736,7 +801,7 @@ impl UdpSocket {
             imr_multiaddr: ip_v4_addr_to_c(multiaddr),
             imr_interface: ip_v4_addr_to_c(interface),
         };
-        setsockopt(&self.inner, c::IPPROTO_IP, c::IP_ADD_MEMBERSHIP, mreq)
+        unsafe { setsockopt(&self.inner, c::IPPROTO_IP, c::IP_ADD_MEMBERSHIP, mreq) }
     }
 
     pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
@@ -744,7 +809,7 @@ impl UdpSocket {
             ipv6mr_multiaddr: ip_v6_addr_to_c(multiaddr),
             ipv6mr_interface: to_ipv6mr_interface(interface),
         };
-        setsockopt(&self.inner, c::IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, mreq)
+        unsafe { setsockopt(&self.inner, c::IPPROTO_IPV6, IPV6_ADD_MEMBERSHIP, mreq) }
     }
 
     pub fn leave_multicast_v4(&self, multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> io::Result<()> {
@@ -752,7 +817,7 @@ impl UdpSocket {
             imr_multiaddr: ip_v4_addr_to_c(multiaddr),
             imr_interface: ip_v4_addr_to_c(interface),
         };
-        setsockopt(&self.inner, c::IPPROTO_IP, c::IP_DROP_MEMBERSHIP, mreq)
+        unsafe { setsockopt(&self.inner, c::IPPROTO_IP, c::IP_DROP_MEMBERSHIP, mreq) }
     }
 
     pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
@@ -760,15 +825,15 @@ impl UdpSocket {
             ipv6mr_multiaddr: ip_v6_addr_to_c(multiaddr),
             ipv6mr_interface: to_ipv6mr_interface(interface),
         };
-        setsockopt(&self.inner, c::IPPROTO_IPV6, IPV6_DROP_MEMBERSHIP, mreq)
+        unsafe { setsockopt(&self.inner, c::IPPROTO_IPV6, IPV6_DROP_MEMBERSHIP, mreq) }
     }
 
     pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
-        setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int)
+        unsafe { setsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL, ttl as c_int) }
     }
 
     pub fn ttl(&self) -> io::Result<u32> {
-        let raw: c_int = getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)?;
+        let raw: c_int = unsafe { getsockopt(&self.inner, c::IPPROTO_IP, c::IP_TTL)? };
         Ok(raw as u32)
     }