about summary refs log tree commit diff
path: root/src/libstd/sys/windows
diff options
context:
space:
mode:
authorSteven Fackler <sfackler@gmail.com>2017-07-04 23:46:24 -0700
committerSteven Fackler <sfackler@gmail.com>2017-07-06 19:35:49 -0700
commit8c92da3c518111b28f0b5fb297ab719bf353cdc6 (patch)
treee68b05e25cf9a0a16bd9d6e72d161f92b843beac /src/libstd/sys/windows
parentccf401f8f7847413ac43f38e3c86016becf67b07 (diff)
downloadrust-8c92da3c518111b28f0b5fb297ab719bf353cdc6.tar.gz
rust-8c92da3c518111b28f0b5fb297ab719bf353cdc6.zip
Implement TcpStream::connect_timeout
This breaks the "single syscall rule", but it's really annoying to hand
write and is pretty foundational.
Diffstat (limited to 'src/libstd/sys/windows')
-rw-r--r--src/libstd/sys/windows/c.rs27
-rw-r--r--src/libstd/sys/windows/net.rs56
2 files changed, 82 insertions, 1 deletions
diff --git a/src/libstd/sys/windows/c.rs b/src/libstd/sys/windows/c.rs
index 1646f8cce72..4785cefd6b4 100644
--- a/src/libstd/sys/windows/c.rs
+++ b/src/libstd/sys/windows/c.rs
@@ -298,6 +298,8 @@ pub const PIPE_TYPE_BYTE: DWORD = 0x00000000;
 pub const PIPE_REJECT_REMOTE_CLIENTS: DWORD = 0x00000008;
 pub const PIPE_READMODE_BYTE: DWORD = 0x00000000;
 
+pub const FD_SETSIZE: usize = 64;
+
 #[repr(C)]
 #[cfg(target_arch = "x86")]
 pub struct WSADATA {
@@ -837,6 +839,26 @@ pub struct CONSOLE_READCONSOLE_CONTROL {
 }
 pub type PCONSOLE_READCONSOLE_CONTROL = *mut CONSOLE_READCONSOLE_CONTROL;
 
+#[repr(C)]
+#[derive(Copy)]
+pub struct fd_set {
+    pub fd_count: c_uint,
+    pub fd_array: [SOCKET; FD_SETSIZE],
+}
+
+impl Clone for fd_set {
+    fn clone(&self) -> fd_set {
+        *self
+    }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct timeval {
+    pub tv_sec: c_long,
+    pub tv_usec: c_long,
+}
+
 extern "system" {
     pub fn WSAStartup(wVersionRequested: WORD,
                       lpWSAData: LPWSADATA) -> c_int;
@@ -1125,6 +1147,11 @@ extern "system" {
                                lpOverlapped: LPOVERLAPPED,
                                lpNumberOfBytesTransferred: LPDWORD,
                                bWait: BOOL) -> BOOL;
+    pub fn select(nfds: c_int,
+                  readfds: *mut fd_set,
+                  writefds: *mut fd_set,
+                  exceptfds: *mut fd_set,
+                  timeout: *const timeval) -> c_int;
 }
 
 // Functions that aren't available on Windows XP, but we still use them and just
diff --git a/src/libstd/sys/windows/net.rs b/src/libstd/sys/windows/net.rs
index f2a2793425d..cd8acff6b0c 100644
--- a/src/libstd/sys/windows/net.rs
+++ b/src/libstd/sys/windows/net.rs
@@ -12,7 +12,7 @@
 
 use cmp;
 use io::{self, Read};
-use libc::{c_int, c_void, c_ulong};
+use libc::{c_int, c_void, c_ulong, c_long};
 use mem;
 use net::{SocketAddr, Shutdown};
 use ptr;
@@ -115,6 +115,60 @@ impl Socket {
         Ok(socket)
     }
 
+    pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
+        self.set_nonblocking(true)?;
+        let r = unsafe {
+            let (addrp, len) = addr.into_inner();
+            cvt(c::connect(self.0, addrp, len))
+        };
+        self.set_nonblocking(false)?;
+
+        match r {
+            Ok(_) => return Ok(()),
+            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
+            Err(e) => return Err(e),
+        }
+
+        if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
+            return Err(io::Error::new(io::ErrorKind::InvalidInput,
+                                      "cannot set a 0 duration timeout"));
+        }
+
+        let mut timeout = c::timeval {
+            tv_sec: timeout.as_secs() as c_long,
+            tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
+        };
+        if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
+            timeout.tv_usec = 1;
+        }
+
+        let fds = unsafe {
+            let mut fds = mem::zeroed::<c::fd_set>();
+            fds.fd_count = 1;
+            fds.fd_array[0] = self.0;
+            fds
+        };
+
+        let mut writefds = fds;
+        let mut errorfds = fds;
+
+        let n = unsafe {
+            cvt(c::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout))?
+        };
+
+        match n {
+            0 => Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
+            _ => {
+                if writefds.fd_count != 1 {
+                    if let Some(e) = self.take_error()? {
+                        return Err(e);
+                    }
+                }
+                Ok(())
+            }
+        }
+    }
+
     pub fn accept(&self, storage: *mut c::SOCKADDR,
                   len: *mut c_int) -> io::Result<Socket> {
         let socket = unsafe {