about summary refs log tree commit diff
path: root/library/std/src/sync
diff options
context:
space:
mode:
authorMatthias Krüger <matthias.krueger@famsik.de>2023-03-21 19:00:11 +0100
committerGitHub <noreply@github.com>2023-03-21 19:00:11 +0100
commit93a82a44a192a8a3f3c3610853878a57a0a37ee3 (patch)
treecd8f577063f6c3d5c81b9eac2b41da9181d939b6 /library/std/src/sync
parent1a43859a747a8916dbec98b8847a237e6caaf994 (diff)
parent34aa87292c5cd45c88a72235dad6e973a9f2b62f (diff)
downloadrust-93a82a44a192a8a3f3c3610853878a57a0a37ee3.tar.gz
rust-93a82a44a192a8a3f3c3610853878a57a0a37ee3.zip
Rollup merge of #108164 - joboet:discard_messages_mpmc_array, r=Amanieu
Drop all messages in bounded channel when destroying the last receiver

Fixes #107466 by splitting the `disconnect` function for receivers/transmitters and dropping all messages in `disconnect_receivers` like the unbounded channel does. Since all receivers must be dropped before the channel is, the messages will already be discarded at that point, so the `Drop` implementation for the channel can be removed.

``@rustbot`` label +T-libs +A-concurrency
Diffstat (limited to 'library/std/src/sync')
-rw-r--r--library/std/src/sync/mpmc/array.rs107
-rw-r--r--library/std/src/sync/mpmc/mod.rs4
-rw-r--r--library/std/src/sync/mpsc/sync_tests.rs13
3 files changed, 98 insertions, 26 deletions
diff --git a/library/std/src/sync/mpmc/array.rs b/library/std/src/sync/mpmc/array.rs
index c6bb09b0417..492e21d9bdb 100644
--- a/library/std/src/sync/mpmc/array.rs
+++ b/library/std/src/sync/mpmc/array.rs
@@ -25,7 +25,8 @@ struct Slot<T> {
     /// The current stamp.
     stamp: AtomicUsize,
 
-    /// The message in this slot.
+    /// The message in this slot. Either read out in `read` or dropped through
+    /// `discard_all_messages`.
     msg: UnsafeCell<MaybeUninit<T>>,
 }
 
@@ -439,14 +440,13 @@ impl<T> Channel<T> {
         Some(self.cap)
     }
 
-    /// Disconnects the channel and wakes up all blocked senders and receivers.
+    /// Disconnects senders and wakes up all blocked receivers.
     ///
     /// Returns `true` if this call disconnected the channel.
