about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2014-04-23 19:21:33 -0700
committerbors <bors@rust-lang.org>2014-04-23 19:21:33 -0700
commit3d05e7f9cdd76887de75f46b5e47d2685bec6520 (patch)
tree56e20faec42ce4ff66a2fdf0e7d92fa8ea1e55a1
parentd9103301726a4d91c622f8ae3d2d10ad225a0f65 (diff)
parente5d3e5180f667f8850cdd96af60fc5511746b1bd (diff)
downloadrust-3d05e7f9cdd76887de75f46b5e47d2685bec6520.tar.gz
rust-3d05e7f9cdd76887de75f46b5e47d2685bec6520.zip
auto merge of #13688 : alexcrichton/rust/accept-timeout, r=brson
This adds experimental support for timeouts when accepting sockets through
`TcpAcceptor::accept`. This does not add a separate `accept_timeout` function,
but rather it adds a `set_timeout` function instead. This second function is
intended to be used as a hard deadline after which all accepts will never block
and fail immediately.

This idea was derived from Go's SetDeadline() methods. We do not currently have
a robust time abstraction in the standard library, so I opted to have the
argument be a relative time in millseconds into the future. I believe a more
appropriate argument type is an absolute time, but this concept does not exist
yet (this is also why the function is marked #[experimental]).

The native support is built on select(), similarly to connect_timeout(), and the
green support is based on channel select and a timer.

cc #13523
-rw-r--r--src/libnative/io/c_win32.rs6
-rw-r--r--src/libnative/io/net.rs83
-rw-r--r--src/libnative/io/timer_win32.rs11
-rw-r--r--src/librustuv/net.rs88
-rw-r--r--src/librustuv/timer.rs19
-rw-r--r--src/libstd/io/net/tcp.rs86
-rw-r--r--src/libstd/rt/rtio.rs1
7 files changed, 253 insertions, 41 deletions
diff --git a/src/libnative/io/c_win32.rs b/src/libnative/io/c_win32.rs
index 8d75a673914..dbbb39b3b7b 100644
--- a/src/libnative/io/c_win32.rs
+++ b/src/libnative/io/c_win32.rs
@@ -50,9 +50,9 @@ extern "system" {
     pub fn ioctlsocket(s: libc::SOCKET, cmd: libc::c_long,
                        argp: *mut libc::c_ulong) -> libc::c_int;
     pub fn select(nfds: libc::c_int,
-                  readfds: *mut fd_set,
-                  writefds: *mut fd_set,
-                  exceptfds: *mut fd_set,
+                  readfds: *fd_set,
+                  writefds: *fd_set,
+                  exceptfds: *fd_set,
                   timeout: *libc::timeval) -> libc::c_int;
     pub fn getsockopt(sockfd: libc::SOCKET,
                       level: libc::c_int,
diff --git a/src/libnative/io/net.rs b/src/libnative/io/net.rs
index be597761b1a..93ec23e32ad 100644
--- a/src/libnative/io/net.rs
+++ b/src/libnative/io/net.rs
@@ -13,6 +13,7 @@ use std::cast;
 use std::io::net::ip;
 use std::io;
 use std::mem;
+use std::os;
 use std::ptr;
 use std::rt::rtio;
 use std::sync::arc::UnsafeArc;
@@ -144,6 +145,21 @@ fn last_error() -> io::IoError {
     super::last_error()
 }
 
+fn ms_to_timeval(ms: u64) -> libc::timeval {
+    libc::timeval {
+        tv_sec: (ms / 1000) as libc::time_t,
+        tv_usec: ((ms % 1000) * 1000) as libc::suseconds_t,
+    }
+}
+
+fn timeout(desc: &'static str) -> io::IoError {
+    io::IoError {
+        kind: io::TimedOut,
+        desc: desc,
+        detail: None,
+    }
+}
+
 #[cfg(windows)] unsafe fn close(sock: sock_t) { let _ = libc::closesocket(sock); }
 #[cfg(unix)]    unsafe fn close(sock: sock_t) { let _ = libc::close(sock); }
 
@@ -271,8 +287,7 @@ impl TcpStream {
     fn connect_timeout(fd: sock_t,
                        addrp: *libc::sockaddr,
                        len: libc::socklen_t,
-                       timeout: u64) -> IoResult<()> {
-        use std::os;
+                       timeout_ms: u64) -> IoResult<()> {
         #[cfg(unix)]    use INPROGRESS = libc::EINPROGRESS;
         #[cfg(windows)] use INPROGRESS = libc::WSAEINPROGRESS;
         #[cfg(unix)]    use WOULDBLOCK = libc::EWOULDBLOCK;
@@ -289,12 +304,8 @@ impl TcpStream {
                   os::errno() as int == WOULDBLOCK as int => {
                 let mut set: c::fd_set = unsafe { mem::init() };
                 c::fd_set(&mut set, fd);
-                match await(fd, &mut set, timeout) {
-                    0 => Err(io::IoError {
-                        kind: io::TimedOut,
-                        desc: "connection timed out",
-                        detail: None,
-                    }),
+                match await(fd, &mut set, timeout_ms) {
+                    0 => Err(timeout("connection timed out")),
                     -1 => Err(last_error()),
                     _ => {
                         let err: libc::c_int = try!(
@@ -338,22 +349,14 @@ impl TcpStream {
                 // Recalculate the timeout each iteration (it is generally
                 // undefined what the value of the 'tv' is after select
                 // returns EINTR).
-                let timeout = timeout - (::io::timer::now() - start);
-                let tv = libc::timeval {
-                    tv_sec: (timeout / 1000) as libc::time_t,
-                    tv_usec: ((timeout % 1000) * 1000) as libc::suseconds_t,
-                };
-                c::select(fd + 1, ptr::null(), set as *mut _ as *_,
-                          ptr::null(), &tv)
+                let tv = ms_to_timeval(timeout - (::io::timer::now() - start));
+                c::select(fd + 1, ptr::null(), &*set, ptr::null(), &tv)
             })
         }
         #[cfg(windows)]
         fn await(_fd: sock_t, set: &mut c::fd_set, timeout: u64) -> libc::c_int {
-            let tv = libc::timeval {
-                tv_sec: (timeout / 1000) as libc::time_t,
-                tv_usec: ((timeout % 1000) * 1000) as libc::suseconds_t,
-            };
-            unsafe { c::select(1, ptr::mut_null(), set, ptr::mut_null(), &tv) }
+            let tv = ms_to_timeval(timeout);
+            unsafe { c::select(1, ptr::null(), &*set, ptr::null(), &tv) }
         }
     }
 
@@ -467,7 +470,7 @@ impl Drop for Inner {
 ////////////////////////////////////////////////////////////////////////////////
 
 pub struct TcpListener {
-    inner: UnsafeArc<Inner>,
+    inner: Inner,
 }
 
 impl TcpListener {
@@ -477,7 +480,7 @@ impl TcpListener {
                 let (addr, len) = addr_to_sockaddr(addr);
                 let addrp = &addr as *libc::sockaddr_storage;
                 let inner = Inner { fd: fd };
-                let ret = TcpListener { inner: UnsafeArc::new(inner) };
+                let ret = TcpListener { inner: inner };
                 // On platforms with Berkeley-derived sockets, this allows
                 // to quickly rebind a socket, without needing to wait for
                 // the OS to clean up the previous one.
@@ -498,15 +501,12 @@ impl TcpListener {
         }
     }
 
-    pub fn fd(&self) -> sock_t {
-        // This is just a read-only arc so the unsafety is fine
-        unsafe { (*self.inner.get()).fd }
-    }
+    pub fn fd(&self) -> sock_t { self.inner.fd }
 
     pub fn native_listen(self, backlog: int) -> IoResult<TcpAcceptor> {
         match unsafe { libc::listen(self.fd(), backlog as libc::c_int) } {
             -1 => Err(last_error()),
-            _ => Ok(TcpAcceptor { listener: self })
+            _ => Ok(TcpAcceptor { listener: self, deadline: 0 })
         }
     }
 }
@@ -525,12 +525,16 @@ impl rtio::RtioSocket for TcpListener {
 
 pub struct TcpAcceptor {
     listener: TcpListener,
+    deadline: u64,
 }
 
 impl TcpAcceptor {
     pub fn fd(&self) -> sock_t { self.listener.fd() }
 
     pub fn native_accept(&mut self) -> IoResult<TcpStream> {
+        if self.deadline != 0 {
+            try!(self.accept_deadline());
+        }
         unsafe {
             let mut storage: libc::sockaddr_storage = mem::init();
             let storagep = &mut storage as *mut libc::sockaddr_storage;
@@ -546,6 +550,25 @@ impl TcpAcceptor {
             }
         }
     }
+
+    fn accept_deadline(&mut self) -> IoResult<()> {
+        let mut set: c::fd_set = unsafe { mem::init() };
+        c::fd_set(&mut set, self.fd());
+
+        match retry(|| {
+            // If we're past the deadline, then pass a 0 timeout to select() so
+            // we can poll the status of the socket.
+            let now = ::io::timer::now();
+            let ms = if self.deadline > now {0} else {self.deadline - now};
+            let tv = ms_to_timeval(ms);
+            let n = if cfg!(windows) {1} else {self.fd() as libc::c_int + 1};
+            unsafe { c::select(n, &set, ptr::null(), ptr::null(), &tv) }
+        }) {
+            -1 => Err(last_error()),
+            0 => Err(timeout("accept timed out")),
+            _ => return Ok(()),
+        }
+    }
 }
 
 impl rtio::RtioSocket for TcpAcceptor {
@@ -561,6 +584,12 @@ impl rtio::RtioTcpAcceptor for TcpAcceptor {
 
     fn accept_simultaneously(&mut self) -> IoResult<()> { Ok(()) }
     fn dont_accept_simultaneously(&mut self) -> IoResult<()> { Ok(()) }
+    fn set_timeout(&mut self, timeout: Option<u64>) {
+        self.deadline = match timeout {
+            None => 0,
+            Some(t) => ::io::timer::now() + t,
+        };
+    }
 }
 
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/src/libnative/io/timer_win32.rs b/src/libnative/io/timer_win32.rs
index a15898feb92..588ec367d81 100644
--- a/src/libnative/io/timer_win32.rs
+++ b/src/libnative/io/timer_win32.rs
@@ -89,6 +89,17 @@ fn helper(input: libc::HANDLE, messages: Receiver<Req>) {
     }
 }
 
+// returns the current time (in milliseconds)
+pub fn now() -> u64 {
+    let mut ticks_per_s = 0;
+    assert_eq!(unsafe { libc::QueryPerformanceFrequency(&mut ticks_per_s) }, 1);
+    let ticks_per_s = if ticks_per_s == 0 {1} else {ticks_per_s};
+    let mut ticks = 0;
+    assert_eq!(unsafe { libc::QueryPerformanceCounter(&mut ticks) }, 1);
+
+    return (ticks as u64 * 1000) / (ticks_per_s as u64);
+}
+
 impl Timer {
     pub fn new() -> IoResult<Timer> {
         timer_helper::boot(helper);
diff --git a/src/librustuv/net.rs b/src/librustuv/net.rs
index 69d978b2433..f8df9263be1 100644
--- a/src/librustuv/net.rs
+++ b/src/librustuv/net.rs
@@ -174,6 +174,9 @@ pub struct TcpListener {
 
 pub struct TcpAcceptor {
     listener: ~TcpListener,
+    timer: Option<TimerWatcher>,
+    timeout_tx: Option<Sender<()>>,
+    timeout_rx: Option<Receiver<()>>,
 }
 
 // TCP watchers (clients/streams)
@@ -459,7 +462,12 @@ impl rtio::RtioSocket for TcpListener {
 impl rtio::RtioTcpListener for TcpListener {
     fn listen(~self) -> Result<~rtio::RtioTcpAcceptor:Send, IoError> {
         // create the acceptor object from ourselves
-        let mut acceptor = ~TcpAcceptor { listener: self };
+        let mut acceptor = ~TcpAcceptor {
+            listener: self,
+            timer: None,
+            timeout_tx: None,
+            timeout_rx: None,
+        };
 
         let _m = acceptor.fire_homing_missile();
         // FIXME: the 128 backlog should be configurable
@@ -509,7 +517,37 @@ impl rtio::RtioSocket for TcpAcceptor {
 
 impl rtio::RtioTcpAcceptor for TcpAcceptor {
     fn accept(&mut self) -> Result<~rtio::RtioTcpStream:Send, IoError> {
-        self.listener.incoming.recv()
+        match self.timeout_rx {
+            None => self.listener.incoming.recv(),
+            Some(ref rx) => {
+                use std::comm::Select;
+
+                // Poll the incoming channel first (don't rely on the order of
+                // select just yet). If someone's pending then we should return
+                // them immediately.
+                match self.listener.incoming.try_recv() {
+                    Ok(data) => return data,
+                    Err(..) => {}
+                }
+
+                // Use select to figure out which channel gets ready first. We
+                // do some custom handling of select to ensure that we never
+                // actually drain the timeout channel (we'll keep seeing the
+                // timeout message in the future).
+                let s = Select::new();
+                let mut timeout = s.handle(rx);
+                let mut data = s.handle(&self.listener.incoming);
+                unsafe {
+                    timeout.add();
+                    data.add();
+                }
+                if s.wait() == timeout.id() {
+                    Err(uv_error_to_io_error(UvError(uvll::ECANCELED)))
+                } else {
+                    self.listener.incoming.recv()
+                }
+            }
+        }
     }
 
     fn accept_simultaneously(&mut self) -> Result<(), IoError> {
@@ -525,6 +563,52 @@ impl rtio::RtioTcpAcceptor for TcpAcceptor {
             uvll::uv_tcp_simultaneous_accepts(self.listener.handle, 0)
         })
     }
+
+    fn set_timeout(&mut self, ms: Option<u64>) {
+        // First, if the timeout is none, clear any previous timeout by dropping
+        // the timer and transmission channels
+        let ms = match ms {
+            None => {
+                return drop((self.timer.take(),
+                             self.timeout_tx.take(),
+                             self.timeout_rx.take()))
+            }
+            Some(ms) => ms,
+        };
+
+        // If we have a timeout, lazily initialize the timer which will be used
+        // to fire when the timeout runs out.
+        if self.timer.is_none() {
+            let _m = self.fire_homing_missile();
+            let loop_ = Loop::wrap(unsafe {
+                uvll::get_loop_for_uv_handle(self.listener.handle)
+            });
+            let mut timer = TimerWatcher::new_home(&loop_, self.home().clone());
+            unsafe {
+                timer.set_data(self as *mut _ as *TcpAcceptor);
+            }
+            self.timer = Some(timer);
+        }
+
+        // Once we've got a timer, stop any previous timeout, reset it for the
+        // current one, and install some new channels to send/receive data on
+        let timer = self.timer.get_mut_ref();
+        timer.stop();
+        timer.start(timer_cb, ms, 0);
+        let (tx, rx) = channel();
+        self.timeout_tx = Some(tx);
+        self.timeout_rx = Some(rx);
+
+        extern fn timer_cb(timer: *uvll::uv_timer_t, status: c_int) {
+            assert_eq!(status, 0);
+            let acceptor: &mut TcpAcceptor = unsafe {
+                &mut *(uvll::get_data_for_uv_handle(timer) as *mut TcpAcceptor)
+            };
+            // This send can never fail because if this timer is active then the
+            // receiving channel is guaranteed to be alive
+            acceptor.timeout_tx.get_ref().send(());
+        }
+    }
 }
 
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/src/librustuv/timer.rs b/src/librustuv/timer.rs
index 3710d97827f..65ab32c6965 100644
--- a/src/librustuv/timer.rs
+++ b/src/librustuv/timer.rs
@@ -14,7 +14,7 @@ use std::rt::rtio::RtioTimer;
 use std::rt::task::BlockedTask;
 
 use homing::{HomeHandle, HomingIO};
-use super::{UvHandle, ForbidUnwind, ForbidSwitch, wait_until_woken_after};
+use super::{UvHandle, ForbidUnwind, ForbidSwitch, wait_until_woken_after, Loop};
 use uvio::UvIoFactory;
 use uvll;
 
@@ -34,18 +34,21 @@ pub enum NextAction {
 
 impl TimerWatcher {
     pub fn new(io: &mut UvIoFactory) -> ~TimerWatcher {
+        let handle = io.make_handle();
+        let me = ~TimerWatcher::new_home(&io.loop_, handle);
+        me.install()
+    }
+
+    pub fn new_home(loop_: &Loop, home: HomeHandle) -> TimerWatcher {
         let handle = UvHandle::alloc(None::<TimerWatcher>, uvll::UV_TIMER);
-        assert_eq!(unsafe {
-            uvll::uv_timer_init(io.uv_loop(), handle)
-        }, 0);
-        let me = ~TimerWatcher {
+        assert_eq!(unsafe { uvll::uv_timer_init(loop_.handle, handle) }, 0);
+        TimerWatcher {
             handle: handle,
             action: None,
             blocker: None,
-            home: io.make_handle(),
+            home: home,
             id: 0,
-        };
-        return me.install();
+        }
     }
 
     pub fn start(&mut self, f: uvll::uv_timer_cb, msecs: u64, period: u64) {
diff --git a/src/libstd/io/net/tcp.rs b/src/libstd/io/net/tcp.rs
index 4f1e6bd7418..0619c89aac1 100644
--- a/src/libstd/io/net/tcp.rs
+++ b/src/libstd/io/net/tcp.rs
@@ -22,7 +22,7 @@ use io::IoResult;
 use io::net::ip::SocketAddr;
 use io::{Reader, Writer, Listener, Acceptor};
 use kinds::Send;
-use option::{None, Some};
+use option::{None, Some, Option};
 use rt::rtio::{IoFactory, LocalIo, RtioSocket, RtioTcpListener};
 use rt::rtio::{RtioTcpAcceptor, RtioTcpStream};
 
@@ -184,6 +184,56 @@ pub struct TcpAcceptor {
     obj: ~RtioTcpAcceptor:Send
 }
 
+impl TcpAcceptor {
+    /// Prevents blocking on all future accepts after `ms` milliseconds have
+    /// elapsed.
+    ///
+    /// This function is used to set a deadline after which this acceptor will
+    /// time out accepting any connections. The argument is the relative
+    /// distance, in milliseconds, to a point in the future after which all
+    /// accepts will fail.
+    ///
+    /// If the argument specified is `None`, then any previously registered
+    /// timeout is cleared.
+    ///
+    /// A timeout of `0` can be used to "poll" this acceptor to see if it has
+    /// any pending connections. All pending connections will be accepted,
+    /// regardless of whether the timeout has expired or not (the accept will
+    /// not block in this case).
+    ///
+    /// # Example
+    ///
+    /// ```no_run
+    /// # #![allow(experimental)]
+    /// use std::io::net::tcp::TcpListener;
+    /// use std::io::net::ip::{SocketAddr, Ipv4Addr};
+    /// use std::io::{Listener, Acceptor, TimedOut};
+    ///
+    /// let addr = SocketAddr { ip: Ipv4Addr(127, 0, 0, 1), port: 8482 };
+    /// let mut a = TcpListener::bind(addr).listen().unwrap();
+    ///
+    /// // After 100ms have passed, all accepts will fail
+    /// a.set_timeout(Some(100));
+    ///
+    /// match a.accept() {
+    ///     Ok(..) => println!("accepted a socket"),
+    ///     Err(ref e) if e.kind == TimedOut => { println!("timed out!"); }
+    ///     Err(e) => println!("err: {}", e),
+    /// }
+    ///
+    /// // Reset the timeout and try again
+    /// a.set_timeout(Some(100));
+    /// let socket = a.accept();
+    ///
+    /// // Clear the timeout and block indefinitely waiting for a connection
+    /// a.set_timeout(None);
+    /// let socket = a.accept();
+    /// ```
+    #[experimental = "the type of the argument and name of this function are \
+                      subject to change"]
+    pub fn set_timeout(&mut self, ms: Option<u64>) { self.obj.set_timeout(ms); }
+}
+
 impl Acceptor<TcpStream> for TcpAcceptor {
     fn accept(&mut self) -> IoResult<TcpStream> {
         self.obj.accept().map(TcpStream::new)
@@ -191,6 +241,7 @@ impl Acceptor<TcpStream> for TcpAcceptor {
 }
 
 #[cfg(test)]
+#[allow(experimental)]
 mod test {
     use super::*;
     use io::net::ip::SocketAddr;
@@ -749,4 +800,37 @@ mod test {
         assert!(s.write([1]).is_err());
         assert_eq!(s.read_to_end(), Ok(vec!(1)));
     })
+
+    iotest!(fn accept_timeout() {
+        let addr = next_test_ip4();
+        let mut a = TcpListener::bind(addr).unwrap().listen().unwrap();
+
+        a.set_timeout(Some(10));
+
+        // Make sure we time out once and future invocations also time out
+        let err = a.accept().err().unwrap();
+        assert_eq!(err.kind, TimedOut);
+        let err = a.accept().err().unwrap();
+        assert_eq!(err.kind, TimedOut);
+
+        // Also make sure that even though the timeout is expired that we will
+        // continue to receive any pending connections.
+        let l = TcpStream::connect(addr).unwrap();
+        for i in range(0, 1001) {
+            match a.accept() {
+                Ok(..) => break,
+                Err(ref e) if e.kind == TimedOut => {}
+                Err(e) => fail!("error: {}", e),
+            }
+            if i == 1000 { fail!("should have a pending connection") }
+        }
+        drop(l);
+
+        // Unset the timeout and make sure that this always blocks.
+        a.set_timeout(None);
+        spawn(proc() {
+            drop(TcpStream::connect(addr));
+        });
+        a.accept().unwrap();
+    })
 }
diff --git a/src/libstd/rt/rtio.rs b/src/libstd/rt/rtio.rs
index 0f3fc9c21ce..5dd14834669 100644
--- a/src/libstd/rt/rtio.rs
+++ b/src/libstd/rt/rtio.rs
@@ -200,6 +200,7 @@ pub trait RtioTcpAcceptor : RtioSocket {
     fn accept(&mut self) -> IoResult<~RtioTcpStream:Send>;
     fn accept_simultaneously(&mut self) -> IoResult<()>;
     fn dont_accept_simultaneously(&mut self) -> IoResult<()>;
+    fn set_timeout(&mut self, timeout: Option<u64>);
 }
 
 pub trait RtioTcpStream : RtioSocket {