about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/alloc/src/collections/binary_heap.rs85
1 files changed, 53 insertions, 32 deletions
diff --git a/library/alloc/src/collections/binary_heap.rs b/library/alloc/src/collections/binary_heap.rs
index bf9f7432fb5..a201af01030 100644
--- a/library/alloc/src/collections/binary_heap.rs
+++ b/library/alloc/src/collections/binary_heap.rs
@@ -652,6 +652,43 @@ impl<T: Ord> BinaryHeap<T> {
         unsafe { self.sift_up(start, pos) };
     }
 
+    /// Rebuild assuming data[0..start] is still a proper heap.
+    fn rebuild_tail(&mut self, start: usize) {
+        if start == self.len() {
+            return;
+        }
+
+        let tail_len = self.len() - start;
+
+        #[inline(always)]
+        fn log2_fast(x: usize) -> usize {
+            (usize::BITS - x.leading_zeros() - 1) as usize
+        }
+
+        // `rebuild` takes O(self.len()) operations
+        // and about 2 * self.len() comparisons in the worst case
+        // while repeating `sift_up` takes O(tail_len * log(start)) operations
+        // and about 1 * tail_len * log_2(start) comparisons in the worst case,
+        // assuming start >= tail_len. For larger heaps, the crossover point
+        // no longer follows this reasoning and was determined empirically.
+        let better_to_rebuild = if start < tail_len {
+            true
+        } else if self.len() <= 2048 {
+            2 * self.len() < tail_len * log2_fast(start)
+        } else {
+            2 * self.len() < tail_len * 11
+        };
+
+        if better_to_rebuild {
+            self.rebuild();
+        } else {
+            for i in start..self.len() {
+                // SAFETY: The index `i` is always less than self.len().
+                unsafe { self.sift_up(0, i) };
+            }
+        }
+    }
+
     fn rebuild(&mut self) {
         let mut n = self.len() / 2;
         while n > 0 {
@@ -689,37 +726,11 @@ impl<T: Ord> BinaryHeap<T> {
             swap(self, other);
         }
 
-        if other.is_empty() {
-            return;
-        }
-
-        #[inline(always)]
-        fn log2_fast(x: usize) -> usize {
-            (usize::BITS - x.leading_zeros() - 1) as usize
-        }
+        let start = self.data.len();
 
-        // `rebuild` takes O(len1 + len2) operations
-        // and about 2 * (len1 + len2) comparisons in the worst case
-        // while `extend` takes O(len2 * log(len1)) operations
-        // and about 1 * len2 * log_2(len1) comparisons in the worst case,
-        // assuming len1 >= len2. For larger heaps, the crossover point
-        // no longer follows this reasoning and was determined empirically.
-        #[inline]
-        fn better_to_rebuild(len1: usize, len2: usize) -> bool {
-            let tot_len = len1 + len2;
-            if tot_len <= 2048 {
-                2 * tot_len < len2 * log2_fast(len1)
-            } else {
-                2 * tot_len < len2 * 11
-            }
-        }
+        self.data.append(&mut other.data);
 
-        if better_to_rebuild(self.len(), other.len()) {
-            self.data.append(&mut other.data);
-            self.rebuild();
-        } else {
-            self.extend(other.drain());
-        }
+        self.rebuild_tail(start);
     }
 
     /// Returns an iterator which retrieves elements in heap order.
@@ -770,12 +781,22 @@ impl<T: Ord> BinaryHeap<T> {
     /// assert_eq!(heap.into_sorted_vec(), [-10, 2, 4])
     /// ```
     #[unstable(feature = "binary_heap_retain", issue = "71503")]
-    pub fn retain<F>(&mut self, f: F)
+    pub fn retain<F>(&mut self, mut f: F)
     where
         F: FnMut(&T) -> bool,
     {
-        self.data.retain(f);
-        self.rebuild();
+        let mut first_removed = self.len();
+        let mut i = 0;
+        self.data.retain(|e| {
+            let keep = f(e);
+            if !keep && i < first_removed {
+                first_removed = i;
+            }
+            i += 1;
+            keep
+        });
+        // data[0..first_removed] is untouched, so we only need to rebuild the tail:
+        self.rebuild_tail(first_removed);
     }
 }