about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/alloc/src/vec/extract_if.rs64
-rw-r--r--src/tools/miri/tests/pass/vec.rs10
2 files changed, 49 insertions, 25 deletions
diff --git a/library/alloc/src/vec/extract_if.rs b/library/alloc/src/vec/extract_if.rs
index a456d3d9e60..cb9e14f554d 100644
--- a/library/alloc/src/vec/extract_if.rs
+++ b/library/alloc/src/vec/extract_if.rs
@@ -64,27 +64,37 @@ where
     type Item = T;
 
     fn next(&mut self) -> Option<T> {
-        unsafe {
-            while self.idx < self.end {
-                let i = self.idx;
-                let v = slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len);
-                let drained = (self.pred)(&mut v[i]);
-                // Update the index *after* the predicate is called. If the index
-                // is updated prior and the predicate panics, the element at this
-                // index would be leaked.
-                self.idx += 1;
-                if drained {
-                    self.del += 1;
-                    return Some(ptr::read(&v[i]));
-                } else if self.del > 0 {
-                    let del = self.del;
-                    let src: *const T = &v[i];
-                    let dst: *mut T = &mut v[i - del];
-                    ptr::copy_nonoverlapping(src, dst, 1);
+        while self.idx < self.end {
+            let i = self.idx;
+            // SAFETY:
+            //  We know that `i < self.end` from the if guard and that `self.end <= self.old_len` from
+            //  the validity of `Self`. Therefore `i` points to an element within `vec`.
+            //
+            //  Additionally, the i-th element is valid because each element is visited at most once
+            //  and it is the first time we access vec[i].
+            //
+            //  Note: we can't use `vec.get_unchecked_mut(i)` here since the precondition for that
+            //  function is that i < vec.len(), but we've set vec's length to zero.
+            let cur = unsafe { &mut *self.vec.as_mut_ptr().add(i) };
+            let drained = (self.pred)(cur);
+            // Update the index *after* the predicate is called. If the index
+            // is updated prior and the predicate panics, the element at this
+            // index would be leaked.
+            self.idx += 1;
+            if drained {
+                self.del += 1;
+                // SAFETY: We never touch this element again after returning it.
+                return Some(unsafe { ptr::read(cur) });
+            } else if self.del > 0 {
+                // SAFETY: `self.del` > 0, so the hole slot must not overlap with current element.
+                // We use copy for move, and never touch this element again.
+                unsafe {
+                    let hole_slot = self.vec.as_mut_ptr().add(i - self.del);
+                    ptr::copy_nonoverlapping(cur, hole_slot, 1);
                 }
             }
-            None
         }
+        None
     }
 
     fn size_hint(&self) -> (usize, Option<usize>) {
@@ -95,14 +105,18 @@ where
 #[stable(feature = "extract_if", since = "1.87.0")]
 impl<T, F, A: Allocator> Drop for ExtractIf<'_, T, F, A> {
     fn drop(&mut self) {
-        unsafe {
-            if self.idx < self.old_len && self.del > 0 {
-                let ptr = self.vec.as_mut_ptr();
-                let src = ptr.add(self.idx);
-                let dst = src.sub(self.del);
-                let tail_len = self.old_len - self.idx;
-                src.copy_to(dst, tail_len);
+        if self.del > 0 {
+            // SAFETY: Trailing unchecked items must be valid since we never touch them.
+            unsafe {
+                ptr::copy(
+                    self.vec.as_ptr().add(self.idx),
+                    self.vec.as_mut_ptr().add(self.idx - self.del),
+                    self.old_len - self.idx,
+                );
             }
+        }
+        // SAFETY: After filling holes, all items are in contiguous memory.
+        unsafe {
             self.vec.set_len(self.old_len - self.del);
         }
     }
diff --git a/src/tools/miri/tests/pass/vec.rs b/src/tools/miri/tests/pass/vec.rs
index 3e526813bb4..8b1b1e143b1 100644
--- a/src/tools/miri/tests/pass/vec.rs
+++ b/src/tools/miri/tests/pass/vec.rs
@@ -169,6 +169,15 @@ fn miri_issue_2759() {
     input.replace_range(0..0, "0");
 }
 
+/// This was skirting the edge of UB, let's make sure it remains on the sound side.
+/// Context: <https://github.com/rust-lang/rust/pull/141032>.
+fn extract_if() {
+    let mut v = vec![Box::new(0u64), Box::new(1u64)];
+    for item in v.extract_if(.., |x| **x == 0) {
+        drop(item);
+    }
+}
+
 fn main() {
     assert_eq!(vec_reallocate().len(), 5);
 
@@ -199,4 +208,5 @@ fn main() {
     swap_remove();
     reverse();
     miri_issue_2759();
+    extract_if();
 }