about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-04-27 17:43:09 +0000
committerbors <bors@rust-lang.org>2023-04-27 17:43:09 +0000
commitc14882f74e8feb3f76ae85ed5cd66afaccd1da67 (patch)
treee893ce45bb93289936dfb2a657337cae82d7f090
parent901fdb3b04375e3456b5cf771f86ecca8d6c1917 (diff)
parentefe7cf468f39e810b22e2779f4fafe976e74e563 (diff)
downloadrust-c14882f74e8feb3f76ae85ed5cd66afaccd1da67.tar.gz
rust-c14882f74e8feb3f76ae85ed5cd66afaccd1da67.zip
Auto merge of #107782 - Zoxc:worker-local, r=cjgillot
Move the WorkerLocal type from the rustc-rayon fork into rustc_data_structures

This PR moves the definition of the `WorkerLocal` type from `rustc-rayon` into `rustc_data_structures`. This is enabled by the introduction of the `Registry` type which allows you to group up threads to be used by `WorkerLocal` which is basically just an array with an per thread index. The `Registry` type mirrors the one in Rayon and each Rayon worker thread is also registered with the new `Registry`. Safety for `WorkerLocal` is ensured by having it keep a reference to the registry and checking on each access that we're still on the group of threads associated with the registry used to construct it.

Accessing a `WorkerLocal` is micro-optimized due to it being hot since it's used for most arena allocations.

Performance is slightly improved for the parallel compiler:
<table><tr><td rowspan="2">Benchmark</td><td colspan="1"><b>Before</b></th><td colspan="2"><b>After</b></th></tr><tr><td align="right">Time</td><td align="right">Time</td><td align="right">%</th></tr><tr><td>🟣 <b>clap</b>:check</td><td align="right">1.9992s</td><td align="right">1.9949s</td><td align="right"> -0.21%</td></tr><tr><td>🟣 <b>hyper</b>:check</td><td align="right">0.2977s</td><td align="right">0.2970s</td><td align="right"> -0.22%</td></tr><tr><td>🟣 <b>regex</b>:check</td><td align="right">1.1335s</td><td align="right">1.1315s</td><td align="right"> -0.18%</td></tr><tr><td>🟣 <b>syn</b>:check</td><td align="right">1.8235s</td><td align="right">1.8171s</td><td align="right"> -0.35%</td></tr><tr><td>🟣 <b>syntex_syntax</b>:check</td><td align="right">6.9047s</td><td align="right">6.8930s</td><td align="right"> -0.17%</td></tr><tr><td>Total</td><td align="right">12.1586s</td><td align="right">12.1336s</td><td align="right"> -0.21%</td></tr><tr><td>Summary</td><td align="right">1.0000s</td><td align="right">0.9977s</td><td align="right"> -0.23%</td></tr></table>

cc `@SparrowLii`
-rw-r--r--compiler/rustc_ast/src/attr/mod.rs36
-rw-r--r--compiler/rustc_data_structures/src/sharded.rs6
-rw-r--r--compiler/rustc_data_structures/src/sync.rs36
-rw-r--r--compiler/rustc_data_structures/src/sync/worker_local.rs180
-rw-r--r--compiler/rustc_interface/src/util.rs6
5 files changed, 198 insertions, 66 deletions
diff --git a/compiler/rustc_ast/src/attr/mod.rs b/compiler/rustc_ast/src/attr/mod.rs
index c4771115cac..e6c4db9e2ae 100644
--- a/compiler/rustc_ast/src/attr/mod.rs
+++ b/compiler/rustc_ast/src/attr/mod.rs
@@ -10,15 +10,10 @@ use crate::tokenstream::{DelimSpan, Spacing, TokenTree};
 use crate::tokenstream::{LazyAttrTokenStream, TokenStream};
 use crate::util::comments;
 use crate::util::literal::escape_string_symbol;
-use rustc_data_structures::sync::WorkerLocal;
 use rustc_index::bit_set::GrowableBitSet;
 use rustc_span::symbol::{sym, Ident, Symbol};
 use rustc_span::Span;
-use std::cell::Cell;
 use std::iter;
-#[cfg(debug_assertions)]
-use std::ops::BitXor;
-#[cfg(debug_assertions)]
 use std::sync::atomic::{AtomicU32, Ordering};
 use thin_vec::{thin_vec, ThinVec};
 
@@ -40,39 +35,16 @@ impl MarkedAttrs {
     }
 }
 
