diff options
| author | Mara Bos <m-ou.se@m-ou.se> | 2022-01-04 14:51:39 +0100 |
|---|---|---|
| committer | Mara Bos <m-ou.se@m-ou.se> | 2022-01-04 14:51:39 +0100 |
| commit | 0e24ad537be4d47686f3b9e3e6623664bce7cbc2 (patch) | |
| tree | aae619d18b17ff9d6867e9e584660fd92a70299c | |
| parent | a45b3ac1836b8d29a2a7b199aed169402aa01805 (diff) | |
| download | rust-0e24ad537be4d47686f3b9e3e6623664bce7cbc2.tar.gz rust-0e24ad537be4d47686f3b9e3e6623664bce7cbc2.zip | |
Implement RFC 3151: Scoped threads.
| -rw-r--r-- | library/std/src/thread/mod.rs | 96 | ||||
| -rw-r--r-- | library/std/src/thread/scoped.rs | 132 |
2 files changed, 202 insertions, 26 deletions
diff --git a/library/std/src/thread/mod.rs b/library/std/src/thread/mod.rs index c799b64c05b..0125545a3db 100644 --- a/library/std/src/thread/mod.rs +++ b/library/std/src/thread/mod.rs @@ -180,6 +180,12 @@ use crate::time::Duration; #[macro_use] mod local; +#[unstable(feature = "scoped_threads", issue = "none")] +mod scoped; + +#[unstable(feature = "scoped_threads", issue = "none")] +pub use scoped::{scope, Scope, ScopedJoinHandle}; + #[stable(feature = "rust1", since = "1.0.0")] pub use self::local::{AccessError, LocalKey}; @@ -447,6 +453,20 @@ impl Builder { F: Send + 'a, T: Send + 'a, { + Ok(JoinHandle(unsafe { self.spawn_unchecked_(f, None) }?)) + } + + unsafe fn spawn_unchecked_<'a, 'scope, F, T>( + self, + f: F, + scope_data: Option<&'scope scoped::ScopeData>, + ) -> io::Result<JoinInner<'scope, T>> + where + F: FnOnce() -> T, + F: Send + 'a, + T: Send + 'a, + 'scope: 'a, + { let Builder { name, stack_size } = self; let stack_size = stack_size.unwrap_or_else(thread::min_stack); @@ -456,7 +476,8 @@ impl Builder { })); let their_thread = my_thread.clone(); - let my_packet: Arc<UnsafeCell<Option<Result<T>>>> = Arc::new(UnsafeCell::new(None)); + let my_packet: Arc<Packet<'scope, T>> = + Arc::new(Packet { scope: scope_data, result: UnsafeCell::new(None) }); let their_packet = my_packet.clone(); let output_capture = crate::io::set_output_capture(None); @@ -480,10 +501,14 @@ impl Builder { // closure (it is an Arc<...>) and `my_packet` will be stored in the // same `JoinInner` as this closure meaning the mutation will be // safe (not modify it and affect a value far away). - unsafe { *their_packet.get() = Some(try_result) }; + unsafe { *their_packet.result.get() = Some(try_result) }; }; - Ok(JoinHandle(JoinInner { + if let Some(scope_data) = scope_data { + scope_data.increment_n_running_threads(); + } + + Ok(JoinInner { // SAFETY: // // `imp::Thread::new` takes a closure with a `'static` lifetime, since it's passed @@ -506,8 +531,8 @@ impl Builder { )? }, thread: my_thread, - packet: Packet(my_packet), - })) + packet: my_packet, + }) } } @@ -1239,34 +1264,53 @@ impl fmt::Debug for Thread { #[stable(feature = "rust1", since = "1.0.0")] pub type Result<T> = crate::result::Result<T, Box<dyn Any + Send + 'static>>; -// This packet is used to communicate the return value between the spawned thread -// and the rest of the program. Memory is shared through the `Arc` within and there's -// no need for a mutex here because synchronization happens with `join()` (the -// caller will never read this packet until the thread has exited). +// This packet is used to communicate the return value between the spawned +// thread and the rest of the program. It is shared through an `Arc` and +// there's no need for a mutex here because synchronization happens with `join()` +// (the caller will never read this packet until the thread has exited). // -// This packet itself is then stored into a `JoinInner` which in turns is placed -// in `JoinHandle` and `JoinGuard`. Due to the usage of `UnsafeCell` we need to -// manually worry about impls like Send and Sync. The type `T` should -// already always be Send (otherwise the thread could not have been created) and -// this type is inherently Sync because no methods take &self. Regardless, -// however, we add inheriting impls for Send/Sync to this type to ensure it's -// Send/Sync and that future modifications will still appropriately classify it. -struct Packet<T>(Arc<UnsafeCell<Option<Result<T>>>>); - -unsafe impl<T: Send> Send for Packet<T> {} -unsafe impl<T: Sync> Sync for Packet<T> {} +// An Arc to the packet is stored into a `JoinInner` which in turns is placed +// in `JoinHandle`. Due to the usage of `UnsafeCell` we need to manually worry +// about impls like Send and Sync. The type `T` should already always be Send +// (otherwise the thread could not have been created) and this type is +// inherently Sync because no methods take &self. Regardless, however, we add +// inheriting impls for Send/Sync to this type to ensure it's Send/Sync and +// that future modifications will still appropriately classify it. +struct Packet<'scope, T> { + scope: Option<&'scope scoped::ScopeData>, + result: UnsafeCell<Option<Result<T>>>, +} + +unsafe impl<'scope, T: Send> Send for Packet<'scope, T> {} +unsafe impl<'scope, T: Sync> Sync for Packet<'scope, T> {} + +impl<'scope, T> Drop for Packet<'scope, T> { + fn drop(&mut self) { + if let Some(scope) = self.scope { + // If this packet was for a thread that ran in a scope, the thread + // panicked, and nobody consumed the panic payload, we put the + // panic payload in the scope so it can re-throw it, if it didn't + // already capture any panic yet. + if let Some(Err(e)) = self.result.get_mut().take() { + scope.panic_payload.lock().unwrap().get_or_insert(e); + } + // Book-keeping so the scope knows when it's done. + scope.decrement_n_running_threads(); + } + } +} /// Inner representation for JoinHandle -struct JoinInner<T> { +struct JoinInner<'scope, T> { native: imp::Thread, thread: Thread, - packet: Packet<T>, + packet: Arc<Packet<'scope, T>>, } -impl<T> JoinInner<T> { +impl<'scope, T> JoinInner<'scope, T> { fn join(mut self) -> Result<T> { self.native.join(); - Arc::get_mut(&mut self.packet.0).unwrap().get_mut().take().unwrap() + Arc::get_mut(&mut self.packet).unwrap().result.get_mut().take().unwrap() } } @@ -1333,7 +1377,7 @@ impl<T> JoinInner<T> { /// [`thread::Builder::spawn`]: Builder::spawn /// [`thread::spawn`]: spawn #[stable(feature = "rust1", since = "1.0.0")] -pub struct JoinHandle<T>(JoinInner<T>); +pub struct JoinHandle<T>(JoinInner<'static, T>); #[stable(feature = "joinhandle_impl_send_sync", since = "1.29.0")] unsafe impl<T> Send for JoinHandle<T> {} @@ -1407,7 +1451,7 @@ impl<T> JoinHandle<T> { /// function has returned, but before the thread itself has stopped running. #[unstable(feature = "thread_is_running", issue = "90470")] pub fn is_running(&self) -> bool { - Arc::strong_count(&self.0.packet.0) > 1 + Arc::strong_count(&self.0.packet) > 1 } } diff --git a/library/std/src/thread/scoped.rs b/library/std/src/thread/scoped.rs new file mode 100644 index 00000000000..8e9a43e05be --- /dev/null +++ b/library/std/src/thread/scoped.rs @@ -0,0 +1,132 @@ +use super::{current, park, Builder, JoinInner, Result, Thread}; +use crate::any::Any; +use crate::fmt; +use crate::io; +use crate::marker::PhantomData; +use crate::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; +use crate::sync::atomic::{AtomicUsize, Ordering}; +use crate::sync::Mutex; + +/// TODO: documentation +pub struct Scope<'env> { + data: ScopeData, + env: PhantomData<&'env ()>, +} + +/// TODO: documentation +pub struct ScopedJoinHandle<'scope, T>(JoinInner<'scope, T>); + +pub(super) struct ScopeData { + n_running_threads: AtomicUsize, + main_thread: Thread, + pub(super) panic_payload: Mutex<Option<Box<dyn Any + Send>>>, +} + +impl ScopeData { + pub(super) fn increment_n_running_threads(&self) { + // We check for 'overflow' with usize::MAX / 2, to make sure there's no + // chance it overflows to 0, which would result in unsoundness. + if self.n_running_threads.fetch_add(1, Ordering::Relaxed) == usize::MAX / 2 { + // This can only reasonably happen by mem::forget()'ing many many ScopedJoinHandles. + self.decrement_n_running_threads(); + panic!("too many running threads in thread scope"); + } + } + pub(super) fn decrement_n_running_threads(&self) { + if self.n_running_threads.fetch_sub(1, Ordering::Release) == 1 { + self.main_thread.unpark(); + } + } +} + +/// TODO: documentation +pub fn scope<'env, F, T>(f: F) -> T +where + F: FnOnce(&Scope<'env>) -> T, +{ + let mut scope = Scope { + data: ScopeData { + n_running_threads: AtomicUsize::new(0), + main_thread: current(), + panic_payload: Mutex::new(None), + }, + env: PhantomData, + }; + + // Run `f`, but catch panics so we can make sure to wait for all the threads to join. + let result = catch_unwind(AssertUnwindSafe(|| f(&scope))); + + // Wait until all the threads are finished. + while scope.data.n_running_threads.load(Ordering::Acquire) != 0 { + park(); + } + + // Throw any panic from `f` or from any panicked thread, or the return value of `f` otherwise. + match result { + Err(e) => { + // `f` itself panicked. + resume_unwind(e); + } + Ok(result) => { + if let Some(panic_payload) = scope.data.panic_payload.get_mut().unwrap().take() { + // A thread panicked. + resume_unwind(panic_payload); + } else { + // Nothing panicked. + result + } + } + } +} + +impl<'env> Scope<'env> { + /// TODO: documentation + pub fn spawn<'scope, F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T> + where + F: FnOnce(&Scope<'env>) -> T + Send + 'env, + T: Send + 'env, + { + Builder::new().spawn_scoped(self, f).expect("failed to spawn thread") + } +} + +impl Builder { + fn spawn_scoped<'scope, 'env, F, T>( + self, + scope: &'scope Scope<'env>, + f: F, + ) -> io::Result<ScopedJoinHandle<'scope, T>> + where + F: FnOnce(&Scope<'env>) -> T + Send + 'env, + T: Send + 'env, + { + Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(|| f(scope), Some(&scope.data)) }?)) + } +} + +impl<'scope, T> ScopedJoinHandle<'scope, T> { + /// TODO + pub fn join(self) -> Result<T> { + self.0.join() + } + + /// TODO + pub fn thread(&self) -> &Thread { + &self.0.thread + } +} + +impl<'env> fmt::Debug for Scope<'env> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Scope") + .field("n_running_threads", &self.data.n_running_threads.load(Ordering::Relaxed)) + .field("panic_payload", &self.data.panic_payload) + .finish_non_exhaustive() + } +} + +impl<'scope, T> fmt::Debug for ScopedJoinHandle<'scope, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ScopedJoinHandle").finish_non_exhaustive() + } +} |
