about summary refs log tree commit diff
diff options
context:
space:
mode:
authorPetros Angelatos <petrosagg@gmail.com>2025-05-15 15:46:53 +0300
committerPetros Angelatos <petrosagg@gmail.com>2025-09-22 10:59:52 +0300
commite9b2c4f395673399b0445733c7e8d3955e42ff0c (patch)
tree8904930357d9c6d8eed262aa706db7b1b01f1223
parentb00998aaa50a78b1e45fb107a6c0cfd3f1dc44dd (diff)
downloadrust-e9b2c4f395673399b0445733c7e8d3955e42ff0c.tar.gz
rust-e9b2c4f395673399b0445733c7e8d3955e42ff0c.zip
avoid violating `slice::from_raw_parts` safety contract in `Vec::extract_if`
The implementation of the `Vec::extract_if` iterator violates the safety
contract adverized by `slice::from_raw_parts` by always constructing a
mutable slice for the entire length of the vector even though that span
of memory can contain holes from items already drained. The safety
contract of `slice::from_raw_parts` requires that all elements must be
properly initialized.

As an example we can look at the following code:

```rust
let mut v = vec![Box::new(0u64), Box::new(1u64)];
for item in v.extract_if(.., |x| **x == 0) {
    drop(item);
}
```

In the second iteration a `&mut [Box<u64>]` slice of length 2 will be
constructed. The first slot of the slice contains the bitpattern of an
already deallocated box, which is invalid.

This fixes the issue by only creating references to valid items and
using pointer manipulation for the rest. I have also taken the liberty
to remove the big `unsafe` blocks in place of targetted ones with a
SAFETY comment. The approach closely mirrors the implementation of
`Vec::retain_mut`.

Signed-off-by: Petros Angelatos <petrosagg@gmail.com>
-rw-r--r--library/alloc/src/vec/extract_if.rs64
-rw-r--r--src/tools/miri/tests/pass/vec.rs10
2 files changed, 49 insertions, 25 deletions
diff --git a/library/alloc/src/vec/extract_if.rs b/library/alloc/src/vec/extract_if.rs
index a456d3d9e60..cb9e14f554d 100644
--- a/library/alloc/src/vec/extract_if.rs
+++ b/library/alloc/src/vec/extract_if.rs
@@ -64,27 +64,37 @@ where
     type Item = T;
 
     fn next(&mut self) -> Option<T> {
-        unsafe {
-            while self.idx < self.end {
-                let i = self.idx;
-                let v = slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len);
-                let drained = (self.pred)(&mut v[i]);
-                // Update the index *after* the predicate is called. If the index
-                // is updated prior and the predicate panics, the element at this
-                // index would be leaked.
-                self.idx += 1;
-                if drained {
-                    self.del += 1;
-                    return Some(ptr::read(&v[i]));
-                } else if self.del > 0 {
-                    let del = self.del;
-                    let src: *const T = &v[i];
-                    let dst: *mut T = &mut v[i - del];
-                    ptr::copy_nonoverlapping(src, dst, 1);
+        while self.idx < self.end {
+            let i = self.idx;
+            // SAFETY:
+            //  We know that `i < self.end` from the if guard and that `self.end <= self.old_len` from
+            //  the validity of `Self`. Therefore `i` points to an element within `vec`.
+            //
+            //  Additionally, the i-th element is valid because each element is visited at most once
+            //  and it is the first time we access vec[i].
+            //
+            //  Note: we can't use `vec.get_unchecked_mut(i)` here since the precondition for that
+            //  function is that i < vec.len(), but we've set vec's length to zero.
+            let cur = unsafe { &mut *self.vec.as_mut_ptr().add(i) };
+            let drained = (self.pred)(cur);
+            // Update the index *after* the predicate is called. If the index
+            // is updated prior and the predicate panics, the element at this
+            // index would be leaked.
+            self.idx += 1;
+            if drained {
+                self.del += 1;
+                // SAFETY: We never touch this element again after returning it.
+                return Some(unsafe { ptr::read(cur) });
+            } else if self.del > 0 {
+                // SAFETY: `self.del` > 0, so the hole slot must not overlap with current element.
+                // We use copy for move, and never touch this element again.
+                unsafe {
+                    let hole_slot = self.vec.as_mut_ptr().add(i - self.del);
+                    ptr::copy_nonoverlapping(cur, hole_slot, 1);
                 }
             }
-            None
         }
+        None
     }
 
     fn size_hint(&self) -> (usize, Option<usize>) {
@@ -95,14 +105,18 @@ where
 #[stable(feature = "extract_if", since = "1.87.0")]
 impl<T, F, A: Allocator> Drop for ExtractIf<'_, T, F, A> {
     fn drop(&mut self) {
-        unsafe {
-            if self.idx < self.old_len && self.del > 0 {
-                let ptr = self.vec.as_mut_ptr();
-                let src = ptr.add(self.idx);
-                let dst = src.sub(self.del);
-                let tail_len = self.old_len - self.idx;
-                src.copy_to(dst, tail_len);
+        if self.del > 0 {
+            // SAFETY: Trailing unchecked items must be valid since we never touch them.
+            unsafe {
+                ptr::copy(
+                    self.vec.as_ptr().add(self.idx),
+                    self.vec.as_mut_ptr().add(self.idx - self.del),
+                    self.old_len - self.idx,
+                );
             }
+        }
+        // SAFETY: After filling holes, all items are in contiguous memory.
+        unsafe {
             self.vec.set_len(self.old_len - self.del);
         }
     }
diff --git a/src/tools/miri/tests/pass/vec.rs b/src/tools/miri/tests/pass/vec.rs
index 3e526813bb4..8b1b1e143b1 100644
--- a/src/tools/miri/tests/pass/vec.rs
+++ b/src/tools/miri/tests/pass/vec.rs
@@ -169,6 +169,15 @@ fn miri_issue_2759() {
     input.replace_range(0..0, "0");
 }
 
+/// This was skirting the edge of UB, let's make sure it remains on the sound side.
+/// Context: <https://github.com/rust-lang/rust/pull/141032>.
+fn extract_if() {
+    let mut v = vec![Box::new(0u64), Box::new(1u64)];
+    for item in v.extract_if(.., |x| **x == 0) {
+        drop(item);
+    }
+}
+
 fn main() {
     assert_eq!(vec_reallocate().len(), 5);
 
@@ -199,4 +208,5 @@ fn main() {
     swap_remove();
     reverse();
     miri_issue_2759();
+    extract_if();
 }