about summary refs log tree commit diff
path: root/library
diff options
context:
space:
mode:
Diffstat (limited to 'library')
-rw-r--r--library/alloc/src/collections/binary_heap/mod.rs55
1 files changed, 46 insertions, 9 deletions
diff --git a/library/alloc/src/collections/binary_heap/mod.rs b/library/alloc/src/collections/binary_heap/mod.rs
index 4583bc9a158..e3fef96613b 100644
--- a/library/alloc/src/collections/binary_heap/mod.rs
+++ b/library/alloc/src/collections/binary_heap/mod.rs
@@ -146,6 +146,7 @@
 use core::fmt;
 use core::iter::{FromIterator, FusedIterator, InPlaceIterable, SourceIter, TrustedLen};
 use core::mem::{self, swap, ManuallyDrop};
+use core::num::NonZeroUsize;
 use core::ops::{Deref, DerefMut};
 use core::ptr;
 
@@ -279,7 +280,9 @@ pub struct BinaryHeap<T> {
 #[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
 pub struct PeekMut<'a, T: 'a + Ord> {
     heap: &'a mut BinaryHeap<T>,
-    sift: bool,
+    // If a set_len + sift_down are required, this is Some. If a &mut T has not
+    // yet been exposed to peek_mut()'s caller, it's None.
+    original_len: Option<NonZeroUsize>,
 }
 
 #[stable(feature = "collection_debug", since = "1.17.0")]
@@ -292,7 +295,14 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
 #[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
 impl<T: Ord> Drop for PeekMut<'_, T> {
     fn drop(&mut self) {
-        if self.sift {
+        if let Some(original_len) = self.original_len {
+            // SAFETY: That's how many elements were in the Vec at the time of
+            // the PeekMut::deref_mut call, and therefore also at the time of
+            // the BinaryHeap::peek_mut call. Since the PeekMut did not end up
+            // getting leaked, we are now undoing the leak amplification that
+            // the DerefMut prepared for.
+            unsafe { self.heap.data.set_len(original_len.get()) };
+
             // SAFETY: PeekMut is only instantiated for non-empty heaps.
             unsafe { self.heap.sift_down(0) };
         }
@@ -313,7 +323,26 @@ impl<T: Ord> Deref for PeekMut<'_, T> {
 impl<T: Ord> DerefMut for PeekMut<'_, T> {
     fn deref_mut(&mut self) -> &mut T {
         debug_assert!(!self.heap.is_empty());
-        self.sift = true;
+
+        let len = self.heap.len();
+        if len > 1 {
+            // Here we preemptively leak all the rest of the underlying vector
+            // after the currently max element. If the caller mutates the &mut T
+            // we're about to give them, and then leaks the PeekMut, all these
+            // elements will remain leaked. If they don't leak the PeekMut, then
+            // either Drop or PeekMut::pop will un-leak the vector elements.
+            //
+            // This is technique is described throughout several other places in
+            // the standard library as "leak amplification".
+            unsafe {
+                // SAFETY: len > 1 so len != 0.
+                self.original_len = Some(NonZeroUsize::new_unchecked(len));
+                // SAFETY: len > 1 so all this does for now is leak elements,
+                // which is safe.
+                self.heap.data.set_len(1);
+            }
+        }
+
         // SAFE: PeekMut is only instantiated for non-empty heaps
         unsafe { self.heap.data.get_unchecked_mut(0) }
     }
@@ -323,9 +352,16 @@ impl<'a, T: Ord> PeekMut<'a, T> {
     /// Removes the peeked value from the heap and returns it.
     #[stable(feature = "binary_heap_peek_mut_pop", since = "1.18.0")]
     pub fn pop(mut this: PeekMut<'a, T>) -> T {
-        let value = this.heap.pop().unwrap();
-        this.sift = false;
-        value
+        if let Some(original_len) = this.original_len.take() {
+            // SAFETY: This is how many elements were in the Vec at the time of
+            // the BinaryHeap::peek_mut call.
+            unsafe { this.heap.data.set_len(original_len.get()) };
+
+            // Unlike in Drop, here we don't also need to do a sift_down even if
+            // the caller could've mutated the element. It is removed from the
+            // heap on the next line and pop() is not sensitive to its value.
+        }
+        this.heap.pop().unwrap()
     }
 }
 
@@ -398,8 +434,9 @@ impl<T: Ord> BinaryHeap<T> {
     /// Returns a mutable reference to the greatest item in the binary heap, or
     /// `None` if it is empty.
     ///
-    /// Note: If the `PeekMut` value is leaked, the heap may be in an
-    /// inconsistent state.
+    /// Note: If the `PeekMut` value is leaked, some heap elements might get
+    /// leaked along with it, but the remaining elements will remain a valid
+    /// heap.
     ///
     /// # Examples
     ///
@@ -426,7 +463,7 @@ impl<T: Ord> BinaryHeap<T> {
     /// otherwise it's *O*(1).
     #[stable(feature = "binary_heap_peek_mut", since = "1.12.0")]
     pub fn peek_mut(&mut self) -> Option<PeekMut<'_, T>> {
-        if self.is_empty() { None } else { Some(PeekMut { heap: self, sift: false }) }
+        if self.is_empty() { None } else { Some(PeekMut { heap: self, original_len: None }) }
     }
 
     /// Removes the greatest item from the binary heap and returns it, or `None` if it