//! A simple queue implementation for synchronization primitives. //! //! This queue is used to implement condition variable and mutexes. //! //! Users of this API are expected to use the `WaitVariable` type. Since //! that type is not `Sync`, it needs to be protected by e.g., a `SpinMutex` to //! allow shared access. //! //! Since userspace may send spurious wake-ups, the wakeup event state is //! recorded in the enclave. The wakeup event state is protected by a spinlock. //! The queue and associated wait state are stored in a `WaitVariable`. use crate::num::NonZeroUsize; use crate::ops::{Deref, DerefMut}; use crate::time::Duration; use super::abi::thread; use super::abi::usercalls; use fortanix_sgx_abi::{Tcs, EV_UNPARK, WAIT_INDEFINITE}; pub use self::spin_mutex::{try_lock_or_false, SpinMutex, SpinMutexGuard}; use self::unsafe_list::{UnsafeList, UnsafeListEntry}; /// An queue entry in a `WaitQueue`. struct WaitEntry { /// TCS address of the thread that is waiting tcs: Tcs, /// Whether this thread has been notified to be awoken wake: bool, } /// Data stored with a `WaitQueue` alongside it. This ensures accesses to the /// queue and the data are synchronized, since the type itself is not `Sync`. /// /// Consumers of this API should use a synchronization primitive for shared /// access, such as `SpinMutex`. #[derive(Default)] pub struct WaitVariable { queue: WaitQueue, lock: T, } impl WaitVariable { pub const fn new(var: T) -> Self { WaitVariable { queue: WaitQueue::new(), lock: var } } pub fn queue_empty(&self) -> bool { self.queue.is_empty() } pub fn lock_var(&self) -> &T { &self.lock } pub fn lock_var_mut(&mut self) -> &mut T { &mut self.lock } } #[derive(Copy, Clone)] pub enum NotifiedTcs { Single(Tcs), All { count: NonZeroUsize }, } /// An RAII guard that will notify a set of target threads as well as unlock /// a mutex on drop. pub struct WaitGuard<'a, T: 'a> { mutex_guard: Option>>, notified_tcs: NotifiedTcs, } /// A queue of threads that are waiting on some synchronization primitive. /// /// `UnsafeList` entries are allocated on the waiting thread's stack. This /// avoids any global locking that might happen in the heap allocator. This is /// safe because the waiting thread will not return from that stack frame until /// after it is notified. The notifying thread ensures to clean up any /// references to the list entries before sending the wakeup event. pub struct WaitQueue { // We use an inner Mutex here to protect the data in the face of spurious // wakeups. inner: UnsafeList>, } unsafe impl Send for WaitQueue {} impl Default for WaitQueue { fn default() -> Self { Self::new() } } impl<'a, T> WaitGuard<'a, T> { /// Returns which TCSes will be notified when this guard drops. pub fn notified_tcs(&self) -> NotifiedTcs { self.notified_tcs } /// Drop this `WaitGuard`, after dropping another `guard`. pub fn drop_after(self, guard: U) { drop(guard); drop(self); } } impl<'a, T> Deref for WaitGuard<'a, T> { type Target = SpinMutexGuard<'a, WaitVariable>; fn deref(&self) -> &Self::Target { self.mutex_guard.as_ref().unwrap() } } impl<'a, T> DerefMut for WaitGuard<'a, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.mutex_guard.as_mut().unwrap() } } impl<'a, T> Drop for WaitGuard<'a, T> { fn drop(&mut self) { drop(self.mutex_guard.take()); let target_tcs = match self.notified_tcs { NotifiedTcs::Single(tcs) => Some(tcs), NotifiedTcs::All { .. } => None, }; rtunwrap!(Ok, usercalls::send(EV_UNPARK, target_tcs)); } } impl WaitQueue { pub const fn new() -> Self { WaitQueue { inner: UnsafeList::new() } } pub fn is_empty(&self) -> bool { self.inner.is_empty() } /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait /// until a wakeup event. /// /// This function does not return until this thread has been awoken. pub fn wait(mut guard: SpinMutexGuard<'_, WaitVariable>, before_wait: F) { // very unsafe: check requirements of UnsafeList::push unsafe { let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry { tcs: thread::current(), wake: false, })); let entry = guard.queue.inner.push(&mut entry); drop(guard); before_wait(); while !entry.lock().wake { // don't panic, this would invalidate `entry` during unwinding let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE)); rtassert!(eventset & EV_UNPARK == EV_UNPARK); } } } /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait /// until a wakeup event or timeout. If event was observed, returns true. /// If not, it will remove the calling thread from the wait queue. pub fn wait_timeout( lock: &SpinMutex>, timeout: Duration, before_wait: F, ) -> bool { // very unsafe: check requirements of UnsafeList::push unsafe { let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry { tcs: thread::current(), wake: false, })); let entry_lock = lock.lock().queue.inner.push(&mut entry); before_wait(); usercalls::wait_timeout(EV_UNPARK, timeout, || entry_lock.lock().wake); // acquire the wait queue's lock first to avoid deadlock. let mut guard = lock.lock(); let success = entry_lock.lock().wake; if !success { // nobody is waking us up, so remove our entry from the wait queue. guard.queue.inner.remove(&mut entry); } success } } /// Either find the next waiter on the wait queue, or return the mutex /// guard unchanged. /// /// If a waiter is found, a `WaitGuard` is returned which will notify the /// waiter when it is dropped. pub fn notify_one( mut guard: SpinMutexGuard<'_, WaitVariable>, ) -> Result, SpinMutexGuard<'_, WaitVariable>> { unsafe { if let Some(entry) = guard.queue.inner.pop() { let mut entry_guard = entry.lock(); let tcs = entry_guard.tcs; entry_guard.wake = true; drop(entry); Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::Single(tcs) }) } else { Err(guard) } } } /// Either find any and all waiters on the wait queue, or return the mutex /// guard unchanged. /// /// If at least one waiter is found, a `WaitGuard` is returned which will /// notify all waiters when it is dropped. pub fn notify_all( mut guard: SpinMutexGuard<'_, WaitVariable>, ) -> Result, SpinMutexGuard<'_, WaitVariable>> { unsafe { let mut count = 0; while let Some(entry) = guard.queue.inner.pop() { count += 1; let mut entry_guard = entry.lock(); entry_guard.wake = true; } if let Some(count) = NonZeroUsize::new(count) { Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::All { count } }) } else { Err(guard) } } } } /// A doubly-linked list where callers are in charge of memory allocation /// of the nodes in the list. mod unsafe_list { use crate::mem; use crate::ptr::NonNull; pub struct UnsafeListEntry { next: NonNull>, prev: NonNull>, value: Option, } impl UnsafeListEntry { fn dummy() -> Self { UnsafeListEntry { next: NonNull::dangling(), prev: NonNull::dangling(), value: None } } pub fn new(value: T) -> Self { UnsafeListEntry { value: Some(value), ..Self::dummy() } } } pub struct UnsafeList { head_tail: NonNull>, head_tail_entry: Option>, } impl UnsafeList { pub const fn new() -> Self { unsafe { UnsafeList { head_tail: NonNull::new_unchecked(1 as _), head_tail_entry: None } } } unsafe fn init(&mut self) { if self.head_tail_entry.is_none() { self.head_tail_entry = Some(UnsafeListEntry::dummy()); self.head_tail = NonNull::new_unchecked(self.head_tail_entry.as_mut().unwrap()); self.head_tail.as_mut().next = self.head_tail; self.head_tail.as_mut().prev = self.head_tail; } } pub fn is_empty(&self) -> bool { unsafe { if self.head_tail_entry.is_some() { let first = self.head_tail.as_ref().next; if first == self.head_tail { // ,-------> /---------\ next ---, // | |head_tail| | // `--- prev \---------/ <-------` rtassert!(self.head_tail.as_ref().prev == first); true } else { false } } else { true } } } /// Pushes an entry onto the back of the list. /// /// # Safety /// /// The entry must remain allocated until the entry is removed from the /// list AND the caller who popped is done using the entry. Special /// care must be taken in the caller of `push` to ensure unwinding does /// not destroy the stack frame containing the entry. pub unsafe fn push<'a>(&mut self, entry: &'a mut UnsafeListEntry) -> &'a T { self.init(); // BEFORE: // /---------\ next ---> /---------\ // ... |prev_tail| |head_tail| ... // \---------/ <--- prev \---------/ // // AFTER: // /---------\ next ---> /-----\ next ---> /---------\ // ... |prev_tail| |entry| |head_tail| ... // \---------/ <--- prev \-----/ <--- prev \---------/ let mut entry = NonNull::new_unchecked(entry); let mut prev_tail = mem::replace(&mut self.head_tail.as_mut().prev, entry); entry.as_mut().prev = prev_tail; entry.as_mut().next = self.head_tail; prev_tail.as_mut().next = entry; // unwrap ok: always `Some` on non-dummy entries (*entry.as_ptr()).value.as_ref().unwrap() } /// Pops an entry from the front of the list. /// /// # Safety /// /// The caller must make sure to synchronize ending the borrow of the /// return value and deallocation of the containing entry. pub unsafe fn pop<'a>(&mut self) -> Option<&'a T> { self.init(); if self.is_empty() { None } else { // BEFORE: // /---------\ next ---> /-----\ next ---> /------\ // ... |head_tail| |first| |second| ... // \---------/ <--- prev \-----/ <--- prev \------/ // // AFTER: // /---------\ next ---> /------\ // ... |head_tail| |second| ... // \---------/ <--- prev \------/ let mut first = self.head_tail.as_mut().next; let mut second = first.as_mut().next; self.head_tail.as_mut().next = second; second.as_mut().prev = self.head_tail; first.as_mut().next = NonNull::dangling(); first.as_mut().prev = NonNull::dangling(); // unwrap ok: always `Some` on non-dummy entries Some((*first.as_ptr()).value.as_ref().unwrap()) } } /// Removes an entry from the list. /// /// # Safety /// /// The caller must ensure that `entry` has been pushed onto `self` /// prior to this call and has not moved since then. pub unsafe fn remove(&mut self, entry: &mut UnsafeListEntry) { rtassert!(!self.is_empty()); // BEFORE: // /----\ next ---> /-----\ next ---> /----\ // ... |prev| |entry| |next| ... // \----/ <--- prev \-----/ <--- prev \----/ // // AFTER: // /----\ next ---> /----\ // ... |prev| |next| ... // \----/ <--- prev \----/ let mut prev = entry.prev; let mut next = entry.next; prev.as_mut().next = next; next.as_mut().prev = prev; entry.next = NonNull::dangling(); entry.prev = NonNull::dangling(); } } #[cfg(test)] mod tests { use super::*; use crate::cell::Cell; unsafe fn assert_empty(list: &mut UnsafeList) { assert!(list.pop().is_none(), "assertion failed: list is not empty"); } #[test] fn init_empty() { unsafe { assert_empty(&mut UnsafeList::::new()); } } #[test] fn push_pop() { unsafe { let mut node = UnsafeListEntry::new(1234); let mut list = UnsafeList::new(); assert_eq!(list.push(&mut node), &1234); assert_eq!(list.pop().unwrap(), &1234); assert_empty(&mut list); } } #[test] fn push_remove() { unsafe { let mut node = UnsafeListEntry::new(1234); let mut list = UnsafeList::new(); assert_eq!(list.push(&mut node), &1234); list.remove(&mut node); assert_empty(&mut list); } } #[test] fn push_remove_pop() { unsafe { let mut node1 = UnsafeListEntry::new(11); let mut node2 = UnsafeListEntry::new(12); let mut node3 = UnsafeListEntry::new(13); let mut node4 = UnsafeListEntry::new(14); let mut node5 = UnsafeListEntry::new(15); let mut list = UnsafeList::new(); assert_eq!(list.push(&mut node1), &11); assert_eq!(list.push(&mut node2), &12); assert_eq!(list.push(&mut node3), &13); assert_eq!(list.push(&mut node4), &14); assert_eq!(list.push(&mut node5), &15); list.remove(&mut node1); assert_eq!(list.pop().unwrap(), &12); list.remove(&mut node3); assert_eq!(list.pop().unwrap(), &14); list.remove(&mut node5); assert_empty(&mut list); assert_eq!(list.push(&mut node1), &11); assert_eq!(list.pop().unwrap(), &11); assert_empty(&mut list); assert_eq!(list.push(&mut node3), &13); assert_eq!(list.push(&mut node4), &14); list.remove(&mut node3); list.remove(&mut node4); assert_empty(&mut list); } } #[test] fn complex_pushes_pops() { unsafe { let mut node1 = UnsafeListEntry::new(1234); let mut node2 = UnsafeListEntry::new(4567); let mut node3 = UnsafeListEntry::new(9999); let mut node4 = UnsafeListEntry::new(8642); let mut list = UnsafeList::new(); list.push(&mut node1); list.push(&mut node2); assert_eq!(list.pop().unwrap(), &1234); list.push(&mut node3); assert_eq!(list.pop().unwrap(), &4567); assert_eq!(list.pop().unwrap(), &9999); assert_empty(&mut list); list.push(&mut node4); assert_eq!(list.pop().unwrap(), &8642); assert_empty(&mut list); } } #[test] fn cell() { unsafe { let mut node = UnsafeListEntry::new(Cell::new(0)); let mut list = UnsafeList::new(); let noderef = list.push(&mut node); assert_eq!(noderef.get(), 0); list.pop().unwrap().set(1); assert_empty(&mut list); assert_eq!(noderef.get(), 1); } } } } /// Trivial spinlock-based implementation of `sync::Mutex`. // FIXME: Perhaps use Intel TSX to avoid locking? mod spin_mutex { use crate::cell::UnsafeCell; use crate::ops::{Deref, DerefMut}; use crate::sync::atomic::{spin_loop_hint, AtomicBool, Ordering}; #[derive(Default)] pub struct SpinMutex { value: UnsafeCell, lock: AtomicBool, } unsafe impl Send for SpinMutex {} unsafe impl Sync for SpinMutex {} pub struct SpinMutexGuard<'a, T: 'a> { mutex: &'a SpinMutex, } impl<'a, T> !Send for SpinMutexGuard<'a, T> {} unsafe impl<'a, T: Sync> Sync for SpinMutexGuard<'a, T> {} impl SpinMutex { pub const fn new(value: T) -> Self { SpinMutex { value: UnsafeCell::new(value), lock: AtomicBool::new(false) } } #[inline(always)] pub fn lock(&self) -> SpinMutexGuard<'_, T> { loop { match self.try_lock() { None => { while self.lock.load(Ordering::Relaxed) { spin_loop_hint() } } Some(guard) => return guard, } } } #[inline(always)] pub fn try_lock(&self) -> Option> { if !self.lock.compare_and_swap(false, true, Ordering::Acquire) { Some(SpinMutexGuard { mutex: self }) } else { None } } } /// Lock the Mutex or return false. pub macro try_lock_or_false($e:expr) { if let Some(v) = $e.try_lock() { v } else { return false } } impl<'a, T> Deref for SpinMutexGuard<'a, T> { type Target = T; fn deref(&self) -> &T { unsafe { &*self.mutex.value.get() } } } impl<'a, T> DerefMut for SpinMutexGuard<'a, T> { fn deref_mut(&mut self) -> &mut T { unsafe { &mut *self.mutex.value.get() } } } impl<'a, T> Drop for SpinMutexGuard<'a, T> { fn drop(&mut self) { self.mutex.lock.store(false, Ordering::Release) } } #[cfg(test)] mod tests { #![allow(deprecated)] use super::*; use crate::sync::Arc; use crate::thread; use crate::time::Duration; #[test] fn sleep() { let mutex = Arc::new(SpinMutex::::default()); let mutex2 = mutex.clone(); let guard = mutex.lock(); let t1 = thread::spawn(move || { *mutex2.lock() = 1; }); thread::sleep(Duration::from_millis(50)); assert_eq!(*guard, 0); drop(guard); t1.join().unwrap(); assert_eq!(*mutex.lock(), 1); } } } #[cfg(test)] mod tests { use super::*; use crate::sync::Arc; use crate::thread; #[test] fn queue() { let wq = Arc::new(SpinMutex::>::default()); let wq2 = wq.clone(); let locked = wq.lock(); let t1 = thread::spawn(move || { // if we obtain the lock, the main thread should be waiting assert!(WaitQueue::notify_one(wq2.lock()).is_ok()); }); WaitQueue::wait(locked, || {}); t1.join().unwrap(); } }