about summary refs log tree commit diff
path: root/library/std/src/sync/mpmc/list.rs
diff options
context:
space:
mode:
Diffstat (limited to 'library/std/src/sync/mpmc/list.rs')
-rw-r--r--library/std/src/sync/mpmc/list.rs648
1 files changed, 648 insertions, 0 deletions
diff --git a/library/std/src/sync/mpmc/list.rs b/library/std/src/sync/mpmc/list.rs
new file mode 100644
index 00000000000..5bc196995b1
--- /dev/null
+++ b/library/std/src/sync/mpmc/list.rs
@@ -0,0 +1,648 @@
+//! Unbounded channel implemented as a linked list.
+
+use super::context::Context;
+use super::error::*;
+use super::select::{Operation, Selected, Token};
+use super::utils::{Backoff, CachePadded};
+use super::waker::SyncWaker;
+
+use crate::cell::UnsafeCell;
+use crate::marker::PhantomData;
+use crate::mem::MaybeUninit;
+use crate::ptr;
+use crate::sync::atomic::{self, AtomicPtr, AtomicUsize, Ordering};
+use crate::time::Instant;
+
+// Bits indicating the state of a slot:
+// * If a message has been written into the slot, `WRITE` is set.
+// * If a message has been read from the slot, `READ` is set.
+// * If the block is being destroyed, `DESTROY` is set.
+const WRITE: usize = 1;
+const READ: usize = 2;
+const DESTROY: usize = 4;
+
+// Each block covers one "lap" of indices.
+const LAP: usize = 32;
+// The maximum number of messages a block can hold.
+const BLOCK_CAP: usize = LAP - 1;
+// How many lower bits are reserved for metadata.
+const SHIFT: usize = 1;
+// Has two different purposes:
+// * If set in head, indicates that the block is not the last one.
+// * If set in tail, indicates that the channel is disconnected.
+const MARK_BIT: usize = 1;
+
+/// A slot in a block.
+struct Slot<T> {
+    /// The message.
+    msg: UnsafeCell<MaybeUninit<T>>,
+
+    /// The state of the slot.
+    state: AtomicUsize,
+}
+
+impl<T> Slot<T> {
+    /// Waits until a message is written into the slot.
+    fn wait_write(&self) {
+        let backoff = Backoff::new();
+        while self.state.load(Ordering::Acquire) & WRITE == 0 {
+            backoff.snooze();
+        }
+    }
+}
+
+/// A block in a linked list.
+///
+/// Each block in the list can hold up to `BLOCK_CAP` messages.
+struct Block<T> {
+    /// The next block in the linked list.
+    next: AtomicPtr<Block<T>>,
+
+    /// Slots for messages.
+    slots: [Slot<T>; BLOCK_CAP],
+}
+
+impl<T> Block<T> {
+    /// Creates an empty block.
+    fn new() -> Block<T> {
+        // SAFETY: This is safe because:
+        //  [1] `Block::next` (AtomicPtr) may be safely zero initialized.
+        //  [2] `Block::slots` (Array) may be safely zero initialized because of [3, 4].
+        //  [3] `Slot::msg` (UnsafeCell) may be safely zero initialized because it
+        //       holds a MaybeUninit.
+        //  [4] `Slot::state` (AtomicUsize) may be safely zero initialized.
+        unsafe { MaybeUninit::zeroed().assume_init() }
+    }
+
+    /// Waits until the next pointer is set.
+    fn wait_next(&self) -> *mut Block<T> {
+        let backoff = Backoff::new();
+        loop {
+            let next = self.next.load(Ordering::Acquire);
+            if !next.is_null() {
+                return next;
+            }
+            backoff.snooze();
+        }
+    }
+
+    /// Sets the `DESTROY` bit in slots starting from `start` and destroys the block.
+    unsafe fn destroy(this: *mut Block<T>, start: usize) {
+        // It is not necessary to set the `DESTROY` bit in the last slot because that slot has
+        // begun destruction of the block.
+        for i in start..BLOCK_CAP - 1 {
+            let slot = (*this).slots.get_unchecked(i);
+
+            // Mark the `DESTROY` bit if a thread is still using the slot.
+            if slot.state.load(Ordering::Acquire) & READ == 0
+                && slot.state.fetch_or(DESTROY, Ordering::AcqRel) & READ == 0
+            {
+                // If a thread is still using the slot, it will continue destruction of the block.
+                return;
+            }
+        }
+
+        // No thread is using the block, now it is safe to destroy it.
+        drop(Box::from_raw(this));
+    }
+}
+
+/// A position in a channel.
+#[derive(Debug)]
+struct Position<T> {
+    /// The index in the channel.
+    index: AtomicUsize,
+
+    /// The block in the linked list.
+    block: AtomicPtr<Block<T>>,
+}
+
+/// The token type for the list flavor.
+#[derive(Debug)]
+pub(crate) struct ListToken {
+    /// The block of slots.
+    block: *const u8,
+
+    /// The offset into the block.
+    offset: usize,
+}
+
+impl Default for ListToken {
+    #[inline]
+    fn default() -> Self {
+        ListToken { block: ptr::null(), offset: 0 }
+    }
+}
+
+/// Unbounded channel implemented as a linked list.
+///
+/// Each message sent into the channel is assigned a sequence number, i.e. an index. Indices are
+/// represented as numbers of type `usize` and wrap on overflow.
+///
+/// Consecutive messages are grouped into blocks in order to put less pressure on the allocator and
+/// improve cache efficiency.
+pub(crate) struct Channel<T> {
+    /// The head of the channel.
+    head: CachePadded<Position<T>>,
+
+    /// The tail of the channel.
+    tail: CachePadded<Position<T>>,
+
+    /// Receivers waiting while the channel is empty and not disconnected.
+    receivers: SyncWaker,
+
+    /// Indicates that dropping a `Channel<T>` may drop messages of type `T`.
+    _marker: PhantomData<T>,
+}
+
+impl<T> Channel<T> {
+    /// Creates a new unbounded channel.
+    pub(crate) fn new() -> Self {
+        Channel {
+            head: CachePadded::new(Position {
+                block: AtomicPtr::new(ptr::null_mut()),
+                index: AtomicUsize::new(0),
+            }),
+            tail: CachePadded::new(Position {
+                block: AtomicPtr::new(ptr::null_mut()),
+                index: AtomicUsize::new(0),
+            }),
+            receivers: SyncWaker::new(),
+            _marker: PhantomData,
+        }
+    }
+
+    /// Attempts to reserve a slot for sending a message.
+    fn start_send(&self, token: &mut Token) -> bool {
+        let backoff = Backoff::new();
+        let mut tail = self.tail.index.load(Ordering::Acquire);
+        let mut block = self.tail.block.load(Ordering::Acquire);
+        let mut next_block = None;
+
+        loop {
+            // Check if the channel is disconnected.
+            if tail & MARK_BIT != 0 {
+                token.list.block = ptr::null();
+                return true;
+            }
+
+            // Calculate the offset of the index into the block.
+            let offset = (tail >> SHIFT) % LAP;
+
+            // If we reached the end of the block, wait until the next one is installed.
+            if offset == BLOCK_CAP {
+                backoff.snooze();
+                tail = self.tail.index.load(Ordering::Acquire);
+                block = self.tail.block.load(Ordering::Acquire);
+                continue;
+            }
+
+            // If we're going to have to install the next block, allocate it in advance in order to
+            // make the wait for other threads as short as possible.
+            if offset + 1 == BLOCK_CAP && next_block.is_none() {
+                next_block = Some(Box::new(Block::<T>::new()));
+            }
+
+            // If this is the first message to be sent into the channel, we need to allocate the
+            // first block and install it.
+            if block.is_null() {
+                let new = Box::into_raw(Box::new(Block::<T>::new()));
+
+                if self
+                    .tail
+                    .block
+                    .compare_exchange(block, new, Ordering::Release, Ordering::Relaxed)
+                    .is_ok()
+                {
+                    self.head.block.store(new, Ordering::Release);
+                    block = new;
+                } else {
+                    next_block = unsafe { Some(Box::from_raw(new)) };
+                    tail = self.tail.index.load(Ordering::Acquire);
+                    block = self.tail.block.load(Ordering::Acquire);
+                    continue;
+                }
+            }
+
+            let new_tail = tail + (1 << SHIFT);
+
+            // Try advancing the tail forward.
+            match self.tail.index.compare_exchange_weak(
+                tail,
+                new_tail,
+                Ordering::SeqCst,
+                Ordering::Acquire,
+            ) {
+                Ok(_) => unsafe {
+                    // If we've reached the end of the block, install the next one.
+                    if offset + 1 == BLOCK_CAP {
+                        let next_block = Box::into_raw(next_block.unwrap());
+                        self.tail.block.store(next_block, Ordering::Release);
+                        self.tail.index.fetch_add(1 << SHIFT, Ordering::Release);
+                        (*block).next.store(next_block, Ordering::Release);
+                    }
+
+                    token.list.block = block as *const u8;
+                    token.list.offset = offset;
+                    return true;
+                },
+                Err(t) => {
+                    tail = t;
+                    block = self.tail.block.load(Ordering::Acquire);
+                    backoff.spin();
+                }
+            }
+        }
+    }
+
+    /// Writes a message into the channel.
+    pub(crate) unsafe fn write(&self, token: &mut Token, msg: T) -> Result<(), T> {
+        // If there is no slot, the channel is disconnected.
+        if token.list.block.is_null() {
+            return Err(msg);
+        }
+
+        // Write the message into the slot.
+        let block = token.list.block as *mut Block<T>;
+        let offset = token.list.offset;
+        let slot = (*block).slots.get_unchecked(offset);
+        slot.msg.get().write(MaybeUninit::new(msg));
+        slot.state.fetch_or(WRITE, Ordering::Release);
+
+        // Wake a sleeping receiver.
+        self.receivers.notify();
+        Ok(())
+    }
+
+    /// Attempts to reserve a slot for receiving a message.
+    fn start_recv(&self, token: &mut Token) -> bool {
+        let backoff = Backoff::new();
+        let mut head = self.head.index.load(Ordering::Acquire);
+        let mut block = self.head.block.load(Ordering::Acquire);
+
+        loop {
+            // Calculate the offset of the index into the block.
+            let offset = (head >> SHIFT) % LAP;
+
+            // If we reached the end of the block, wait until the next one is installed.
+            if offset == BLOCK_CAP {
+                backoff.snooze();
+                head = self.head.index.load(Ordering::Acquire);
+                block = self.head.block.load(Ordering::Acquire);
+                continue;
+            }
+
+            let mut new_head = head + (1 << SHIFT);
+
+            if new_head & MARK_BIT == 0 {
+                atomic::fence(Ordering::SeqCst);
+                let tail = self.tail.index.load(Ordering::Relaxed);
+
+                // If the tail equals the head, that means the channel is empty.
+                if head >> SHIFT == tail >> SHIFT {
+                    // If the channel is disconnected...
+                    if tail & MARK_BIT != 0 {
+                        // ...then receive an error.
+                        token.list.block = ptr::null();
+                        return true;
+                    } else {
+                        // Otherwise, the receive operation is not ready.
+                        return false;
+                    }
+                }
+
+                // If head and tail are not in the same block, set `MARK_BIT` in head.
+                if (head >> SHIFT) / LAP != (tail >> SHIFT) / LAP {
+                    new_head |= MARK_BIT;
+                }
+            }
+
+            // The block can be null here only if the first message is being sent into the channel.
+            // In that case, just wait until it gets initialized.
+            if block.is_null() {
+                backoff.snooze();
+                head = self.head.index.load(Ordering::Acquire);
+                block = self.head.block.load(Ordering::Acquire);
+                continue;
+            }
+
+            // Try moving the head index forward.
+            match self.head.index.compare_exchange_weak(
+                head,
+                new_head,
+                Ordering::SeqCst,
+                Ordering::Acquire,
+            ) {
+                Ok(_) => unsafe {
+                    // If we've reached the end of the block, move to the next one.
+                    if offset + 1 == BLOCK_CAP {
+                        let next = (*block).wait_next();
+                        let mut next_index = (new_head & !MARK_BIT).wrapping_add(1 << SHIFT);
+                        if !(*next).next.load(Ordering::Relaxed).is_null() {
+                            next_index |= MARK_BIT;
+                        }
+
+                        self.head.block.store(next, Ordering::Release);
+                        self.head.index.store(next_index, Ordering::Release);
+                    }
+
+                    token.list.block = block as *const u8;
+                    token.list.offset = offset;
+                    return true;
+                },
+                Err(h) => {
+                    head = h;
+                    block = self.head.block.load(Ordering::Acquire);
+                    backoff.spin();
+                }
+            }
+        }
+    }
+
+    /// Reads a message from the channel.
+    pub(crate) unsafe fn read(&self, token: &mut Token) -> Result<T, ()> {
+        if token.list.block.is_null() {
+            // The channel is disconnected.
+            return Err(());
+        }
+
+        // Read the message.
+        let block = token.list.block as *mut Block<T>;
+        let offset = token.list.offset;
+        let slot = (*block).slots.get_unchecked(offset);
+        slot.wait_write();
+        let msg = slot.msg.get().read().assume_init();
+
+        // Destroy the block if we've reached the end, or if another thread wanted to destroy but
+        // couldn't because we were busy reading from the slot.
+        if offset + 1 == BLOCK_CAP {
+            Block::destroy(block, 0);
+        } else if slot.state.fetch_or(READ, Ordering::AcqRel) & DESTROY != 0 {
+            Block::destroy(block, offset + 1);
+        }
+
+        Ok(msg)
+    }
+
+    /// Attempts to send a message into the channel.
+    pub(crate) fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
+        self.send(msg, None).map_err(|err| match err {
+            SendTimeoutError::Disconnected(msg) => TrySendError::Disconnected(msg),
+            SendTimeoutError::Timeout(_) => unreachable!(),
+        })
+    }
+
+    /// Sends a message into the channel.
+    pub(crate) fn send(
+        &self,
+        msg: T,
+        _deadline: Option<Instant>,
+    ) -> Result<(), SendTimeoutError<T>> {
+        let token = &mut Token::default();
+        assert!(self.start_send(token));
+        unsafe { self.write(token, msg).map_err(SendTimeoutError::Disconnected) }
+    }
+
+    /// Attempts to receive a message without blocking.
+    pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> {
+        let token = &mut Token::default();
+
+        if self.start_recv(token) {
+            unsafe { self.read(token).map_err(|_| TryRecvError::Disconnected) }
+        } else {
+            Err(TryRecvError::Empty)
+        }
+    }
+
+    /// Receives a message from the channel.
+    pub(crate) fn recv(&self, deadline: Option<Instant>) -> Result<T, RecvTimeoutError> {
+        let token = &mut Token::default();
+        loop {
+            // Try receiving a message several times.
+            let backoff = Backoff::new();
+            loop {
+                if self.start_recv(token) {
+                    unsafe {
+                        return self.read(token).map_err(|_| RecvTimeoutError::Disconnected);
+                    }
+                }
+
+                if backoff.is_completed() {
+                    break;
+                } else {
+                    backoff.snooze();
+                }
+            }
+
+            if let Some(d) = deadline {
+                if Instant::now() >= d {
+                    return Err(RecvTimeoutError::Timeout);
+                }
+            }
+
+            // Prepare for blocking until a sender wakes us up.
+            Context::with(|cx| {
+                let oper = Operation::hook(token);
+                self.receivers.register(oper, cx);
+
+                // Has the channel become ready just now?
+                if !self.is_empty() || self.is_disconnected() {
+                    let _ = cx.try_select(Selected::Aborted);
+                }
+
+                // Block the current thread.
+                let sel = cx.wait_until(deadline);
+
+                match sel {
+                    Selected::Waiting => unreachable!(),
+                    Selected::Aborted | Selected::Disconnected => {
+                        self.receivers.unregister(oper).unwrap();
+                        // If the channel was disconnected, we still have to check for remaining
+                        // messages.
+                    }
+                    Selected::Operation(_) => {}
+                }
+            });
+        }
+    }
+
+    /// Returns the current number of messages inside the channel.
+    pub(crate) fn len(&self) -> usize {
+        loop {
+            // Load the tail index, then load the head index.
+            let mut tail = self.tail.index.load(Ordering::SeqCst);
+            let mut head = self.head.index.load(Ordering::SeqCst);
+
+            // If the tail index didn't change, we've got consistent indices to work with.
+            if self.tail.index.load(Ordering::SeqCst) == tail {
+                // Erase the lower bits.
+                tail &= !((1 << SHIFT) - 1);
+                head &= !((1 << SHIFT) - 1);
+
+                // Fix up indices if they fall onto block ends.
+                if (tail >> SHIFT) & (LAP - 1) == LAP - 1 {
+                    tail = tail.wrapping_add(1 << SHIFT);
+                }
+                if (head >> SHIFT) & (LAP - 1) == LAP - 1 {
+                    head = head.wrapping_add(1 << SHIFT);
+                }
+
+                // Rotate indices so that head falls into the first block.
+                let lap = (head >> SHIFT) / LAP;
+                tail = tail.wrapping_sub((lap * LAP) << SHIFT);
+                head = head.wrapping_sub((lap * LAP) << SHIFT);
+
+                // Remove the lower bits.
+                tail >>= SHIFT;
+                head >>= SHIFT;
+
+                // Return the difference minus the number of blocks between tail and head.
+                return tail - head - tail / LAP;
+            }
+        }
+    }
+
+    /// Returns the capacity of the channel.
+    pub(crate) fn capacity(&self) -> Option<usize> {
+        None
+    }
+
+    /// Disconnects senders and wakes up all blocked receivers.
+    ///
+    /// Returns `true` if this call disconnected the channel.
+    pub(crate) fn disconnect_senders(&self) -> bool {
+        let tail = self.tail.index.fetch_or(MARK_BIT, Ordering::SeqCst);
+
+        if tail & MARK_BIT == 0 {
+            self.receivers.disconnect();
+            true
+        } else {
+            false
+        }
+    }
+
+    /// Disconnects receivers.
+    ///
+    /// Returns `true` if this call disconnected the channel.
+    pub(crate) fn disconnect_receivers(&self) -> bool {
+        let tail = self.tail.index.fetch_or(MARK_BIT, Ordering::SeqCst);
+
+        if tail & MARK_BIT == 0 {
+            // If receivers are dropped first, discard all messages to free
+            // memory eagerly.
+            self.discard_all_messages();
+            true
+        } else {
+            false
+        }
+    }
+
+    /// Discards all messages.
+    ///
+    /// This method should only be called when all receivers are dropped.
+    fn discard_all_messages(&self) {
+        let backoff = Backoff::new();
+        let mut tail = self.tail.index.load(Ordering::Acquire);
+        loop {
+            let offset = (tail >> SHIFT) % LAP;
+            if offset != BLOCK_CAP {
+                break;
+            }
+
+            // New updates to tail will be rejected by MARK_BIT and aborted unless it's
+            // at boundary. We need to wait for the updates take affect otherwise there
+            // can be memory leaks.
+            backoff.snooze();
+            tail = self.tail.index.load(Ordering::Acquire);
+        }
+
+        let mut head = self.head.index.load(Ordering::Acquire);
+        let mut block = self.head.block.load(Ordering::Acquire);
+
+        unsafe {
+            // Drop all messages between head and tail and deallocate the heap-allocated blocks.
+            while head >> SHIFT != tail >> SHIFT {
+                let offset = (head >> SHIFT) % LAP;
+
+                if offset < BLOCK_CAP {
+                    // Drop the message in the slot.
+                    let slot = (*block).slots.get_unchecked(offset);
+                    slot.wait_write();
+                    let p = &mut *slot.msg.get();
+                    p.as_mut_ptr().drop_in_place();
+                } else {
+                    (*block).wait_next();
+                    // Deallocate the block and move to the next one.
+                    let next = (*block).next.load(Ordering::Acquire);
+                    drop(Box::from_raw(block));
+                    block = next;
+                }
+
+                head = head.wrapping_add(1 << SHIFT);
+            }
+
+            // Deallocate the last remaining block.
+            if !block.is_null() {
+                drop(Box::from_raw(block));
+            }
+        }
+        head &= !MARK_BIT;
+        self.head.block.store(ptr::null_mut(), Ordering::Release);
+        self.head.index.store(head, Ordering::Release);
+    }
+
+    /// Returns `true` if the channel is disconnected.
+    pub(crate) fn is_disconnected(&self) -> bool {
+        self.tail.index.load(Ordering::SeqCst) & MARK_BIT != 0
+    }
+
+    /// Returns `true` if the channel is empty.
+    pub(crate) fn is_empty(&self) -> bool {
+        let head = self.head.index.load(Ordering::SeqCst);
+        let tail = self.tail.index.load(Ordering::SeqCst);
+        head >> SHIFT == tail >> SHIFT
+    }
+
+    /// Returns `true` if the channel is full.
+    pub(crate) fn is_full(&self) -> bool {
+        false
+    }
+}
+
+impl<T> Drop for Channel<T> {
+    fn drop(&mut self) {
+        let mut head = self.head.index.load(Ordering::Relaxed);
+        let mut tail = self.tail.index.load(Ordering::Relaxed);
+        let mut block = self.head.block.load(Ordering::Relaxed);
+
+        // Erase the lower bits.
+        head &= !((1 << SHIFT) - 1);
+        tail &= !((1 << SHIFT) - 1);
+
+        unsafe {
+            // Drop all messages between head and tail and deallocate the heap-allocated blocks.
+            while head != tail {
+                let offset = (head >> SHIFT) % LAP;
+
+                if offset < BLOCK_CAP {
+                    // Drop the message in the slot.
+                    let slot = (*block).slots.get_unchecked(offset);
+                    let p = &mut *slot.msg.get();
+                    p.as_mut_ptr().drop_in_place();
+                } else {
+                    // Deallocate the block and move to the next one.
+                    let next = (*block).next.load(Ordering::Relaxed);
+                    drop(Box::from_raw(block));
+                    block = next;
+                }
+
+                head = head.wrapping_add(1 << SHIFT);
+            }
+
+            // Deallocate the last remaining block.
+            if !block.is_null() {
+                drop(Box::from_raw(block));
+            }
+        }
+    }
+}