about summary refs log tree commit diff
path: root/library/std/src/sync/barrier.rs
diff options
context:
space:
mode:
Diffstat (limited to 'library/std/src/sync/barrier.rs')
-rw-r--r--library/std/src/sync/barrier.rs175
1 files changed, 175 insertions, 0 deletions
diff --git a/library/std/src/sync/barrier.rs b/library/std/src/sync/barrier.rs
new file mode 100644
index 00000000000..82cc13a74b7
--- /dev/null
+++ b/library/std/src/sync/barrier.rs
@@ -0,0 +1,175 @@
+#[cfg(test)]
+mod tests;
+
+use crate::fmt;
+use crate::sync::{Condvar, Mutex};
+
+/// A barrier enables multiple threads to synchronize the beginning
+/// of some computation.
+///
+/// # Examples
+///
+/// ```
+/// use std::sync::{Arc, Barrier};
+/// use std::thread;
+///
+/// let n = 10;
+/// let mut handles = Vec::with_capacity(n);
+/// let barrier = Arc::new(Barrier::new(n));
+/// for _ in 0..n {
+///     let c = Arc::clone(&barrier);
+///     // The same messages will be printed together.
+///     // You will NOT see any interleaving.
+///     handles.push(thread::spawn(move || {
+///         println!("before wait");
+///         c.wait();
+///         println!("after wait");
+///     }));
+/// }
+/// // Wait for other threads to finish.
+/// for handle in handles {
+///     handle.join().unwrap();
+/// }
+/// ```
+#[stable(feature = "rust1", since = "1.0.0")]
+pub struct Barrier {
+    lock: Mutex<BarrierState>,
+    cvar: Condvar,
+    num_threads: usize,
+}
+
+// The inner state of a double barrier
+struct BarrierState {
+    count: usize,
+    generation_id: usize,
+}
+
+/// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads
+/// in the [`Barrier`] have rendezvoused.
+///
+/// # Examples
+///
+/// ```
+/// use std::sync::Barrier;
+///
+/// let barrier = Barrier::new(1);
+/// let barrier_wait_result = barrier.wait();
+/// ```
+#[stable(feature = "rust1", since = "1.0.0")]
+pub struct BarrierWaitResult(bool);
+
+#[stable(feature = "std_debug", since = "1.16.0")]
+impl fmt::Debug for Barrier {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("Barrier").finish_non_exhaustive()
+    }
+}
+
+impl Barrier {
+    /// Creates a new barrier that can block a given number of threads.
+    ///
+    /// A barrier will block `n`-1 threads which call [`wait()`] and then wake
+    /// up all threads at once when the `n`th thread calls [`wait()`].
+    ///
+    /// [`wait()`]: Barrier::wait
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use std::sync::Barrier;
+    ///
+    /// let barrier = Barrier::new(10);
+    /// ```
+    #[stable(feature = "rust1", since = "1.0.0")]
+    #[rustc_const_stable(feature = "const_barrier", since = "1.78.0")]
+    #[must_use]
+    #[inline]
+    pub const fn new(n: usize) -> Barrier {
+        Barrier {
+            lock: Mutex::new(BarrierState { count: 0, generation_id: 0 }),
+            cvar: Condvar::new(),
+            num_threads: n,
+        }
+    }
+
+    /// Blocks the current thread until all threads have rendezvoused here.
+    ///
+    /// Barriers are re-usable after all threads have rendezvoused once, and can
+    /// be used continuously.
+    ///
+    /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
+    /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
+    /// from this function, and all other threads will receive a result that
+    /// will return `false` from [`BarrierWaitResult::is_leader()`].
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use std::sync::{Arc, Barrier};
+    /// use std::thread;
+    ///
+    /// let n = 10;
+    /// let mut handles = Vec::with_capacity(n);
+    /// let barrier = Arc::new(Barrier::new(n));
+    /// for _ in 0..n {
+    ///     let c = Arc::clone(&barrier);
+    ///     // The same messages will be printed together.
+    ///     // You will NOT see any interleaving.
+    ///     handles.push(thread::spawn(move || {
+    ///         println!("before wait");
+    ///         c.wait();
+    ///         println!("after wait");
+    ///     }));
+    /// }
+    /// // Wait for other threads to finish.
+    /// for handle in handles {
+    ///     handle.join().unwrap();
+    /// }
+    /// ```
+    #[stable(feature = "rust1", since = "1.0.0")]
+    pub fn wait(&self) -> BarrierWaitResult {
+        let mut lock = self.lock.lock().unwrap();
+        let local_gen = lock.generation_id;
+        lock.count += 1;
+        if lock.count < self.num_threads {
+            let _guard =
+                self.cvar.wait_while(lock, |state| local_gen == state.generation_id).unwrap();
+            BarrierWaitResult(false)
+        } else {
+            lock.count = 0;
+            lock.generation_id = lock.generation_id.wrapping_add(1);
+            self.cvar.notify_all();
+            BarrierWaitResult(true)
+        }
+    }
+}
+
+#[stable(feature = "std_debug", since = "1.16.0")]
+impl fmt::Debug for BarrierWaitResult {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("BarrierWaitResult").field("is_leader", &self.is_leader()).finish()
+    }
+}
+
+impl BarrierWaitResult {
+    /// Returns `true` if this thread is the "leader thread" for the call to
+    /// [`Barrier::wait()`].
+    ///
+    /// Only one thread will have `true` returned from their result, all other
+    /// threads will have `false` returned.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use std::sync::Barrier;
+    ///
+    /// let barrier = Barrier::new(1);
+    /// let barrier_wait_result = barrier.wait();
+    /// println!("{:?}", barrier_wait_result.is_leader());
+    /// ```
+    #[stable(feature = "rust1", since = "1.0.0")]
+    #[must_use]
+    pub fn is_leader(&self) -> bool {
+        self.0
+    }
+}