about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/std/src/sys/exit_guard.rs43
-rw-r--r--library/std/src/sys/pal/unix/os.rs2
2 files changed, 21 insertions, 24 deletions
diff --git a/library/std/src/sys/exit_guard.rs b/library/std/src/sys/exit_guard.rs
index 5a090f50666..bd70d178244 100644
--- a/library/std/src/sys/exit_guard.rs
+++ b/library/std/src/sys/exit_guard.rs
@@ -1,14 +1,5 @@
 cfg_if::cfg_if! {
     if #[cfg(target_os = "linux")] {
-        /// pthread_t is a pointer on some platforms,
-        /// so we wrap it in this to impl Send + Sync.
-        #[derive(Clone, Copy)]
-        #[repr(transparent)]
-        struct PThread(libc::pthread_t);
-        // Safety: pthread_t is safe to send between threads
-        unsafe impl Send for PThread {}
-        // Safety: pthread_t is safe to share between threads
-        unsafe impl Sync for PThread {}
         /// Mitigation for <https://github.com/rust-lang/rust/issues/126600>
         ///
         /// On glibc, `libc::exit` has been observed to not always be thread-safe.
@@ -30,28 +21,34 @@ cfg_if::cfg_if! {
         ///   (waiting for the process to exit).
         #[cfg_attr(any(test, doctest), allow(dead_code))]
         pub(crate) fn unique_thread_exit() {
-            let this_thread_id = unsafe { libc::pthread_self() };
-            use crate::sync::{Mutex, PoisonError};
-            static EXITING_THREAD_ID: Mutex<Option<PThread>> = Mutex::new(None);
-            let mut exiting_thread_id =
-                EXITING_THREAD_ID.lock().unwrap_or_else(PoisonError::into_inner);
-            match *exiting_thread_id {
-                None => {
+            use crate::ffi::c_int;
+            use crate::ptr;
+            use crate::sync::atomic::AtomicPtr;
+            use crate::sync::atomic::Ordering::{Acquire, Relaxed};
+
+            static EXITING_THREAD_ID: AtomicPtr<c_int> = AtomicPtr::new(ptr::null_mut());
+
+            // We use the address of `errno` as a cheap and safe way to identify
+            // threads. As the C standard mandates that `errno` must have thread
+            // storage duration, we can rely on its address not changing over the
+            // lifetime of the thread. Additionally, accesses to `errno` are
+            // async-signal-safe, so this function is available in all imaginable
+            // circumstances.
+            let this_thread_id = crate::sys::os::errno_location();
+            match EXITING_THREAD_ID.compare_exchange(ptr::null_mut(), this_thread_id, Acquire, Relaxed) {
+                Ok(_) => {
                     // This is the first thread to call `unique_thread_exit`,
-                    // and this is the first time it is called.
-                    // Set EXITING_THREAD_ID to this thread's ID and return.
-                    *exiting_thread_id = Some(PThread(this_thread_id));
-                },
-                Some(exiting_thread_id) if exiting_thread_id.0 == this_thread_id => {
+                    // and this is the first time it is called. Continue exiting.
+                }
+                Err(exiting_thread_id) if exiting_thread_id == this_thread_id => {
                     // This is the first thread to call `unique_thread_exit`,
                     // but this is the second time it is called.
                     // Abort the process.
                     core::panicking::panic_nounwind("std::process::exit called re-entrantly")
                 }
-                Some(_) => {
+                Err(_) => {
                     // This is not the first thread to call `unique_thread_exit`.
                     // Pause until the process exits.
-                    drop(exiting_thread_id);
                     loop {
                         // Safety: libc::pause is safe to call.
                         unsafe { libc::pause(); }
diff --git a/library/std/src/sys/pal/unix/os.rs b/library/std/src/sys/pal/unix/os.rs
index 4883303b88e..48609030aed 100644
--- a/library/std/src/sys/pal/unix/os.rs
+++ b/library/std/src/sys/pal/unix/os.rs
@@ -56,7 +56,7 @@ unsafe extern "C" {
     #[cfg_attr(target_os = "aix", link_name = "_Errno")]
     // SAFETY: this will always return the same pointer on a given thread.
     #[unsafe(ffi_const)]
-    fn errno_location() -> *mut c_int;
+    pub safe fn errno_location() -> *mut c_int;
 }
 
 /// Returns the platform-specific value of errno