about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMohsen Zohrevandi <mohsen.zohrevandi@fortanix.com>2021-04-21 13:45:57 -0700
committerMohsen Zohrevandi <mohsen.zohrevandi@fortanix.com>2021-04-21 14:45:45 -0700
commit5d9eeff062053f87ab900fc7ebde6eb13a2a1645 (patch)
treeed2ea8ae02841a34e82ac18f706bb1de8386e6de
parent6df26f897cffb2d86880544bb451c6b5f8509b2d (diff)
downloadrust-5d9eeff062053f87ab900fc7ebde6eb13a2a1645.tar.gz
rust-5d9eeff062053f87ab900fc7ebde6eb13a2a1645.zip
Ensure TLS destructors run before thread joins in SGX
-rw-r--r--library/std/src/sys/sgx/abi/mod.rs6
-rw-r--r--library/std/src/sys/sgx/thread.rs69
-rw-r--r--library/std/src/thread/local/tests.rs55
3 files changed, 119 insertions, 11 deletions
diff --git a/library/std/src/sys/sgx/abi/mod.rs b/library/std/src/sys/sgx/abi/mod.rs
index a5e45303476..f9536c4203d 100644
--- a/library/std/src/sys/sgx/abi/mod.rs
+++ b/library/std/src/sys/sgx/abi/mod.rs
@@ -62,10 +62,12 @@ unsafe extern "C" fn tcs_init(secondary: bool) {
 extern "C" fn entry(p1: u64, p2: u64, p3: u64, secondary: bool, p4: u64, p5: u64) -> EntryReturn {
     // FIXME: how to support TLS in library mode?
     let tls = Box::new(tls::Tls::new());
-    let _tls_guard = unsafe { tls.activate() };
+    let tls_guard = unsafe { tls.activate() };
 
     if secondary {
-        super::thread::Thread::entry();
+        let join_notifier = super::thread::Thread::entry();
+        drop(tls_guard);
+        drop(join_notifier);
 
         EntryReturn(0, 0)
     } else {
diff --git a/library/std/src/sys/sgx/thread.rs b/library/std/src/sys/sgx/thread.rs
index 55ef460cc90..67e2e8b59d3 100644
--- a/library/std/src/sys/sgx/thread.rs
+++ b/library/std/src/sys/sgx/thread.rs
@@ -9,26 +9,37 @@ pub struct Thread(task_queue::JoinHandle);
 
 pub const DEFAULT_MIN_STACK_SIZE: usize = 4096;
 
+pub use self::task_queue::JoinNotifier;
+
 mod task_queue {
-    use crate::sync::mpsc;
+    use super::wait_notify;
     use crate::sync::{Mutex, MutexGuard, Once};
 
-    pub type JoinHandle = mpsc::Receiver<()>;
+    pub type JoinHandle = wait_notify::Waiter;
+
+    pub struct JoinNotifier(Option<wait_notify::Notifier>);
+
+    impl Drop for JoinNotifier {
+        fn drop(&mut self) {
+            self.0.take().unwrap().notify();
+        }
+    }
 
     pub(super) struct Task {
         p: Box<dyn FnOnce()>,
-        done: mpsc::Sender<()>,
+        done: JoinNotifier,
     }
 
     impl Task {
         pub(super) fn new(p: Box<dyn FnOnce()>) -> (Task, JoinHandle) {
-            let (done, recv) = mpsc::channel();
+            let (done, recv) = wait_notify::new();
+            let done = JoinNotifier(Some(done));
             (Task { p, done }, recv)
         }
 
-        pub(super) fn run(self) {
+        pub(super) fn run(self) -> JoinNotifier {
             (self.p)();
-            let _ = self.done.send(());
+            self.done
         }
     }
 
@@ -47,6 +58,48 @@ mod task_queue {
     }
 }
 
+/// This module provides a synchronization primitive that does not use thread
+/// local variables. This is needed for signaling that a thread has finished
+/// execution. The signal is sent once all TLS destructors have finished at
+/// which point no new thread locals should be created.
+pub mod wait_notify {
+    use super::super::waitqueue::{SpinMutex, WaitQueue, WaitVariable};
+    use crate::sync::Arc;
+
+    pub struct Notifier(Arc<SpinMutex<WaitVariable<bool>>>);
+
+    impl Notifier {
+        /// Notify the waiter. The waiter is either notified right away (if
+        /// currently blocked in `Waiter::wait()`) or later when it calls the
+        /// `Waiter::wait()` method.
+        pub fn notify(self) {
+            let mut guard = self.0.lock();
+            *guard.lock_var_mut() = true;
+            let _ = WaitQueue::notify_one(guard);
+        }
+    }
+
+    pub struct Waiter(Arc<SpinMutex<WaitVariable<bool>>>);
+
+    impl Waiter {
+        /// Wait for a notification. If `Notifier::notify()` has already been
+        /// called, this will return immediately, otherwise the current thread
+        /// is blocked until notified.
+        pub fn wait(self) {
+            let guard = self.0.lock();
+            if *guard.lock_var() {
+                return;
+            }
+            WaitQueue::wait(guard, || {});
+        }
+    }
+
+    pub fn new() -> (Notifier, Waiter) {
+        let inner = Arc::new(SpinMutex::new(WaitVariable::new(false)));
+        (Notifier(inner.clone()), Waiter(inner))
+    }
+}
+
 impl Thread {
     // unsafe: see thread::Builder::spawn_unchecked for safety requirements
     pub unsafe fn new(_stack: usize, p: Box<dyn FnOnce()>) -> io::Result<Thread> {
@@ -57,7 +110,7 @@ impl Thread {
         Ok(Thread(handle))
     }
 
-    pub(super) fn entry() {
+    pub(super) fn entry() -> JoinNotifier {
         let mut pending_tasks = task_queue::lock();
         let task = rtunwrap!(Some, pending_tasks.pop());
         drop(pending_tasks); // make sure to not hold the task queue lock longer than necessary
@@ -78,7 +131,7 @@ impl Thread {
     }
 
     pub fn join(self) {
-        let _ = self.0.recv();
+        self.0.wait();
     }
 }
 
diff --git a/library/std/src/thread/local/tests.rs b/library/std/src/thread/local/tests.rs
index 80e6798d847..98f525eafb0 100644
--- a/library/std/src/thread/local/tests.rs
+++ b/library/std/src/thread/local/tests.rs
@@ -1,5 +1,6 @@
 use crate::cell::{Cell, UnsafeCell};
-use crate::sync::mpsc::{channel, Sender};
+use crate::sync::atomic::{AtomicBool, Ordering};
+use crate::sync::mpsc::{self, channel, Sender};
 use crate::thread::{self, LocalKey};
 use crate::thread_local;
 
@@ -207,3 +208,55 @@ fn dtors_in_dtors_in_dtors_const_init() {
     });
     rx.recv().unwrap();
 }
+
+// This test tests that TLS destructors have run before the thread joins. The
+// test has no false positives (meaning: if the test fails, there's actually
+// an ordering problem). It may have false negatives, where the test passes but
+// join is not guaranteed to be after the TLS destructors. However, false
+// negatives should be exceedingly rare due to judicious use of
+// thread::yield_now and running the test several times.
+#[test]
+fn join_orders_after_tls_destructors() {
+    static THREAD2_LAUNCHED: AtomicBool = AtomicBool::new(false);
+
+    for _ in 0..10 {
+        let (tx, rx) = mpsc::sync_channel(0);
+        THREAD2_LAUNCHED.store(false, Ordering::SeqCst);
+
+        let jh = thread::spawn(move || {
+            struct RecvOnDrop(Cell<Option<mpsc::Receiver<()>>>);
+
+            impl Drop for RecvOnDrop {
+                fn drop(&mut self) {
+                    let rx = self.0.take().unwrap();
+                    while !THREAD2_LAUNCHED.load(Ordering::SeqCst) {
+                        thread::yield_now();
+                    }
+                    rx.recv().unwrap();
+                }
+            }
+
+            thread_local! {
+                static TL_RX: RecvOnDrop = RecvOnDrop(Cell::new(None));
+            }
+
+            TL_RX.with(|v| v.0.set(Some(rx)))
+        });
+
+        let tx_clone = tx.clone();
+        let jh2 = thread::spawn(move || {
+            THREAD2_LAUNCHED.store(true, Ordering::SeqCst);
+            jh.join().unwrap();
+            tx_clone.send(()).expect_err(
+                "Expecting channel to be closed because thread 1 TLS destructors must've run",
+            );
+        });
+
+        while !THREAD2_LAUNCHED.load(Ordering::SeqCst) {
+            thread::yield_now();
+        }
+        thread::yield_now();
+        tx.send(()).expect("Expecting channel to be live because thread 2 must block on join");
+        jh2.join().unwrap();
+    }
+}