about summary refs log tree commit diff
path: root/library/std/src/sys/sgx/waitqueue.rs
diff options
context:
space:
mode:
Diffstat (limited to 'library/std/src/sys/sgx/waitqueue.rs')
-rw-r--r--library/std/src/sys/sgx/waitqueue.rs619
1 files changed, 619 insertions, 0 deletions
diff --git a/library/std/src/sys/sgx/waitqueue.rs b/library/std/src/sys/sgx/waitqueue.rs
new file mode 100644
index 00000000000..070afa55f30
--- /dev/null
+++ b/library/std/src/sys/sgx/waitqueue.rs
@@ -0,0 +1,619 @@
+//! A simple queue implementation for synchronization primitives.
+//!
+//! This queue is used to implement condition variable and mutexes.
+//!
+//! Users of this API are expected to use the `WaitVariable<T>` type. Since
+//! that type is not `Sync`, it needs to be protected by e.g., a `SpinMutex` to
+//! allow shared access.
+//!
+//! Since userspace may send spurious wake-ups, the wakeup event state is
+//! recorded in the enclave. The wakeup event state is protected by a spinlock.
+//! The queue and associated wait state are stored in a `WaitVariable`.
+use crate::num::NonZeroUsize;
+use crate::ops::{Deref, DerefMut};
+use crate::time::Duration;
+
+use super::abi::thread;
+use super::abi::usercalls;
+use fortanix_sgx_abi::{Tcs, EV_UNPARK, WAIT_INDEFINITE};
+
+pub use self::spin_mutex::{try_lock_or_false, SpinMutex, SpinMutexGuard};
+use self::unsafe_list::{UnsafeList, UnsafeListEntry};
+
+/// An queue entry in a `WaitQueue`.
+struct WaitEntry {
+    /// TCS address of the thread that is waiting
+    tcs: Tcs,
+    /// Whether this thread has been notified to be awoken
+    wake: bool,
+}
+
+/// Data stored with a `WaitQueue` alongside it. This ensures accesses to the
+/// queue and the data are synchronized, since the type itself is not `Sync`.
+///
+/// Consumers of this API should use a synchronization primitive for shared
+/// access, such as `SpinMutex`.
+#[derive(Default)]
+pub struct WaitVariable<T> {
+    queue: WaitQueue,
+    lock: T,
+}
+
+impl<T> WaitVariable<T> {
+    pub const fn new(var: T) -> Self {
+        WaitVariable { queue: WaitQueue::new(), lock: var }
+    }
+
+    pub fn queue_empty(&self) -> bool {
+        self.queue.is_empty()
+    }
+
+    pub fn lock_var(&self) -> &T {
+        &self.lock
+    }
+
+    pub fn lock_var_mut(&mut self) -> &mut T {
+        &mut self.lock
+    }
+}
+
+#[derive(Copy, Clone)]
+pub enum NotifiedTcs {
+    Single(Tcs),
+    All { count: NonZeroUsize },
+}
+
+/// An RAII guard that will notify a set of target threads as well as unlock
+/// a mutex on drop.
+pub struct WaitGuard<'a, T: 'a> {
+    mutex_guard: Option<SpinMutexGuard<'a, WaitVariable<T>>>,
+    notified_tcs: NotifiedTcs,
+}
+
+/// A queue of threads that are waiting on some synchronization primitive.
+///
+/// `UnsafeList` entries are allocated on the waiting thread's stack. This
+/// avoids any global locking that might happen in the heap allocator. This is
+/// safe because the waiting thread will not return from that stack frame until
+/// after it is notified. The notifying thread ensures to clean up any
+/// references to the list entries before sending the wakeup event.
+pub struct WaitQueue {
+    // We use an inner Mutex here to protect the data in the face of spurious
+    // wakeups.
+    inner: UnsafeList<SpinMutex<WaitEntry>>,
+}
+unsafe impl Send for WaitQueue {}
+
+impl Default for WaitQueue {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl<'a, T> WaitGuard<'a, T> {
+    /// Returns which TCSes will be notified when this guard drops.
+    pub fn notified_tcs(&self) -> NotifiedTcs {
+        self.notified_tcs
+    }
+
+    /// Drop this `WaitGuard`, after dropping another `guard`.
+    pub fn drop_after<U>(self, guard: U) {
+        drop(guard);
+        drop(self);
+    }
+}
+
+impl<'a, T> Deref for WaitGuard<'a, T> {
+    type Target = SpinMutexGuard<'a, WaitVariable<T>>;
+
+    fn deref(&self) -> &Self::Target {
+        self.mutex_guard.as_ref().unwrap()
+    }
+}
+
+impl<'a, T> DerefMut for WaitGuard<'a, T> {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.mutex_guard.as_mut().unwrap()
+    }
+}
+
+impl<'a, T> Drop for WaitGuard<'a, T> {
+    fn drop(&mut self) {
+        drop(self.mutex_guard.take());
+        let target_tcs = match self.notified_tcs {
+            NotifiedTcs::Single(tcs) => Some(tcs),
+            NotifiedTcs::All { .. } => None,
+        };
+        rtunwrap!(Ok, usercalls::send(EV_UNPARK, target_tcs));
+    }
+}
+
+impl WaitQueue {
+    pub const fn new() -> Self {
+        WaitQueue { inner: UnsafeList::new() }
+    }
+
+    pub fn is_empty(&self) -> bool {
+        self.inner.is_empty()
+    }
+
+    /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait
+    /// until a wakeup event.
+    ///
+    /// This function does not return until this thread has been awoken.
+    pub fn wait<T, F: FnOnce()>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>, before_wait: F) {
+        // very unsafe: check requirements of UnsafeList::push
+        unsafe {
+            let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
+                tcs: thread::current(),
+                wake: false,
+            }));
+            let entry = guard.queue.inner.push(&mut entry);
+            drop(guard);
+            before_wait();
+            while !entry.lock().wake {
+                // don't panic, this would invalidate `entry` during unwinding
+                let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE));
+                rtassert!(eventset & EV_UNPARK == EV_UNPARK);
+            }
+        }
+    }
+
+    /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait
+    /// until a wakeup event or timeout. If event was observed, returns true.
+    /// If not, it will remove the calling thread from the wait queue.
+    pub fn wait_timeout<T, F: FnOnce()>(
+        lock: &SpinMutex<WaitVariable<T>>,
+        timeout: Duration,
+        before_wait: F,
+    ) -> bool {
+        // very unsafe: check requirements of UnsafeList::push
+        unsafe {
+            let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
+                tcs: thread::current(),
+                wake: false,
+            }));
+            let entry_lock = lock.lock().queue.inner.push(&mut entry);
+            before_wait();
+            usercalls::wait_timeout(EV_UNPARK, timeout, || entry_lock.lock().wake);
+            // acquire the wait queue's lock first to avoid deadlock.
+            let mut guard = lock.lock();
+            let success = entry_lock.lock().wake;
+            if !success {
+                // nobody is waking us up, so remove our entry from the wait queue.
+                guard.queue.inner.remove(&mut entry);
+            }
+            success
+        }
+    }
+
+    /// Either find the next waiter on the wait queue, or return the mutex
+    /// guard unchanged.
+    ///
+    /// If a waiter is found, a `WaitGuard` is returned which will notify the
+    /// waiter when it is dropped.
+    pub fn notify_one<T>(
+        mut guard: SpinMutexGuard<'_, WaitVariable<T>>,
+    ) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> {
+        unsafe {
+            if let Some(entry) = guard.queue.inner.pop() {
+                let mut entry_guard = entry.lock();
+                let tcs = entry_guard.tcs;
+                entry_guard.wake = true;
+                drop(entry);
+                Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::Single(tcs) })
+            } else {
+                Err(guard)
+            }
+        }
+    }
+
+    /// Either find any and all waiters on the wait queue, or return the mutex
+    /// guard unchanged.
+    ///
+    /// If at least one waiter is found, a `WaitGuard` is returned which will
+    /// notify all waiters when it is dropped.
+    pub fn notify_all<T>(
+        mut guard: SpinMutexGuard<'_, WaitVariable<T>>,
+    ) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> {
+        unsafe {
+            let mut count = 0;
+            while let Some(entry) = guard.queue.inner.pop() {
+                count += 1;
+                let mut entry_guard = entry.lock();
+                entry_guard.wake = true;
+            }
+            if let Some(count) = NonZeroUsize::new(count) {
+                Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::All { count } })
+            } else {
+                Err(guard)
+            }
+        }
+    }
+}
+
+/// A doubly-linked list where callers are in charge of memory allocation
+/// of the nodes in the list.
+mod unsafe_list {
+    use crate::mem;
+    use crate::ptr::NonNull;
+
+    pub struct UnsafeListEntry<T> {
+        next: NonNull<UnsafeListEntry<T>>,
+        prev: NonNull<UnsafeListEntry<T>>,
+        value: Option<T>,
+    }
+
+    impl<T> UnsafeListEntry<T> {
+        fn dummy() -> Self {
+            UnsafeListEntry { next: NonNull::dangling(), prev: NonNull::dangling(), value: None }
+        }
+
+        pub fn new(value: T) -> Self {
+            UnsafeListEntry { value: Some(value), ..Self::dummy() }
+        }
+    }
+
+    pub struct UnsafeList<T> {
+        head_tail: NonNull<UnsafeListEntry<T>>,
+        head_tail_entry: Option<UnsafeListEntry<T>>,
+    }
+
+    impl<T> UnsafeList<T> {
+        pub const fn new() -> Self {
+            unsafe {
+                UnsafeList { head_tail: NonNull::new_unchecked(1 as _), head_tail_entry: None }
+            }
+        }
+
+        unsafe fn init(&mut self) {
+            if self.head_tail_entry.is_none() {
+                self.head_tail_entry = Some(UnsafeListEntry::dummy());
+                self.head_tail = NonNull::new_unchecked(self.head_tail_entry.as_mut().unwrap());
+                self.head_tail.as_mut().next = self.head_tail;
+                self.head_tail.as_mut().prev = self.head_tail;
+            }
+        }
+
+        pub fn is_empty(&self) -> bool {
+            unsafe {
+                if self.head_tail_entry.is_some() {
+                    let first = self.head_tail.as_ref().next;
+                    if first == self.head_tail {
+                        // ,-------> /---------\ next ---,
+                        // |         |head_tail|         |
+                        // `--- prev \---------/ <-------`
+                        rtassert!(self.head_tail.as_ref().prev == first);
+                        true
+                    } else {
+                        false
+                    }
+                } else {
+                    true
+                }
+            }
+        }
+
+        /// Pushes an entry onto the back of the list.
+        ///
+        /// # Safety
+        ///
+        /// The entry must remain allocated until the entry is removed from the
+        /// list AND the caller who popped is done using the entry. Special
+        /// care must be taken in the caller of `push` to ensure unwinding does
+        /// not destroy the stack frame containing the entry.
+        pub unsafe fn push<'a>(&mut self, entry: &'a mut UnsafeListEntry<T>) -> &'a T {
+            self.init();
+
+            // BEFORE:
+            //     /---------\ next ---> /---------\
+            // ... |prev_tail|           |head_tail| ...
+            //     \---------/ <--- prev \---------/
+            //
+            // AFTER:
+            //     /---------\ next ---> /-----\ next ---> /---------\
+            // ... |prev_tail|           |entry|           |head_tail| ...
+            //     \---------/ <--- prev \-----/ <--- prev \---------/
+            let mut entry = NonNull::new_unchecked(entry);
+            let mut prev_tail = mem::replace(&mut self.head_tail.as_mut().prev, entry);
+            entry.as_mut().prev = prev_tail;
+            entry.as_mut().next = self.head_tail;
+            prev_tail.as_mut().next = entry;
+            // unwrap ok: always `Some` on non-dummy entries
+            (*entry.as_ptr()).value.as_ref().unwrap()
+        }
+
+        /// Pops an entry from the front of the list.
+        ///
+        /// # Safety
+        ///
+        /// The caller must make sure to synchronize ending the borrow of the
+        /// return value and deallocation of the containing entry.
+        pub unsafe fn pop<'a>(&mut self) -> Option<&'a T> {
+            self.init();
+
+            if self.is_empty() {
+                None
+            } else {
+                // BEFORE:
+                //     /---------\ next ---> /-----\ next ---> /------\
+                // ... |head_tail|           |first|           |second| ...
+                //     \---------/ <--- prev \-----/ <--- prev \------/
+                //
+                // AFTER:
+                //     /---------\ next ---> /------\
+                // ... |head_tail|           |second| ...
+                //     \---------/ <--- prev \------/
+                let mut first = self.head_tail.as_mut().next;
+                let mut second = first.as_mut().next;
+                self.head_tail.as_mut().next = second;
+                second.as_mut().prev = self.head_tail;
+                first.as_mut().next = NonNull::dangling();
+                first.as_mut().prev = NonNull::dangling();
+                // unwrap ok: always `Some` on non-dummy entries
+                Some((*first.as_ptr()).value.as_ref().unwrap())
+            }
+        }
+
+        /// Removes an entry from the list.
+        ///
+        /// # Safety
+        ///
+        /// The caller must ensure that `entry` has been pushed onto `self`
+        /// prior to this call and has not moved since then.
+        pub unsafe fn remove(&mut self, entry: &mut UnsafeListEntry<T>) {
+            rtassert!(!self.is_empty());
+            // BEFORE:
+            //     /----\ next ---> /-----\ next ---> /----\
+            // ... |prev|           |entry|           |next| ...
+            //     \----/ <--- prev \-----/ <--- prev \----/
+            //
+            // AFTER:
+            //     /----\ next ---> /----\
+            // ... |prev|           |next| ...
+            //     \----/ <--- prev \----/
+            let mut prev = entry.prev;
+            let mut next = entry.next;
+            prev.as_mut().next = next;
+            next.as_mut().prev = prev;
+            entry.next = NonNull::dangling();
+            entry.prev = NonNull::dangling();
+        }
+    }
+
+    #[cfg(test)]
+    mod tests {
+        use super::*;
+        use crate::cell::Cell;
+
+        unsafe fn assert_empty<T>(list: &mut UnsafeList<T>) {
+            assert!(list.pop().is_none(), "assertion failed: list is not empty");
+        }
+
+        #[test]
+        fn init_empty() {
+            unsafe {
+                assert_empty(&mut UnsafeList::<i32>::new());
+            }
+        }
+
+        #[test]
+        fn push_pop() {
+            unsafe {
+                let mut node = UnsafeListEntry::new(1234);
+                let mut list = UnsafeList::new();
+                assert_eq!(list.push(&mut node), &1234);
+                assert_eq!(list.pop().unwrap(), &1234);
+                assert_empty(&mut list);
+            }
+        }
+
+        #[test]
+        fn push_remove() {
+            unsafe {
+                let mut node = UnsafeListEntry::new(1234);
+                let mut list = UnsafeList::new();
+                assert_eq!(list.push(&mut node), &1234);
+                list.remove(&mut node);
+                assert_empty(&mut list);
+            }
+        }
+
+        #[test]
+        fn push_remove_pop() {
+            unsafe {
+                let mut node1 = UnsafeListEntry::new(11);
+                let mut node2 = UnsafeListEntry::new(12);
+                let mut node3 = UnsafeListEntry::new(13);
+                let mut node4 = UnsafeListEntry::new(14);
+                let mut node5 = UnsafeListEntry::new(15);
+                let mut list = UnsafeList::new();
+                assert_eq!(list.push(&mut node1), &11);
+                assert_eq!(list.push(&mut node2), &12);
+                assert_eq!(list.push(&mut node3), &13);
+                assert_eq!(list.push(&mut node4), &14);
+                assert_eq!(list.push(&mut node5), &15);
+
+                list.remove(&mut node1);
+                assert_eq!(list.pop().unwrap(), &12);
+                list.remove(&mut node3);
+                assert_eq!(list.pop().unwrap(), &14);
+                list.remove(&mut node5);
+                assert_empty(&mut list);
+
+                assert_eq!(list.push(&mut node1), &11);
+                assert_eq!(list.pop().unwrap(), &11);
+                assert_empty(&mut list);
+
+                assert_eq!(list.push(&mut node3), &13);
+                assert_eq!(list.push(&mut node4), &14);
+                list.remove(&mut node3);
+                list.remove(&mut node4);
+                assert_empty(&mut list);
+            }
+        }
+
+        #[test]
+        fn complex_pushes_pops() {
+            unsafe {
+                let mut node1 = UnsafeListEntry::new(1234);
+                let mut node2 = UnsafeListEntry::new(4567);
+                let mut node3 = UnsafeListEntry::new(9999);
+                let mut node4 = UnsafeListEntry::new(8642);
+                let mut list = UnsafeList::new();
+                list.push(&mut node1);
+                list.push(&mut node2);
+                assert_eq!(list.pop().unwrap(), &1234);
+                list.push(&mut node3);
+                assert_eq!(list.pop().unwrap(), &4567);
+                assert_eq!(list.pop().unwrap(), &9999);
+                assert_empty(&mut list);
+                list.push(&mut node4);
+                assert_eq!(list.pop().unwrap(), &8642);
+                assert_empty(&mut list);
+            }
+        }
+
+        #[test]
+        fn cell() {
+            unsafe {
+                let mut node = UnsafeListEntry::new(Cell::new(0));
+                let mut list = UnsafeList::new();
+                let noderef = list.push(&mut node);
+                assert_eq!(noderef.get(), 0);
+                list.pop().unwrap().set(1);
+                assert_empty(&mut list);
+                assert_eq!(noderef.get(), 1);
+            }
+        }
+    }
+}
+
+/// Trivial spinlock-based implementation of `sync::Mutex`.
+// FIXME: Perhaps use Intel TSX to avoid locking?
+mod spin_mutex {
+    use crate::cell::UnsafeCell;
+    use crate::ops::{Deref, DerefMut};
+    use crate::sync::atomic::{spin_loop_hint, AtomicBool, Ordering};
+
+    #[derive(Default)]
+    pub struct SpinMutex<T> {
+        value: UnsafeCell<T>,
+        lock: AtomicBool,
+    }
+
+    unsafe impl<T: Send> Send for SpinMutex<T> {}
+    unsafe impl<T: Send> Sync for SpinMutex<T> {}
+
+    pub struct SpinMutexGuard<'a, T: 'a> {
+        mutex: &'a SpinMutex<T>,
+    }
+
+    impl<'a, T> !Send for SpinMutexGuard<'a, T> {}
+    unsafe impl<'a, T: Sync> Sync for SpinMutexGuard<'a, T> {}
+
+    impl<T> SpinMutex<T> {
+        pub const fn new(value: T) -> Self {
+            SpinMutex { value: UnsafeCell::new(value), lock: AtomicBool::new(false) }
+        }
+
+        #[inline(always)]
+        pub fn lock(&self) -> SpinMutexGuard<'_, T> {
+            loop {
+                match self.try_lock() {
+                    None => {
+                        while self.lock.load(Ordering::Relaxed) {
+                            spin_loop_hint()
+                        }
+                    }
+                    Some(guard) => return guard,
+                }
+            }
+        }
+
+        #[inline(always)]
+        pub fn try_lock(&self) -> Option<SpinMutexGuard<'_, T>> {
+            if !self.lock.compare_and_swap(false, true, Ordering::Acquire) {
+                Some(SpinMutexGuard { mutex: self })
+            } else {
+                None
+            }
+        }
+    }
+
+    /// Lock the Mutex or return false.
+    pub macro try_lock_or_false($e:expr) {
+        if let Some(v) = $e.try_lock() { v } else { return false }
+    }
+
+    impl<'a, T> Deref for SpinMutexGuard<'a, T> {
+        type Target = T;
+
+        fn deref(&self) -> &T {
+            unsafe { &*self.mutex.value.get() }
+        }
+    }
+
+    impl<'a, T> DerefMut for SpinMutexGuard<'a, T> {
+        fn deref_mut(&mut self) -> &mut T {
+            unsafe { &mut *self.mutex.value.get() }
+        }
+    }
+
+    impl<'a, T> Drop for SpinMutexGuard<'a, T> {
+        fn drop(&mut self) {
+            self.mutex.lock.store(false, Ordering::Release)
+        }
+    }
+
+    #[cfg(test)]
+    mod tests {
+        #![allow(deprecated)]
+
+        use super::*;
+        use crate::sync::Arc;
+        use crate::thread;
+        use crate::time::Duration;
+
+        #[test]
+        fn sleep() {
+            let mutex = Arc::new(SpinMutex::<i32>::default());
+            let mutex2 = mutex.clone();
+            let guard = mutex.lock();
+            let t1 = thread::spawn(move || {
+                *mutex2.lock() = 1;
+            });
+
+            thread::sleep(Duration::from_millis(50));
+
+            assert_eq!(*guard, 0);
+            drop(guard);
+            t1.join().unwrap();
+            assert_eq!(*mutex.lock(), 1);
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::sync::Arc;
+    use crate::thread;
+
+    #[test]
+    fn queue() {
+        let wq = Arc::new(SpinMutex::<WaitVariable<()>>::default());
+        let wq2 = wq.clone();
+
+        let locked = wq.lock();
+
+        let t1 = thread::spawn(move || {
+            // if we obtain the lock, the main thread should be waiting
+            assert!(WaitQueue::notify_one(wq2.lock()).is_ok());
+        });
+
+        WaitQueue::wait(locked, || {});
+
+        t1.join().unwrap();
+    }
+}