about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2024-03-15 11:59:38 +0100
committerRalf Jung <post@ralfj.de>2024-03-15 15:58:37 +0100
commit7be47b219df42cba96560a97de707967cca8c762 (patch)
tree7bdb3b2cb894364261f8f034ecbc3ce4d97fb30d
parentee03c286cfdca26fa5b2a4ee40957625d2c826ff (diff)
downloadrust-7be47b219df42cba96560a97de707967cca8c762.tar.gz
rust-7be47b219df42cba96560a97de707967cca8c762.zip
interpret/allocation: fix aliasing issue in interpreter and refactor getters a bit
- rename mutating functions to be more scary
- add a new raw bytes getter
-rw-r--r--compiler/rustc_const_eval/src/interpret/memory.rs17
-rw-r--r--compiler/rustc_middle/src/mir/interpret/allocation.rs42
2 files changed, 45 insertions, 14 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/memory.rs b/compiler/rustc_const_eval/src/interpret/memory.rs
index 86aad2e1642..9e6776ab55f 100644
--- a/compiler/rustc_const_eval/src/interpret/memory.rs
+++ b/compiler/rustc_const_eval/src/interpret/memory.rs
@@ -1157,11 +1157,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
         };
 
         // Side-step AllocRef and directly access the underlying bytes more efficiently.
-        // (We are staying inside the bounds here so all is good.)
+        // (We are staying inside the bounds here and all bytes do get overwritten so all is good.)
         let alloc_id = alloc_ref.alloc_id;
         let bytes = alloc_ref
             .alloc
-            .get_bytes_mut(&alloc_ref.tcx, alloc_ref.range)
+            .get_bytes_unchecked_for_overwrite(&alloc_ref.tcx, alloc_ref.range)
             .map_err(move |e| e.to_interp_error(alloc_id))?;
         // `zip` would stop when the first iterator ends; we want to definitely
         // cover all of `bytes`.
@@ -1182,6 +1182,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
         self.mem_copy_repeatedly(src, dest, size, 1, nonoverlapping)
     }
 
+    /// Performs `num_copies` many copies of `size` many bytes from `src` to `dest + i*size` (where
+    /// `i` is the index of the copy).
+    ///
+    /// Either `nonoverlapping` must be true or `num_copies` must be 1; doing repeated copies that
+    /// may overlap is not supported.
     pub fn mem_copy_repeatedly(
         &mut self,
         src: Pointer<Option<M::Provenance>>,
@@ -1243,8 +1248,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
             (dest_alloc_id, dest_prov),
             dest_range,
         )?;
+        // Yes we do overwrite all bytes in `dest_bytes`.
         let dest_bytes = dest_alloc
-            .get_bytes_mut_ptr(&tcx, dest_range)
+            .get_bytes_unchecked_for_overwrite_ptr(&tcx, dest_range)
             .map_err(|e| e.to_interp_error(dest_alloc_id))?
             .as_mut_ptr();
 
@@ -1278,6 +1284,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                     }
                 }
             }
+            if num_copies > 1 {
+                assert!(nonoverlapping, "multi-copy only supported in non-overlapping mode");
+            }
 
             let size_in_bytes = size.bytes_usize();
             // For particularly large arrays (where this is perf-sensitive) it's common that
@@ -1290,6 +1299,8 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
             } else if src_alloc_id == dest_alloc_id {
                 let mut dest_ptr = dest_bytes;
                 for _ in 0..num_copies {
+                    // Here we rely on `src` and `dest` being non-overlapping if there is more than
+                    // one copy.
                     ptr::copy(src_bytes, dest_ptr, size_in_bytes);
                     dest_ptr = dest_ptr.add(size_in_bytes);
                 }
diff --git a/compiler/rustc_middle/src/mir/interpret/allocation.rs b/compiler/rustc_middle/src/mir/interpret/allocation.rs
index 4047891d769..00faa211853 100644
--- a/compiler/rustc_middle/src/mir/interpret/allocation.rs
+++ b/compiler/rustc_middle/src/mir/interpret/allocation.rs
@@ -37,9 +37,16 @@ pub trait AllocBytes:
     /// Create a zeroed `AllocBytes` of the specified size and alignment.
     /// Returns `None` if we ran out of memory on the host.
     fn zeroed(size: Size, _align: Align) -> Option<Self>;
+
+    /// Gives direct access to the raw underlying storage.
+    ///
+    /// Crucially this pointer is compatible with:
+    /// - other pointers retunred by this method, and
+    /// - references returned from `deref()`, as long as there was no write.
+    fn as_mut_ptr(&mut self) -> *mut u8;
 }
 