-pub struct AttrIdGenerator(WorkerLocal<Cell<u32>>);
-
-#[cfg(debug_assertions)]
-static MAX_ATTR_ID: AtomicU32 = AtomicU32::new(u32::MAX);
+pub struct AttrIdGenerator(AtomicU32);
 
 impl AttrIdGenerator {
     pub fn new() -> Self {
-        // We use `(index as u32).reverse_bits()` to initialize the
-        // starting value of AttrId in each worker thread.
-        // The `index` is the index of the worker thread.
-        // This ensures that the AttrId generated in each thread is unique.
-        AttrIdGenerator(WorkerLocal::new(|index| {
-            let index: u32 = index.try_into().unwrap();
-
-            #[cfg(debug_assertions)]
-            {
-                let max_id = ((index + 1).next_power_of_two() - 1).bitxor(u32::MAX).reverse_bits();
-                MAX_ATTR_ID.fetch_min(max_id, Ordering::Release);
-            }
-
-            Cell::new(index.reverse_bits())
-        }))
+        AttrIdGenerator(AtomicU32::new(0))
     }
 
     pub fn mk_attr_id(&self) -> AttrId {
-        let id = self.0.get();
-
-        // Ensure the assigned attr_id does not overlap the bits
-        // representing the number of threads.
-        #[cfg(debug_assertions)]
-        assert!(id <= MAX_ATTR_ID.load(Ordering::Acquire));
-
-        self.0.set(id + 1);
+        let id = self.0.fetch_add(1, Ordering::Relaxed);
+        assert!(id != u32::MAX);
         AttrId::from_u32(id)
     }
 }
diff --git a/compiler/rustc_data_structures/src/sharded.rs b/compiler/rustc_data_structures/src/sharded.rs
index bd7a86f6780..7ed70ba1e0f 100644
--- a/compiler/rustc_data_structures/src/sharded.rs
+++ b/compiler/rustc_data_structures/src/sharded.rs
@@ -1,14 +1,10 @@
 use crate::fx::{FxHashMap, FxHasher};
-use crate::sync::{Lock, LockGuard};
+use crate::sync::{CacheAligned, Lock, LockGuard};
 use std::borrow::Borrow;
 use std::collections::hash_map::RawEntryMut;
 use std::hash::{Hash, Hasher};
 use std::mem;
 
-#[derive(Default)]
-#[cfg_attr(parallel_compiler, repr(align(64)))]
-struct CacheAligned<T>(T);
-
 #[cfg(parallel_compiler)]
 // 32 shards is sufficient to reduce contention on an 8-core Ryzen 7 1700,
 // but this should be tested on higher core count CPUs. How the `Sharded` type gets used
diff --git a/compiler/rustc_data_structures/src/sync.rs b/compiler/rustc_data_structures/src/sync.rs
index ef1da85198f..e73ca56efa0 100644
--- a/compiler/rustc_data_structures/src/sync.rs
+++ b/compiler/rustc_data_structures/src/sync.rs
@@ -45,6 +45,9 @@ use std::hash::{BuildHasher, Hash};
 use std::ops::{Deref, DerefMut};
 use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
 
+mod worker_local;
+pub use worker_local::{Registry, WorkerLocal};
+
 pub use std::sync::atomic::Ordering;
 pub use std::sync::atomic::Ordering::SeqCst;
 
@@ -205,33 +208,6 @@ cfg_if! {
 
         use std::cell::Cell;
 
-        #[derive(Debug)]
-        pub struct WorkerLocal<T>(OneThread<T>);
-
-        impl<T> WorkerLocal<T> {
-            /// Creates a new worker local where the `initial` closure computes the
-            /// value this worker local should take for each thread in the thread pool.
-            #[inline]
-            pub fn new<F: FnMut(usize) -> T>(mut f: F) -> WorkerLocal<T> {
-                WorkerLocal(OneThread::new(f(0)))
-            }
-
-            /// Returns the worker-local value for each thread
-            #[inline]
-            pub fn into_inner(self) -> Vec<T> {
-                vec![OneThread::into_inner(self.0)]
-            }
-        }
-
-        impl<T> Deref for WorkerLocal<T> {
-            type Target = T;
-
-            #[inline(always)]
-            fn deref(&self) -> &T {
-                &self.0
-            }
-        }
-
         pub type MTLockRef<'a, T> = &'a mut MTLock<T>;
 
         #[derive(Debug, Default)]
@@ -351,8 +327,6 @@ cfg_if! {
             };
         }
 
