summary refs log tree commit diff
path: root/library/std/src
diff options
context:
space:
mode:
Diffstat (limited to 'library/std/src')
-rw-r--r--library/std/src/sync/reentrant_lock.rs138
-rw-r--r--library/std/src/thread/mod.rs32
2 files changed, 144 insertions, 26 deletions
diff --git a/library/std/src/sync/reentrant_lock.rs b/library/std/src/sync/reentrant_lock.rs
index 042c439394e..84a0b36db17 100644
--- a/library/std/src/sync/reentrant_lock.rs
+++ b/library/std/src/sync/reentrant_lock.rs
@@ -1,12 +1,14 @@
 #[cfg(all(test, not(target_os = "emscripten")))]
 mod tests;
 
+use cfg_if::cfg_if;
+
 use crate::cell::UnsafeCell;
 use crate::fmt;
 use crate::ops::Deref;
 use crate::panic::{RefUnwindSafe, UnwindSafe};
-use crate::sync::atomic::{AtomicUsize, Ordering::Relaxed};
 use crate::sys::sync as sys;
+use crate::thread::{current_id, ThreadId};
 
 /// A re-entrant mutual exclusion lock
 ///
@@ -53,8 +55,8 @@ use crate::sys::sync as sys;
 //
 // The 'owner' field tracks which thread has locked the mutex.
 //
-// We use current_thread_unique_ptr() as the thread identifier,
-// which is just the address of a thread local variable.
+// We use thread::current_id() as the thread identifier, which is just the
+// current thread's ThreadId, so it's unique across the process lifetime.
 //
 // If `owner` is set to the identifier of the current thread,
 // we assume the mutex is already locked and instead of locking it again,
@@ -72,14 +74,109 @@ use crate::sys::sync as sys;
 // since we're not dealing with multiple threads. If it's not equal,
 // synchronization is left to the mutex, making relaxed memory ordering for
 // the `owner` field fine in all cases.
+//
+// On systems without 64 bit atomics we also store the address of a TLS variable
+// along the 64-bit TID. We then first check that address against the address
+// of that variable on the current thread, and only if they compare equal do we
+// compare the actual TIDs. Because we only ever read the TID on the same thread
+// that it was written on (or a thread sharing the TLS block with that writer thread),
+// we don't need to further synchronize the TID accesses, so they can be regular 64-bit
+// non-atomic accesses.
 #[unstable(feature = "reentrant_lock", issue = "121440")]
 pub struct ReentrantLock<T: ?Sized> {
     mutex: sys::Mutex,
-    owner: AtomicUsize,
+    owner: Tid,
     lock_count: UnsafeCell<u32>,
     data: T,
 }
 
