about summary refs log tree commit diff
path: root/library/std/src/sys/sgx/waitqueue/mod.rs
diff options
context:
space:
mode:
authorJethro Beekman <jethro@fortanix.com>2021-05-07 13:40:43 +0200
committerJethro Beekman <jethro@fortanix.com>2021-05-07 13:55:03 +0200
commitbfa84842e546a5fb2d1c69ba499ddc94bee1e01a (patch)
tree6f139af6fcd62774e43a769664e2be98579d51a2 /library/std/src/sys/sgx/waitqueue/mod.rs
parentac888e8675182c703c2cd097957878faf88dad94 (diff)
downloadrust-bfa84842e546a5fb2d1c69ba499ddc94bee1e01a.tar.gz
rust-bfa84842e546a5fb2d1c69ba499ddc94bee1e01a.zip
Rearrange SGX split module files
In #75979 several inlined modules were split out into multiple files.
This PR keeps the multiple files but moves a few things around to
organize things in a coherent way.
Diffstat (limited to 'library/std/src/sys/sgx/waitqueue/mod.rs')
-rw-r--r--library/std/src/sys/sgx/waitqueue/mod.rs240
1 files changed, 240 insertions, 0 deletions
diff --git a/library/std/src/sys/sgx/waitqueue/mod.rs b/library/std/src/sys/sgx/waitqueue/mod.rs
new file mode 100644
index 00000000000..61bb11d9a6f
--- /dev/null
+++ b/library/std/src/sys/sgx/waitqueue/mod.rs
@@ -0,0 +1,240 @@
+//! A simple queue implementation for synchronization primitives.
+//!
+//! This queue is used to implement condition variable and mutexes.
+//!
+//! Users of this API are expected to use the `WaitVariable<T>` type. Since
+//! that type is not `Sync`, it needs to be protected by e.g., a `SpinMutex` to
+//! allow shared access.
+//!
+//! Since userspace may send spurious wake-ups, the wakeup event state is
+//! recorded in the enclave. The wakeup event state is protected by a spinlock.
+//! The queue and associated wait state are stored in a `WaitVariable`.
+
+#[cfg(test)]
+mod tests;
+
+mod spin_mutex;
+mod unsafe_list;
+
+use crate::num::NonZeroUsize;
+use crate::ops::{Deref, DerefMut};
+use crate::time::Duration;
+
+use super::abi::thread;
+use super::abi::usercalls;
+use fortanix_sgx_abi::{Tcs, EV_UNPARK, WAIT_INDEFINITE};
+
+pub use self::spin_mutex::{try_lock_or_false, SpinMutex, SpinMutexGuard};
+use self::unsafe_list::{UnsafeList, UnsafeListEntry};
+
+/// An queue entry in a `WaitQueue`.
+struct WaitEntry {
+    /// TCS address of the thread that is waiting
+    tcs: Tcs,
+    /// Whether this thread has been notified to be awoken
+    wake: bool,
+}
+
+/// Data stored with a `WaitQueue` alongside it. This ensures accesses to the
+/// queue and the data are synchronized, since the type itself is not `Sync`.
+///
+/// Consumers of this API should use a synchronization primitive for shared
+/// access, such as `SpinMutex`.
+#[derive(Default)]
+pub struct WaitVariable<T> {
+    queue: WaitQueue,
+    lock: T,
+}
+
+impl<T> WaitVariable<T> {
+    pub const fn new(var: T) -> Self {
+        WaitVariable { queue: WaitQueue::new(), lock: var }
+    }
+
+    pub fn queue_empty(&self) -> bool {
+        self.queue.is_empty()
+    }
+
+    pub fn lock_var(&self) -> &T {
+        &self.lock
+    }
+
+    pub fn lock_var_mut(&mut self) -> &mut T {
+        &mut self.lock
+    }
+}
+
+#[derive(Copy, Clone)]
+pub enum NotifiedTcs {
+    Single(Tcs),
+    All { count: NonZeroUsize },
+}
+
+/// An RAII guard that will notify a set of target threads as well as unlock
+/// a mutex on drop.
+pub struct WaitGuard<'a, T: 'a> {
+    mutex_guard: Option<SpinMutexGuard<'a, WaitVariable<T>>>,
+    notified_tcs: NotifiedTcs,
+}
+
+/// A queue of threads that are waiting on some synchronization primitive.
+///
+/// `UnsafeList` entries are allocated on the waiting thread's stack. This
+/// avoids any global locking that might happen in the heap allocator. This is
+/// safe because the waiting thread will not return from that stack frame until
+/// after it is notified. The notifying thread ensures to clean up any
+/// references to the list entries before sending the wakeup event.
+pub struct WaitQueue {
+    // We use an inner Mutex here to protect the data in the face of spurious
+    // wakeups.
+    inner: UnsafeList<SpinMutex<WaitEntry>>,
+}
+unsafe impl Send for WaitQueue {}
+
+impl Default for WaitQueue {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl<'a, T> WaitGuard<'a, T> {
+    /// Returns which TCSes will be notified when this guard drops.
+    pub fn notified_tcs(&self) -> NotifiedTcs {
+        self.notified_tcs
+    }
+
+    /// Drop this `WaitGuard`, after dropping another `guard`.
+    pub fn drop_after<U>(self, guard: U) {
+        drop(guard);
+        drop(self);
+    }
+}
+
+impl<'a, T> Deref for WaitGuard<'a, T> {
+    type Target = SpinMutexGuard<'a, WaitVariable<T>>;
+
+    fn deref(&self) -> &Self::Target {
+        self.mutex_guard.as_ref().unwrap()
+    }
+}
+
+impl<'a, T> DerefMut for WaitGuard<'a, T> {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.mutex_guard.as_mut().unwrap()
+    }
+}
+
+impl<'a, T> Drop for WaitGuard<'a, T> {
+    fn drop(&mut self) {
+        drop(self.mutex_guard.take());
+        let target_tcs = match self.notified_tcs {
+            NotifiedTcs::Single(tcs) => Some(tcs),
+            NotifiedTcs::All { .. } => None,
+        };
+        rtunwrap!(Ok, usercalls::send(EV_UNPARK, target_tcs));
+    }
+}
+
+impl WaitQueue {
+    pub const fn new() -> Self {
+        WaitQueue { inner: UnsafeList::new() }
+    }
+
+    pub fn is_empty(&self) -> bool {
+        self.inner.is_empty()
+    }
+
+    /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait
+    /// until a wakeup event.
+    ///
+    /// This function does not return until this thread has been awoken.
+    pub fn wait<T, F: FnOnce()>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>, before_wait: F) {
+        // very unsafe: check requirements of UnsafeList::push
+        unsafe {
+            let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
+                tcs: thread::current(),
+                wake: false,
+            }));
+            let entry = guard.queue.inner.push(&mut entry);
+            drop(guard);
+            before_wait();
+            while !entry.lock().wake {
+                // don't panic, this would invalidate `entry` during unwinding
+                let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE));
+                rtassert!(eventset & EV_UNPARK == EV_UNPARK);
+            }
+        }
+    }
+
+    /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait
+    /// until a wakeup event or timeout. If event was observed, returns true.
+    /// If not, it will remove the calling thread from the wait queue.
+    pub fn wait_timeout<T, F: FnOnce()>(
+        lock: &SpinMutex<WaitVariable<T>>,
+        timeout: Duration,
+        before_wait: F,
+    ) -> bool {
+        // very unsafe: check requirements of UnsafeList::push
+        unsafe {
+            let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
+                tcs: thread::current(),
+                wake: false,
+            }));
+            let entry_lock = lock.lock().queue.inner.push(&mut entry);
+            before_wait();
+            usercalls::wait_timeout(EV_UNPARK, timeout, || entry_lock.lock().wake);
+            // acquire the wait queue's lock first to avoid deadlock.
+            let mut guard = lock.lock();
+            let success = entry_lock.lock().wake;
+            if !success {
+                // nobody is waking us up, so remove our entry from the wait queue.
+                guard.queue.inner.remove(&mut entry);
+            }
+            success
+        }
+    }
+
+    /// Either find the next waiter on the wait queue, or return the mutex
+    /// guard unchanged.
+    ///
+    /// If a waiter is found, a `WaitGuard` is returned which will notify the
+    /// waiter when it is dropped.
+    pub fn notify_one<T>(
+        mut guard: SpinMutexGuard<'_, WaitVariable<T>>,
+    ) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> {
+        unsafe {
+            if let Some(entry) = guard.queue.inner.pop() {
+                let mut entry_guard = entry.lock();
+                let tcs = entry_guard.tcs;
+                entry_guard.wake = true;
+                drop(entry);
+                Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::Single(tcs) })
+            } else {
+                Err(guard)
+            }
+        }
+    }
+
+    /// Either find any and all waiters on the wait queue, or return the mutex
+    /// guard unchanged.
+    ///
+    /// If at least one waiter is found, a `WaitGuard` is returned which will
+    /// notify all waiters when it is dropped.
+    pub fn notify_all<T>(
+        mut guard: SpinMutexGuard<'_, WaitVariable<T>>,
+    ) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> {
+        unsafe {
+            let mut count = 0;
+            while let Some(entry) = guard.queue.inner.pop() {
+                count += 1;
+                let mut entry_guard = entry.lock();
+                entry_guard.wake = true;
+            }
+            if let Some(count) = NonZeroUsize::new(count) {
+                Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::All { count } })
+            } else {
+                Err(guard)
+            }
+        }
+    }
+}