about summary refs log tree commit diff
path: root/library/alloc/src/vec/extract_if.rs
diff options
context:
space:
mode:
Diffstat (limited to 'library/alloc/src/vec/extract_if.rs')
-rw-r--r--library/alloc/src/vec/extract_if.rs64
1 files changed, 39 insertions, 25 deletions
diff --git a/library/alloc/src/vec/extract_if.rs b/library/alloc/src/vec/extract_if.rs
index a456d3d9e60..cb9e14f554d 100644
--- a/library/alloc/src/vec/extract_if.rs
+++ b/library/alloc/src/vec/extract_if.rs
@@ -64,27 +64,37 @@ where
     type Item = T;
 
     fn next(&mut self) -> Option<T> {
-        unsafe {
-            while self.idx < self.end {
-                let i = self.idx;
-                let v = slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len);
-                let drained = (self.pred)(&mut v[i]);
-                // Update the index *after* the predicate is called. If the index
-                // is updated prior and the predicate panics, the element at this
-                // index would be leaked.
-                self.idx += 1;
-                if drained {
-                    self.del += 1;
-                    return Some(ptr::read(&v[i]));
-                } else if self.del > 0 {
-                    let del = self.del;
-                    let src: *const T = &v[i];
-                    let dst: *mut T = &mut v[i - del];
-                    ptr::copy_nonoverlapping(src, dst, 1);
+        while self.idx < self.end {
+            let i = self.idx;
+            // SAFETY:
+            //  We know that `i < self.end` from the if guard and that `self.end <= self.old_len` from
+            //  the validity of `Self`. Therefore `i` points to an element within `vec`.
+            //
+            //  Additionally, the i-th element is valid because each element is visited at most once
+            //  and it is the first time we access vec[i].
+            //
+            //  Note: we can't use `vec.get_unchecked_mut(i)` here since the precondition for that
+            //  function is that i < vec.len(), but we've set vec's length to zero.
+            let cur = unsafe { &mut *self.vec.as_mut_ptr().add(i) };
+            let drained = (self.pred)(cur);
+            // Update the index *after* the predicate is called. If the index
+            // is updated prior and the predicate panics, the element at this
+            // index would be leaked.
+            self.idx += 1;
+            if drained {
+                self.del += 1;
+                // SAFETY: We never touch this element again after returning it.
+                return Some(unsafe { ptr::read(cur) });
+            } else if self.del > 0 {
+                // SAFETY: `self.del` > 0, so the hole slot must not overlap with current element.
+                // We use copy for move, and never touch this element again.
+                unsafe {
+                    let hole_slot = self.vec.as_mut_ptr().add(i - self.del);
+                    ptr::copy_nonoverlapping(cur, hole_slot, 1);
                 }
             }
-            None
         }
+        None
     }
 
     fn size_hint(&self) -> (usize, Option<usize>) {
@@ -95,14 +105,18 @@ where
 #[stable(feature = "extract_if", since = "1.87.0")]
 impl<T, F, A: Allocator> Drop for ExtractIf<'_, T, F, A> {
     fn drop(&mut self) {
-        unsafe {
-            if self.idx < self.old_len && self.del > 0 {
-                let ptr = self.vec.as_mut_ptr();
-                let src = ptr.add(self.idx);
-                let dst = src.sub(self.del);
-                let tail_len = self.old_len - self.idx;
-                src.copy_to(dst, tail_len);
+        if self.del > 0 {
+            // SAFETY: Trailing unchecked items must be valid since we never touch them.
+            unsafe {
+                ptr::copy(
+                    self.vec.as_ptr().add(self.idx),
+                    self.vec.as_mut_ptr().add(self.idx - self.del),
+                    self.old_len - self.idx,
+                );
             }
+        }
+        // SAFETY: After filling holes, all items are in contiguous memory.
+        unsafe {
             self.vec.set_len(self.old_len - self.del);
         }
     }