+cfg_if!(
+    if #[cfg(target_has_atomic = "64")] {
+        use crate::sync::atomic::{AtomicU64, Ordering::Relaxed};
+
+        struct Tid(AtomicU64);
+
+        impl Tid {
+            const fn new() -> Self {
+                Self(AtomicU64::new(0))
+            }
+
+            #[inline]
+            fn contains(&self, owner: ThreadId) -> bool {
+                owner.as_u64().get() == self.0.load(Relaxed)
+            }
+
+            #[inline]
+            // This is just unsafe to match the API of the Tid type below.
+            unsafe fn set(&self, tid: Option<ThreadId>) {
+                let value = tid.map_or(0, |tid| tid.as_u64().get());
+                self.0.store(value, Relaxed);
+            }
+        }
+    } else {
+        /// Returns the address of a TLS variable. This is guaranteed to
+        /// be unique across all currently alive threads.
+        fn tls_addr() -> usize {
+            thread_local! { static X: u8 = const { 0u8 } };
+
+            X.with(|p| <*const u8>::addr(p))
+        }
+
+        use crate::sync::atomic::{
+            AtomicUsize,
+            Ordering,
+        };
+
+        struct Tid {
+            // When a thread calls `set()`, this value gets updated to
+            // the address of a thread local on that thread. This is
+            // used as a first check in `contains()`; if the `tls_addr`
+            // doesn't match the TLS address of the current thread, then
+            // the ThreadId also can't match. Only if the TLS addresses do
+            // match do we read out the actual TID.
+            // Note also that we can use relaxed atomic operations here, because
+            // we only ever read from the tid if `tls_addr` matches the current
+            // TLS address. In that case, either the the tid has been set by
+            // the current thread, or by a thread that has terminated before
+            // the current thread was created. In either case, no further
+            // synchronization is needed (as per <https://github.com/rust-lang/miri/issues/3450>)
+            tls_addr: AtomicUsize,
+            tid: UnsafeCell<u64>,
+        }
+
+        unsafe impl Send for Tid {}
+        unsafe impl Sync for Tid {}
+
+        impl Tid {
+            const fn new() -> Self {
+                Self { tls_addr: AtomicUsize::new(0), tid: UnsafeCell::new(0) }
+            }
+
+            #[inline]
+            // NOTE: This assumes that `owner` is the ID of the current
+            // thread, and may spuriously return `false` if that's not the case.
+            fn contains(&self, owner: ThreadId) -> bool {
+                // SAFETY: See the comments in the struct definition.
+                self.tls_addr.load(Ordering::Relaxed) == tls_addr()
+                    && unsafe { *self.tid.get() } == owner.as_u64().get()
+            }
+
+            #[inline]
+            // This may only be called by one thread at a time, and can lead to
+            // race conditions otherwise.
+            unsafe fn set(&self, tid: Option<ThreadId>) {
+                // It's important that we set `self.tls_addr` to 0 if the tid is
+                // cleared. Otherwise, there might be race conditions between
+                // `set()` and `get()`.
+                let tls_addr = if tid.is_some() { tls_addr() } else { 0 };
+                let value = tid.map_or(0, |tid| tid.as_u64().get());
+                self.tls_addr.store(tls_addr, Ordering::Relaxed);
+                unsafe { *self.tid.get() = value };
+            }
+        }
+    }
+);
+
 #[unstable(feature = "reentrant_lock", issue = "121440")]
 unsafe impl<T: Send + ?Sized> Send for ReentrantLock<T> {}
 #[unstable(feature = "reentrant_lock", issue = "121440")]