-    pub(crate) fn disconnect(&self) -> bool {
+    pub(crate) fn disconnect_senders(&self) -> bool {
         let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
 
         if tail & self.mark_bit == 0 {
-            self.senders.disconnect();
             self.receivers.disconnect();
             true
         } else {
@@ -454,6 +454,85 @@ impl<T> Channel<T> {
         }
     }
 
+    /// Disconnects receivers and wakes up all blocked senders.
+    ///
+    /// Returns `true` if this call disconnected the channel.
+    ///
+    /// # Safety
+    /// May only be called once upon dropping the last receiver. The
+    /// destruction of all other receivers must have been observed with acquire
+    /// ordering or stronger.
+    pub(crate) unsafe fn disconnect_receivers(&self) -> bool {
+        let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
+        let disconnected = if tail & self.mark_bit == 0 {
+            self.senders.disconnect();
+            true
+        } else {
+            false
+        };
+
+        self.discard_all_messages(tail);
+        disconnected
+    }
+
+    /// Discards all messages.
+    ///
+    /// `tail` should be the current (and therefore last) value of `tail`.
+    ///
+    /// # Panicking
+    /// If a destructor panics, the remaining messages are leaked, matching the
+    /// behaviour of the unbounded channel.
+    ///
+    /// # Safety
+    /// This method must only be called when dropping the last receiver. The
+    /// destruction of all other receivers must have been observed with acquire
+    /// ordering or stronger.
+    unsafe fn discard_all_messages(&self, tail: usize) {
+        debug_assert!(self.is_disconnected());
+
+        // Only receivers modify `head`, so since we are the last one,
+        // this value will not change and will not be observed (since
+        // no new messages can be sent after disconnection).
+        let mut head = self.head.load(Ordering::Relaxed);
+        let tail = tail & !self.mark_bit;
+
+        let backoff = Backoff::new();
+        loop {
+            // Deconstruct the head.
+            let index = head & (self.mark_bit - 1);
+            let lap = head & !(self.one_lap - 1);
+
+            // Inspect the corresponding slot.
+            debug_assert!(index < self.buffer.len());
+            let slot = unsafe { self.buffer.get_unchecked(index) };
+            let stamp = slot.stamp.load(Ordering::Acquire);
+
+            // If the stamp is ahead of the head by 1, we may drop the message.
+            if head + 1 == stamp {
+                head = if index + 1 < self.cap {
+                    // Same lap, incremented index.
+                    // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
+                    head + 1
+                } else {
+                    // One lap forward, index wraps around to zero.
+                    // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
+                    lap.wrapping_add(self.one_lap)
+                };
+
+                unsafe {
+                    (*slot.msg.get()).assume_init_drop();
+                }
+            // If the tail equals the head, that means the channel is empty.
+            } else if tail == head {
+                return;
+            // Otherwise, a sender is about to write into the slot, so we need
+            // to wait for it to update the stamp.
+            } else {
+                backoff.spin_heavy();
+            }
+        }
+    }
+
     /// Returns `true` if the channel is disconnected.
     pub(crate) fn is_disconnected(&self) -> bool {
         self.tail.load(Ordering::SeqCst) & self.mark_bit != 0
@@ -483,23 +562,3 @@ impl<T> Channel<T> {
         head.wrapping_add(self.one_lap) == tail & !self.mark_bit
     }
 }
-
-impl<T> Drop for Channel<T> {
-    fn drop(&mut self) {
-        // Get the index of the head.
-        let hix = self.head.load(Ordering::Relaxed) & (self.mark_bit - 1);
-
-        // Loop over all slots that hold a message and drop them.
-        for i in 0..self.len() {
-            // Compute the index of the next slot holding a message.
-            let index = if hix + i < self.cap { hix + i } else { hix + i - self.cap };
-
-            unsafe {
-                debug_assert!(index < self.buffer.len());
-                let slot = self.buffer.get_unchecked_mut(index);
-                let msg = &mut *slot.msg.get();
-                msg.as_mut_ptr().drop_in_place();
-            }
-        }
-    }
-}
diff --git a/library/std/src/sync/mpmc/mod.rs b/library/std/src/sync/mpmc/mod.rs
index 7a602cecd3b..2068dda393a 100644
--- a/library/std/src/sync/mpmc/mod.rs
+++ b/library/std/src/sync/mpmc/mod.rs
@@ -227,7 +227,7 @@ impl<T> Drop for Sender<T> {
     fn drop(&mut self) {
         unsafe {
             match &self.flavor {
-                SenderFlavor::Array(chan) => chan.release(|c| c.disconnect()),
+                SenderFlavor::Array(chan) => chan.release(|c| c.disconnect_senders()),
                 SenderFlavor::List(chan) => chan.release(|c| c.disconnect_senders()),
                 SenderFlavor::Zero(chan) => chan.release(|c| c.disconnect()),
             }
@@ -403,7 +403,7 @@ impl<T> Drop for Receiver<T> {
     fn drop(&mut self) {
         unsafe {
             match &self.flavor {
-                ReceiverFlavor::Array(chan) => chan.release(|c| c.disconnect()),
+                ReceiverFlavor::Array(chan) => chan.release(|c| c.disconnect_receivers()),
                 ReceiverFlavor::List(chan) => chan.release(|c| c.disconnect_receivers()),
                 ReceiverFlavor::Zero(chan) => chan.release(|c| c.disconnect()),
             }
diff --git a/library/std/src/sync/mpsc/sync_tests.rs b/library/std/src/sync/mpsc/sync_tests.rs
index 9d2f92ffc9b..632709fd98d 100644
--- a/library/std/src/sync/mpsc/sync_tests.rs
+++ b/library/std/src/sync/mpsc/sync_tests.rs
@@ -1,5 +1,6 @@
 use super::*;
 use crate::env;
+use crate::rc::Rc;
 use crate::sync::mpmc::SendTimeoutError;
 use crate::thread;
 use crate::time::Duration;
@@ -656,3 +657,15 @@ fn issue_15761() {
         repro()
     }
 }
+
+#[test]
+fn drop_unreceived() {
+    let (tx, rx) = sync_channel::<Rc<()>>(1);
+    let msg = Rc::new(());
+    let weak = Rc::downgrade(&msg);
+    assert!(tx.send(msg).is_ok());
+    drop(rx);
+    // Messages should be dropped immediately when the last receiver is destroyed.
+    assert!(weak.upgrade().is_none());
+    drop(tx);
+}