about summary refs log tree commit diff
path: root/library/std/src/sync/barrier.rs
diff options
context:
space:
mode:
Diffstat (limited to 'library/std/src/sync/barrier.rs')
-rw-r--r--library/std/src/sync/barrier.rs215
1 files changed, 215 insertions, 0 deletions
diff --git a/library/std/src/sync/barrier.rs b/library/std/src/sync/barrier.rs
new file mode 100644
index 00000000000..01314370ce3
--- /dev/null
+++ b/library/std/src/sync/barrier.rs
@@ -0,0 +1,215 @@
+use crate::fmt;
+use crate::sync::{Condvar, Mutex};
+
+/// A barrier enables multiple threads to synchronize the beginning
+/// of some computation.
+///
+/// # Examples
+///
+/// ```
+/// use std::sync::{Arc, Barrier};
+/// use std::thread;
+///
+/// let mut handles = Vec::with_capacity(10);
+/// let barrier = Arc::new(Barrier::new(10));
+/// for _ in 0..10 {
+///     let c = barrier.clone();
+///     // The same messages will be printed together.
+///     // You will NOT see any interleaving.
+///     handles.push(thread::spawn(move|| {
+///         println!("before wait");
+///         c.wait();
+///         println!("after wait");
+///     }));
+/// }
+/// // Wait for other threads to finish.
+/// for handle in handles {
+///     handle.join().unwrap();
+/// }
+/// ```
+#[stable(feature = "rust1", since = "1.0.0")]
+pub struct Barrier {
+    lock: Mutex<BarrierState>,
+    cvar: Condvar,
+    num_threads: usize,
+}
+
+// The inner state of a double barrier
+struct BarrierState {
+    count: usize,
+    generation_id: usize,
+}
+
+/// A `BarrierWaitResult` is returned by [`wait`] when all threads in the [`Barrier`]
+/// have rendezvoused.
+///
+/// [`wait`]: struct.Barrier.html#method.wait
+/// [`Barrier`]: struct.Barrier.html
+///
+/// # Examples
+///
+/// ```
+/// use std::sync::Barrier;
+///
+/// let barrier = Barrier::new(1);
+/// let barrier_wait_result = barrier.wait();
+/// ```
+#[stable(feature = "rust1", since = "1.0.0")]
+pub struct BarrierWaitResult(bool);
+
+#[stable(feature = "std_debug", since = "1.16.0")]
+impl fmt::Debug for Barrier {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.pad("Barrier { .. }")
+    }
+}
+
+impl Barrier {
+    /// Creates a new barrier that can block a given number of threads.
+    ///
+    /// A barrier will block `n`-1 threads which call [`wait`] and then wake up
+    /// all threads at once when the `n`th thread calls [`wait`].
+    ///
+    /// [`wait`]: #method.wait
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use std::sync::Barrier;
+    ///
+    /// let barrier = Barrier::new(10);
+    /// ```
+    #[stable(feature = "rust1", since = "1.0.0")]
+    pub fn new(n: usize) -> Barrier {
+        Barrier {
+            lock: Mutex::new(BarrierState { count: 0, generation_id: 0 }),
+            cvar: Condvar::new(),
+            num_threads: n,
+        }
+    }
+
+    /// Blocks the current thread until all threads have rendezvoused here.
+    ///
+    /// Barriers are re-usable after all threads have rendezvoused once, and can
+    /// be used continuously.
+    ///
+    /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
+    /// returns `true` from [`is_leader`] when returning from this function, and
+    /// all other threads will receive a result that will return `false` from
+    /// [`is_leader`].
+    ///
+    /// [`BarrierWaitResult`]: struct.BarrierWaitResult.html
+    /// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use std::sync::{Arc, Barrier};
+    /// use std::thread;
+    ///
+    /// let mut handles = Vec::with_capacity(10);
+    /// let barrier = Arc::new(Barrier::new(10));
+    /// for _ in 0..10 {
+    ///     let c = barrier.clone();
+    ///     // The same messages will be printed together.
+    ///     // You will NOT see any interleaving.
+    ///     handles.push(thread::spawn(move|| {
+    ///         println!("before wait");
+    ///         c.wait();
+    ///         println!("after wait");
+    ///     }));
+    /// }
+    /// // Wait for other threads to finish.
+    /// for handle in handles {
+    ///     handle.join().unwrap();
+    /// }
+    /// ```
+    #[stable(feature = "rust1", since = "1.0.0")]
+    pub fn wait(&self) -> BarrierWaitResult {
+        let mut lock = self.lock.lock().unwrap();
+        let local_gen = lock.generation_id;
+        lock.count += 1;
+        if lock.count < self.num_threads {
+            // We need a while loop to guard against spurious wakeups.
+            // http://en.wikipedia.org/wiki/Spurious_wakeup
+            while local_gen == lock.generation_id && lock.count < self.num_threads {
+                lock = self.cvar.wait(lock).unwrap();
+            }
+            BarrierWaitResult(false)
+        } else {
+            lock.count = 0;
+            lock.generation_id = lock.generation_id.wrapping_add(1);
+            self.cvar.notify_all();
+            BarrierWaitResult(true)
+        }
+    }
+}
+
+#[stable(feature = "std_debug", since = "1.16.0")]
+impl fmt::Debug for BarrierWaitResult {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("BarrierWaitResult").field("is_leader", &self.is_leader()).finish()
+    }
+}
+
+impl BarrierWaitResult {
+    /// Returns `true` if this thread from [`wait`] is the "leader thread".
+    ///
+    /// Only one thread will have `true` returned from their result, all other
+    /// threads will have `false` returned.
+    ///
+    /// [`wait`]: struct.Barrier.html#method.wait
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use std::sync::Barrier;
+    ///
+    /// let barrier = Barrier::new(1);
+    /// let barrier_wait_result = barrier.wait();
+    /// println!("{:?}", barrier_wait_result.is_leader());
+    /// ```
+    #[stable(feature = "rust1", since = "1.0.0")]
+    pub fn is_leader(&self) -> bool {
+        self.0
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::sync::mpsc::{channel, TryRecvError};
+    use crate::sync::{Arc, Barrier};
+    use crate::thread;
+
+    #[test]
+    #[cfg_attr(target_os = "emscripten", ignore)]
+    fn test_barrier() {
+        const N: usize = 10;
+
+        let barrier = Arc::new(Barrier::new(N));
+        let (tx, rx) = channel();
+
+        for _ in 0..N - 1 {
+            let c = barrier.clone();
+            let tx = tx.clone();
+            thread::spawn(move || {
+                tx.send(c.wait().is_leader()).unwrap();
+            });
+        }
+
+        // At this point, all spawned threads should be blocked,
+        // so we shouldn't get anything from the port
+        assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
+
+        let mut leader_found = barrier.wait().is_leader();
+
+        // Now, the barrier is cleared and we should get data.
+        for _ in 0..N - 1 {
+            if rx.recv().unwrap() {
+                assert!(!leader_found);
+                leader_found = true;
+            }
+        }
+        assert!(leader_found);
+    }
+}