@@ -134,7 +231,7 @@ impl<T> ReentrantLock<T> {
     pub const fn new(t: T) -> ReentrantLock<T> {
         ReentrantLock {
             mutex: sys::Mutex::new(),
-            owner: AtomicUsize::new(0),
+            owner: Tid::new(),
             lock_count: UnsafeCell::new(0),
             data: t,
         }
@@ -184,14 +281,16 @@ impl<T: ?Sized> ReentrantLock<T> {
     /// assert_eq!(lock.lock().get(), 10);
     /// ```
     pub fn lock(&self) -> ReentrantLockGuard<'_, T> {
-        let this_thread = current_thread_unique_ptr();
-        // Safety: We only touch lock_count when we own the lock.
+        let this_thread = current_id();
+        // Safety: We only touch lock_count when we own the inner mutex.
+        // Additionally, we only call `self.owner.set()` while holding
+        // the inner mutex, so no two threads can call it concurrently.
         unsafe {
-            if self.owner.load(Relaxed) == this_thread {
+            if self.owner.contains(this_thread) {
                 self.increment_lock_count().expect("lock count overflow in reentrant mutex");
             } else {
                 self.mutex.lock();
-                self.owner.store(this_thread, Relaxed);
+                self.owner.set(Some(this_thread));
                 debug_assert_eq!(*self.lock_count.get(), 0);
                 *self.lock_count.get() = 1;
             }
@@ -226,14 +325,16 @@ impl<T: ?Sized> ReentrantLock<T> {
     ///
     /// This function does not block.
     pub(crate) fn try_lock(&self) -> Option<ReentrantLockGuard<'_, T>> {
-        let this_thread = current_thread_unique_ptr();
-        // Safety: We only touch lock_count when we own the lock.
+        let this_thread = current_id();
+        // Safety: We only touch lock_count when we own the inner mutex.
+        // Additionally, we only call `self.owner.set()` while holding
+        // the inner mutex, so no two threads can call it concurrently.
         unsafe {
-            if self.owner.load(Relaxed) == this_thread {
+            if self.owner.contains(this_thread) {
                 self.increment_lock_count()?;
                 Some(ReentrantLockGuard { lock: self })
             } else if self.mutex.try_lock() {
-                self.owner.store(this_thread, Relaxed);
+                self.owner.set(Some(this_thread));
                 debug_assert_eq!(*self.lock_count.get(), 0);
                 *self.lock_count.get() = 1;
                 Some(ReentrantLockGuard { lock: self })
@@ -308,18 +409,9 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
         unsafe {
             *self.lock.lock_count.get() -= 1;
             if *self.lock.lock_count.get() == 0 {
-                self.lock.owner.store(0, Relaxed);
+                self.lock.owner.set(None);
                 self.lock.mutex.unlock();
             }
         }
     }
 }
-
-/// Get an address that is unique per running thread.
-///
-/// This can be used as a non-null usize-sized ID.
-pub(crate) fn current_thread_unique_ptr() -> usize {
-    // Use a non-drop type to make sure it's still available during thread destruction.
-    thread_local! { static X: u8 = const { 0 } }
-    X.with(|x| <*const _>::addr(x))
-}
diff --git a/library/std/src/thread/mod.rs b/library/std/src/thread/mod.rs
index c8ee365392f..7b98f2ae763 100644
--- a/library/std/src/thread/mod.rs
+++ b/library/std/src/thread/mod.rs
@@ -159,7 +159,7 @@
 mod tests;
 
 use crate::any::Any;
-use crate::cell::{OnceCell, UnsafeCell};
+use crate::cell::{Cell, OnceCell, UnsafeCell};
 use crate::env;
 use crate::ffi::{CStr, CString};
 use crate::fmt;
@@ -699,17 +699,22 @@ where
 }
 
 thread_local! {
+    // Invariant: `CURRENT` and `CURRENT_ID` will always be initialized together.
+    // If `CURRENT` is initialized, then `CURRENT_ID` will hold the same value
+    // as `CURRENT.id()`.
     static CURRENT: OnceCell<Thread> = const { OnceCell::new() };
+    static CURRENT_ID: Cell<Option<ThreadId>> = const { Cell::new(None) };
 }
 
 /// Sets the thread handle for the current thread.
 ///
 /// Aborts if the handle has been set already to reduce code size.
 pub(crate) fn set_current(thread: Thread) {
+    let tid = thread.id();
     // Using `unwrap` here can add ~3kB to the binary size. We have complete
     // control over where this is called, so just abort if there is a bug.
     CURRENT.with(|current| match current.set(thread) {
-        Ok(()) => {}
+        Ok(()) => CURRENT_ID.set(Some(tid)),
         Err(_) => rtabort!("thread::set_current should only be called once per thread"),
     });
 }
@@ -719,7 +724,28 @@ pub(crate) fn set_current(thread: Thread) {
 /// In contrast to the public `current` function, this will not panic if called
 /// from inside a TLS destructor.
 pub(crate) fn try_current() -> Option<Thread> {
-    CURRENT.try_with(|current| current.get_or_init(|| Thread::new_unnamed()).clone()).ok()
+    CURRENT
+        .try_with(|current| {
+            current
+                .get_or_init(|| {
+                    let thread = Thread::new_unnamed();
+                    CURRENT_ID.set(Some(thread.id()));
+                    thread
+                })
+                .clone()
+        })
+        .ok()
+}
+
+/// Gets the id of the thread that invokes it.
+#[inline]
+pub(crate) fn current_id() -> ThreadId {
+    CURRENT_ID.get().unwrap_or_else(|| {
+        // If `CURRENT_ID` isn't initialized yet, then `CURRENT` must also not be initialized.
+        // `current()` will initialize both `CURRENT` and `CURRENT_ID` so subsequent calls to
+        // `current_id()` will succeed immediately.
+        current().id()
+    })
 }
 
 /// Gets a handle to the thread that invokes it.