-// Default `bytes` for `Allocation` is a `Box<[u8]>`.
+/// Default `bytes` for `Allocation` is a `Box<u8>`.
 impl AllocBytes for Box<[u8]> {
     fn from_bytes<'a>(slice: impl Into<Cow<'a, [u8]>>, _align: Align) -> Self {
         Box::<[u8]>::from(slice.into())
@@ -51,6 +58,11 @@ impl AllocBytes for Box<[u8]> {
         let bytes = unsafe { bytes.assume_init() };
         Some(bytes)
     }
+
+    fn as_mut_ptr(&mut self) -> *mut u8 {
+        // Carefully avoiding any intermediate references.
+        ptr::addr_of_mut!(**self).cast()
+    }
 }
 
 /// This type represents an Allocation in the Miri/CTFE core engine.
@@ -399,10 +411,6 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
 
 /// Byte accessors.
 impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes> {
-    pub fn base_addr(&self) -> *const u8 {
-        self.bytes.as_ptr()
-    }
-
     /// This is the entirely abstraction-violating way to just grab the raw bytes without
     /// caring about provenance or initialization.
     ///
@@ -452,13 +460,14 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
         Ok(self.get_bytes_unchecked(range))
     }
 
-    /// Just calling this already marks everything as defined and removes provenance,
-    /// so be sure to actually put data there!
+    /// This is the entirely abstraction-violating way to just get mutable access to the raw bytes.
+    /// Just calling this already marks everything as defined and removes provenance, so be sure to
+    /// actually overwrite all the data there!
     ///
     /// It is the caller's responsibility to check bounds and alignment beforehand.
     /// Most likely, you want to use the `PlaceTy` and `OperandTy`-based methods
     /// on `InterpCx` instead.
-    pub fn get_bytes_mut(
+    pub fn get_bytes_unchecked_for_overwrite(
         &mut self,
         cx: &impl HasDataLayout,
         range: AllocRange,
@@ -469,8 +478,9 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
         Ok(&mut self.bytes[range.start.bytes_usize()..range.end().bytes_usize()])
     }
 
-    /// A raw pointer variant of `get_bytes_mut` that avoids invalidating existing aliases into this memory.
-    pub fn get_bytes_mut_ptr(
+    /// A raw pointer variant of `get_bytes_unchecked_for_overwrite` that avoids invalidating existing immutable aliases
+    /// into this memory.
+    pub fn get_bytes_unchecked_for_overwrite_ptr(
         &mut self,
         cx: &impl HasDataLayout,
         range: AllocRange,
@@ -479,10 +489,19 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
         self.provenance.clear(range, cx)?;
 
         assert!(range.end().bytes_usize() <= self.bytes.len()); // need to do our own bounds-check
+        // Cruciall, we go via `AllocBytes::as_mut_ptr`, not `AllocBytes::deref_mut`.
         let begin_ptr = self.bytes.as_mut_ptr().wrapping_add(range.start.bytes_usize());
         let len = range.end().bytes_usize() - range.start.bytes_usize();
         Ok(ptr::slice_from_raw_parts_mut(begin_ptr, len))
     }
+
+    /// This gives direct mutable access to the entire buffer, just exposing their internal state
+    /// without reseting anything. Directly exposes `AllocBytes::as_mut_ptr`. Only works if
+    /// `OFFSET_IS_ADDR` is true.
+    pub fn get_bytes_unchecked_raw_mut(&mut self) -> *mut u8 {
+        assert!(Prov::OFFSET_IS_ADDR);
+        self.bytes.as_mut_ptr()
+    }
 }
 
 /// Reading and writing.
@@ -589,7 +608,8 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
         };
 
         let endian = cx.data_layout().endian;
-        let dst = self.get_bytes_mut(cx, range)?;
+        // Yes we do overwrite all the bytes in `dst`.
+        let dst = self.get_bytes_unchecked_for_overwrite(cx, range)?;
         write_target_uint(endian, dst, bytes).unwrap();
 
         // See if we have to also store some provenance.