about summary refs log tree commit diff
path: root/compiler/rustc_const_eval/src/interpret/memory.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_const_eval/src/interpret/memory.rs')
-rw-r--r--compiler/rustc_const_eval/src/interpret/memory.rs49
1 files changed, 36 insertions, 13 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/memory.rs b/compiler/rustc_const_eval/src/interpret/memory.rs
index 45a5eb9bd52..d87588496c0 100644
--- a/compiler/rustc_const_eval/src/interpret/memory.rs
+++ b/compiler/rustc_const_eval/src/interpret/memory.rs
@@ -8,9 +8,8 @@
 
 use std::assert_matches::assert_matches;
 use std::borrow::Cow;
-use std::cell::Cell;
 use std::collections::VecDeque;
-use std::{fmt, ptr};
+use std::{fmt, mem, ptr};
 
 use rustc_ast::Mutability;
 use rustc_data_structures::fx::{FxHashSet, FxIndexMap};
@@ -118,7 +117,7 @@ pub struct Memory<'tcx, M: Machine<'tcx>> {
     /// This stores whether we are currently doing reads purely for the purpose of validation.
     /// Those reads do not trigger the machine's hooks for memory reads.
     /// Needless to say, this must only be set with great care!
-    validation_in_progress: Cell<bool>,
+    validation_in_progress: bool,
 }
 
 /// A reference to some allocation that was already bounds-checked for the given region
@@ -145,7 +144,7 @@ impl<'tcx, M: Machine<'tcx>> Memory<'tcx, M> {
             alloc_map: M::MemoryMap::default(),
             extra_fn_ptr_map: FxIndexMap::default(),
             dead_alloc_map: FxIndexMap::default(),
-            validation_in_progress: Cell::new(false),
+            validation_in_progress: false,
         }
     }
 
@@ -682,7 +681,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         // We want to call the hook on *all* accesses that involve an AllocId, including zero-sized
         // accesses. That means we cannot rely on the closure above or the `Some` branch below. We
         // do this after `check_and_deref_ptr` to ensure some basic sanity has already been checked.
-        if !self.memory.validation_in_progress.get() {
+        if !self.memory.validation_in_progress {
             if let Ok((alloc_id, ..)) = self.ptr_try_get_alloc_id(ptr, size_i64) {
                 M::before_alloc_read(self, alloc_id)?;
             }
@@ -690,7 +689,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
 
         if let Some((alloc_id, offset, prov, alloc)) = ptr_and_alloc {
             let range = alloc_range(offset, size);
-            if !self.memory.validation_in_progress.get() {
+            if !self.memory.validation_in_progress {
                 M::before_memory_read(
                     self.tcx,
                     &self.machine,
@@ -766,11 +765,14 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         let parts = self.get_ptr_access(ptr, size)?;
         if let Some((alloc_id, offset, prov)) = parts {
             let tcx = self.tcx;
+            let validation_in_progress = self.memory.validation_in_progress;
             // FIXME: can we somehow avoid looking up the allocation twice here?
             // We cannot call `get_raw_mut` inside `check_and_deref_ptr` as that would duplicate `&mut self`.
             let (alloc, machine) = self.get_alloc_raw_mut(alloc_id)?;
             let range = alloc_range(offset, size);
-            M::before_memory_write(tcx, machine, &mut alloc.extra, (alloc_id, prov), range)?;
+            if !validation_in_progress {
+                M::before_memory_write(tcx, machine, &mut alloc.extra, (alloc_id, prov), range)?;
+            }
             Ok(Some(AllocRefMut { alloc, range, tcx: *tcx, alloc_id }))
         } else {
             Ok(None)
@@ -1014,16 +1016,16 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
     ///
     /// We do this so Miri's allocation access tracking does not show the validation
     /// reads as spurious accesses.
-    pub fn run_for_validation<R>(&self, f: impl FnOnce() -> R) -> R {
+    pub fn run_for_validation<R>(&mut self, f: impl FnOnce(&mut Self) -> R) -> R {
         // This deliberately uses `==` on `bool` to follow the pattern
         // `assert!(val.replace(new) == old)`.
         assert!(
-            self.memory.validation_in_progress.replace(true) == false,
+            mem::replace(&mut self.memory.validation_in_progress, true) == false,
             "`validation_in_progress` was already set"
         );
-        let res = f();
+        let res = f(self);
         assert!(
-            self.memory.validation_in_progress.replace(false) == true,
+            mem::replace(&mut self.memory.validation_in_progress, false) == true,
             "`validation_in_progress` was unset by someone else"
         );
         res
@@ -1115,6 +1117,10 @@ impl<'a, 'tcx, M: Machine<'tcx>> std::fmt::Debug for DumpAllocs<'a, 'tcx, M> {
 impl<'tcx, 'a, Prov: Provenance, Extra, Bytes: AllocBytes>
     AllocRefMut<'a, 'tcx, Prov, Extra, Bytes>
 {
+    pub fn as_ref<'b>(&'b self) -> AllocRef<'b, 'tcx, Prov, Extra, Bytes> {
+        AllocRef { alloc: self.alloc, range: self.range, tcx: self.tcx, alloc_id: self.alloc_id }
+    }
+
     /// `range` is relative to this allocation reference, not the base of the allocation.
     pub fn write_scalar(&mut self, range: AllocRange, val: Scalar<Prov>) -> InterpResult<'tcx> {
         let range = self.range.subrange(range);
@@ -1130,13 +1136,30 @@ impl<'tcx, 'a, Prov: Provenance, Extra, Bytes: AllocBytes>
         self.write_scalar(alloc_range(offset, self.tcx.data_layout().pointer_size), val)
     }
 
+    /// Mark the given sub-range (relative to this allocation reference) as uninitialized.
+    pub fn write_uninit(&mut self, range: AllocRange) -> InterpResult<'tcx> {
+        let range = self.range.subrange(range);
+        Ok(self
+            .alloc
+            .write_uninit(&self.tcx, range)
+            .map_err(|e| e.to_interp_error(self.alloc_id))?)
+    }
+
     /// Mark the entire referenced range as uninitialized
-    pub fn write_uninit(&mut self) -> InterpResult<'tcx> {
+    pub fn write_uninit_full(&mut self) -> InterpResult<'tcx> {
         Ok(self
             .alloc
             .write_uninit(&self.tcx, self.range)
             .map_err(|e| e.to_interp_error(self.alloc_id))?)
     }
+
+    /// Remove all provenance in the reference range.
+    pub fn clear_provenance(&mut self) -> InterpResult<'tcx> {
+        Ok(self
+            .alloc
+            .clear_provenance(&self.tcx, self.range)
+            .map_err(|e| e.to_interp_error(self.alloc_id))?)
+    }
 }
 
 impl<'tcx, 'a, Prov: Provenance, Extra, Bytes: AllocBytes> AllocRef<'a, 'tcx, Prov, Extra, Bytes> {
@@ -1278,7 +1301,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
         };
         let src_alloc = self.get_alloc_raw(src_alloc_id)?;
         let src_range = alloc_range(src_offset, size);
-        assert!(!self.memory.validation_in_progress.get(), "we can't be copying during validation");
+        assert!(!self.memory.validation_in_progress, "we can't be copying during validation");
         M::before_memory_read(
             tcx,
             &self.machine,