about summary refs log tree commit diff
path: root/library
diff options
context:
space:
mode:
Diffstat (limited to 'library')
-rw-r--r--library/std/src/sys/thread_local/native/lazy.rs80
1 files changed, 51 insertions, 29 deletions
diff --git a/library/std/src/sys/thread_local/native/lazy.rs b/library/std/src/sys/thread_local/native/lazy.rs
index 51294285ba0..b556dd9aa25 100644
--- a/library/std/src/sys/thread_local/native/lazy.rs
+++ b/library/std/src/sys/thread_local/native/lazy.rs
@@ -1,9 +1,9 @@
-use crate::cell::UnsafeCell;
-use crate::hint::unreachable_unchecked;
+use crate::cell::{Cell, UnsafeCell};
+use crate::mem::MaybeUninit;
 use crate::ptr;
 use crate::sys::thread_local::{abort_on_dtor_unwind, destructors};
 
-pub unsafe trait DestroyedState: Sized {
+pub unsafe trait DestroyedState: Sized + Copy {
     fn register_dtor<T>(s: &Storage<T, Self>);
 }
 
@@ -19,15 +19,17 @@ unsafe impl DestroyedState for () {
     }
 }
 
-enum State<T, D> {
-    Initial,
-    Alive(T),
+#[derive(Copy, Clone)]
+enum State<D> {
+    Uninitialized,
+    Alive,
     Destroyed(D),
 }
 
 #[allow(missing_debug_implementations)]
 pub struct Storage<T, D> {
-    state: UnsafeCell<State<T, D>>,
+    state: Cell<State<D>>,
+    value: UnsafeCell<MaybeUninit<T>>,
 }
 
 impl<T, D> Storage<T, D>
@@ -35,7 +37,10 @@ where
     D: DestroyedState,
 {
     pub const fn new() -> Storage<T, D> {
-        Storage { state: UnsafeCell::new(State::Initial) }
+        Storage {
+            state: Cell::new(State::Uninitialized),
+            value: UnsafeCell::new(MaybeUninit::uninit()),
+        }
     }
 
     /// Gets a pointer to the TLS value, potentially initializing it with the
@@ -49,35 +54,49 @@ where
     /// The `self` reference must remain valid until the TLS destructor is run.
     #[inline]
     pub unsafe fn get_or_init(&self, i: Option<&mut Option<T>>, f: impl FnOnce() -> T) -> *const T {
-        let state = unsafe { &*self.state.get() };
-        match state {
-            State::Alive(v) => v,
-            State::Destroyed(_) => ptr::null(),
-            State::Initial => unsafe { self.initialize(i, f) },
+        if let State::Alive = self.state.get() {
+            self.value.get().cast()
+        } else {
+            unsafe { self.get_or_init_slow(i, f) }
         }
     }
 
+    /// # Safety
+    /// The `self` reference must remain valid until the TLS destructor is run.
     #[cold]
-    unsafe fn initialize(&self, i: Option<&mut Option<T>>, f: impl FnOnce() -> T) -> *const T {
-        // Perform initialization
+    unsafe fn get_or_init_slow(
+        &self,
+        i: Option<&mut Option<T>>,
+        f: impl FnOnce() -> T,
+    ) -> *const T {
+        match self.state.get() {
+            State::Uninitialized => {}
+            State::Alive => return self.value.get().cast(),
+            State::Destroyed(_) => return ptr::null(),
+        }
 
         let v = i.and_then(Option::take).unwrap_or_else(f);
 
-        let old = unsafe { self.state.get().replace(State::Alive(v)) };
-        match old {
+        // SAFETY: we cannot be inside a `LocalKey::with` scope, as the initializer
+        // has already returned and the next scope only starts after we return
+        // the pointer. Therefore, there can be no references to the old value,
+        // even if it was initialized. Thus because we are !Sync we have exclusive
+        // access to self.value and may replace it.
+        let mut old_value = unsafe { self.value.get().replace(MaybeUninit::new(v)) };
+        match self.state.replace(State::Alive) {
             // If the variable is not being recursively initialized, register
             // the destructor. This might be a noop if the value does not need
             // destruction.
-            State::Initial => D::register_dtor(self),
-            // Else, drop the old value. This might be changed to a panic.
-            val => drop(val),
-        }
+            State::Uninitialized => D::register_dtor(self),
 
-        // SAFETY: the state was just set to `Alive`
-        unsafe {
-            let State::Alive(v) = &*self.state.get() else { unreachable_unchecked() };
-            v
+            // Recursive initialization, we only need to drop the old value
+            // as we've already registered the destructor.
+            State::Alive => unsafe { old_value.assume_init_drop() },
+
+            State::Destroyed(_) => unreachable!(),
         }
+
+        self.value.get().cast()
     }
 }
 
@@ -92,9 +111,12 @@ unsafe extern "C" fn destroy<T>(ptr: *mut u8) {
     // Print a nice abort message if a panic occurs.
     abort_on_dtor_unwind(|| {
         let storage = unsafe { &*(ptr as *const Storage<T, ()>) };
-        // Update the state before running the destructor as it may attempt to
-        // access the variable.
-        let val = unsafe { storage.state.get().replace(State::Destroyed(())) };
-        drop(val);
+        if let State::Alive = storage.state.replace(State::Destroyed(())) {
+            // SAFETY: we ensured the state was Alive so the value was initialized.
+            // We also updated the state to Destroyed to prevent the destructor
+            // from accessing the thread-local variable, as this would violate
+            // the exclusive access provided by &mut T in Drop::drop.
+            unsafe { (*storage.value.get()).assume_init_drop() }
+        }
     })
 }