-        pub use rayon_core::WorkerLocal;
-
         pub use rayon::iter::ParallelIterator;
         use rayon::iter::IntoParallelIterator;
 
@@ -383,6 +357,10 @@ pub fn assert_send<T: ?Sized + Send>() {}
 pub fn assert_send_val<T: ?Sized + Send>(_t: &T) {}
 pub fn assert_send_sync_val<T: ?Sized + Sync + Send>(_t: &T) {}
 
+#[derive(Default)]
+#[cfg_attr(parallel_compiler, repr(align(64)))]
+pub struct CacheAligned<T>(pub T);
+
 pub trait HashMapExt<K, V> {
     /// Same as HashMap::insert, but it may panic if there's already an
     /// entry for `key` with a value not equal to `value`
diff --git a/compiler/rustc_data_structures/src/sync/worker_local.rs b/compiler/rustc_data_structures/src/sync/worker_local.rs
new file mode 100644
index 00000000000..bfb04ba8a73
--- /dev/null
+++ b/compiler/rustc_data_structures/src/sync/worker_local.rs
@@ -0,0 +1,180 @@
+use crate::sync::Lock;
+use std::cell::Cell;
+use std::cell::OnceCell;
+use std::ops::Deref;
+use std::ptr;
+use std::sync::Arc;
+
+#[cfg(parallel_compiler)]
+use {crate::cold_path, crate::sync::CacheAligned};
+
+/// A pointer to the `RegistryData` which uniquely identifies a registry.
+/// This identifier can be reused if the registry gets freed.
+#[derive(Clone, Copy, PartialEq)]
+struct RegistryId(*const RegistryData);
+
+impl RegistryId {
+    #[inline(always)]
+    /// Verifies that the current thread is associated with the registry and returns its unique
+    /// index within the registry. This panics if the current thread is not associated with this
+    /// registry.
+    ///
+    /// Note that there's a race possible where the identifer in `THREAD_DATA` could be reused
+    /// so this can succeed from a different registry.
+    #[cfg(parallel_compiler)]
+    fn verify(self) -> usize {
+        let (id, index) = THREAD_DATA.with(|data| (data.registry_id.get(), data.index.get()));
+
+        if id == self {
+            index
+        } else {
+            cold_path(|| panic!("Unable to verify registry association"))
+        }
+    }
+}
+
+struct RegistryData {
+    thread_limit: usize,
+    threads: Lock<usize>,
+}
+
+/// Represents a list of threads which can access worker locals.
+#[derive(Clone)]
+pub struct Registry(Arc<RegistryData>);
+
+thread_local! {
+    /// The registry associated with the thread.
+    /// This allows the `WorkerLocal` type to clone the registry in its constructor.
+    static REGISTRY: OnceCell<Registry> = OnceCell::new();
+}
+
+struct ThreadData {
+    registry_id: Cell<RegistryId>,
+    index: Cell<usize>,
+}
+
+thread_local! {
+    /// A thread local which contains the identifer of `REGISTRY` but allows for faster access.
+    /// It also holds the index of the current thread.
+    static THREAD_DATA: ThreadData = const { ThreadData {
+        registry_id: Cell::new(RegistryId(ptr::null())),
+        index: Cell::new(0),
+    }};
+}
+
+impl Registry {
+    /// Creates a registry which can hold up to `thread_limit` threads.
+    pub fn new(thread_limit: usize) -> Self {
+        Registry(Arc::new(RegistryData { thread_limit, threads: Lock::new(0) }))
+    }
+
+    /// Gets the registry associated with the current thread. Panics if there's no such registry.
+    pub fn current() -> Self {
+        REGISTRY.with(|registry| registry.get().cloned().expect("No assocated registry"))
+    }
+
+    /// Registers the current thread with the registry so worker locals can be used on it.
+    /// Panics if the thread limit is hit or if the thread already has an associated registry.
+    pub fn register(&self) {
+        let mut threads = self.0.threads.lock();
+        if *threads < self.0.thread_limit {
+            REGISTRY.with(|registry| {
+                if registry.get().is_some() {
+                    drop(threads);
+                    panic!("Thread already has a registry");
+                }
+                registry.set(self.clone()).ok();
+                THREAD_DATA.with(|data| {
+                    data.registry_id.set(self.id());
+                    data.index.set(*threads);
+                });
+                *threads += 1;
+            });
+        } else {
+            drop(threads);
+            panic!("Thread limit reached");
+        }
+    }
+
+    /// Gets the identifer of this registry.
+    fn id(&self) -> RegistryId {
+        RegistryId(&*self.0)
+    }
+}
+
+/// Holds worker local values for each possible thread in a registry. You can only access the
+/// worker local value through the `Deref` impl on the registry associated with the thread it was
+/// created on. It will panic otherwise.
+pub struct WorkerLocal<T> {
+    #[cfg(not(parallel_compiler))]
+    local: T,
+    #[cfg(parallel_compiler)]
+    locals: Box<[CacheAligned<T>]>,
+    #[cfg(parallel_compiler)]
+    registry: Registry,
+}
+
+// This is safe because the `deref` call will return a reference to a `T` unique to each thread
+// or it will panic for threads without an associated local. So there isn't a need for `T` to do
+// it's own synchronization. The `verify` method on `RegistryId` has an issue where the the id
+// can be reused, but `WorkerLocal` has a reference to `Registry` which will prevent any reuse.
+#[cfg(parallel_compiler)]
+unsafe impl<T: Send> Sync for WorkerLocal<T> {}
+
+impl<T> WorkerLocal<T> {
+    /// Creates a new worker local where the `initial` closure computes the
+    /// value this worker local should take for each thread in the registry.
+    #[inline]
+    pub fn new<F: FnMut(usize) -> T>(mut initial: F) -> WorkerLocal<T> {
+        #[cfg(parallel_compiler)]
+        {
+            let registry = Registry::current();
+            WorkerLocal {
+                locals: (0..registry.0.thread_limit).map(|i| CacheAligned(initial(i))).collect(),
+                registry,
+            }
+        }
+        #[cfg(not(parallel_compiler))]
+        {
+            WorkerLocal { local: initial(0) }
+        }
+    }
+
+    /// Returns the worker-local values for each thread
+    #[inline]
+    pub fn into_inner(self) -> impl Iterator<Item = T> {
+        #[cfg(parallel_compiler)]
+        {
+            self.locals.into_vec().into_iter().map(|local| local.0)
+        }
+        #[cfg(not(parallel_compiler))]
+        {
+            std::iter::once(self.local)
+        }
+    }
+}
+
+impl<T> WorkerLocal<Vec<T>> {
+    /// Joins the elements of all the worker locals into one Vec
+    pub fn join(self) -> Vec<T> {
+        self.into_inner().into_iter().flat_map(|v| v).collect()
+    }
+}
+
+impl<T> Deref for WorkerLocal<T> {
+    type Target = T;
+
+    #[inline(always)]
+    #[cfg(not(parallel_compiler))]
+    fn deref(&self) -> &T {
+        &self.local
+    }
+
+    #[inline(always)]
+    #[cfg(parallel_compiler)]
+    fn deref(&self) -> &T {
+        // This is safe because `verify` will only return values less than
+        // `self.registry.thread_limit` which is the size of the `self.locals` array.
+        unsafe { &self.locals.get_unchecked(self.registry.id().verify()).0 }
+    }
+}
diff --git a/compiler/rustc_interface/src/util.rs b/compiler/rustc_interface/src/util.rs
index 612903810d2..a27a1e2978a 100644
--- a/compiler/rustc_interface/src/util.rs
+++ b/compiler/rustc_interface/src/util.rs
@@ -4,6 +4,8 @@ use libloading::Library;
 use rustc_ast as ast;
 use rustc_codegen_ssa::traits::CodegenBackend;
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
+#[cfg(parallel_compiler)]
+use rustc_data_structures::sync;
 use rustc_errors::registry::Registry;
 use rustc_parse::validate_attr;
 use rustc_session as session;
@@ -170,6 +172,7 @@ pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce() -> R + Send, R: Send>(
     use rustc_middle::ty::tls;
     use rustc_query_impl::{deadlock, QueryContext, QueryCtxt};
 
+    let registry = sync::Registry::new(threads);
     let mut builder = rayon::ThreadPoolBuilder::new()
         .thread_name(|_| "rustc".to_string())
         .acquire_thread_handler(jobserver::acquire_thread)
@@ -200,6 +203,9 @@ pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce() -> R + Send, R: Send>(
                 .build_scoped(
                     // Initialize each new worker thread when created.
                     move |thread: rayon::ThreadBuilder| {
+                        // Register the thread for use with the `WorkerLocal` type.
+                        registry.register();
+
                         rustc_span::set_session_globals_then(session_globals, || thread.run())
                     },
                     // Run `f` on the first thread in the thread pool.