about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorManish Goregaokar <manishsmail@gmail.com>2020-06-28 08:30:24 -0700
committerGitHub <noreply@github.com>2020-06-28 08:30:24 -0700
commitccc1bf79c8be0b4be549f6f82141104a34efec80 (patch)
treeb9a736ae696425d03799e543641a436c101883e7 /src
parentec4898977a849fc73c8d3198e45c6f17c2bf177a (diff)
parentb9f4e0dbfd3181f9f57a650e7451c807116c5921 (diff)
downloadrust-ccc1bf79c8be0b4be549f6f82141104a34efec80.tar.gz
rust-ccc1bf79c8be0b4be549f6f82141104a34efec80.zip
Rollup merge of #73757 - oli-obk:const_prop_hardening, r=wesleywiser
Const prop: erase all block-only locals at the end of every block

I messed up this erasure in https://github.com/rust-lang/rust/pull/73656#discussion_r446040140. I think it is too fragile to have the previous scheme. Let's benchmark the new scheme and see what happens.

r? @wesleywiser

cc @felix91gr
Diffstat (limited to 'src')
-rw-r--r--src/librustc_mir/interpret/eval_context.rs7
-rw-r--r--src/librustc_mir/interpret/machine.rs19
-rw-r--r--src/librustc_mir/interpret/operand.rs6
-rw-r--r--src/librustc_mir/interpret/place.rs6
-rw-r--r--src/librustc_mir/transform/const_prop.rs68
5 files changed, 90 insertions, 16 deletions
diff --git a/src/librustc_mir/interpret/eval_context.rs b/src/librustc_mir/interpret/eval_context.rs
index 9c72a18c6d4..602876e3de1 100644
--- a/src/librustc_mir/interpret/eval_context.rs
+++ b/src/librustc_mir/interpret/eval_context.rs
@@ -132,6 +132,10 @@ pub enum LocalValue<Tag = ()> {
 }
 
 impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
+    /// Read the local's value or error if the local is not yet live or not live anymore.
+    ///
+    /// Note: This may only be invoked from the `Machine::access_local` hook and not from
+    /// anywhere else. You may be invalidating machine invariants if you do!
     pub fn access(&self) -> InterpResult<'tcx, Operand<Tag>> {
         match self.value {
             LocalValue::Dead => throw_ub!(DeadLocal),
@@ -144,6 +148,9 @@ impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
 
     /// Overwrite the local.  If the local can be overwritten in place, return a reference
     /// to do so; otherwise return the `MemPlace` to consult instead.
+    ///
+    /// Note: This may only be invoked from the `Machine::access_local_mut` hook and not from
+    /// anywhere else. You may be invalidating machine invariants if you do!
     pub fn access_mut(
         &mut self,
     ) -> InterpResult<'tcx, Result<&mut LocalValue<Tag>, MemPlace<Tag>>> {
diff --git a/src/librustc_mir/interpret/machine.rs b/src/librustc_mir/interpret/machine.rs
index b5dc40d9551..ec1c93c8165 100644
--- a/src/librustc_mir/interpret/machine.rs
+++ b/src/librustc_mir/interpret/machine.rs
@@ -11,7 +11,7 @@ use rustc_span::def_id::DefId;
 
 use super::{
     AllocId, Allocation, AllocationExtra, CheckInAllocMsg, Frame, ImmTy, InterpCx, InterpResult,
-    Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
+    LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
 };
 
 /// Data returned by Machine::stack_pop,
@@ -192,6 +192,8 @@ pub trait Machine<'mir, 'tcx>: Sized {
     ) -> InterpResult<'tcx>;
 
     /// Called to read the specified `local` from the `frame`.
+    /// Since reading a ZST is not actually accessing memory or locals, this is never invoked
+    /// for ZST reads.
     #[inline]
     fn access_local(
         _ecx: &InterpCx<'mir, 'tcx, Self>,
@@ -201,6 +203,21 @@ pub trait Machine<'mir, 'tcx>: Sized {
         frame.locals[local].access()
     }
 
+    /// Called to write the specified `local` from the `frame`.
+    /// Since writing a ZST is not actually accessing memory or locals, this is never invoked
+    /// for ZST reads.
+    #[inline]
+    fn access_local_mut<'a>(
+        ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
+        frame: usize,
+        local: mir::Local,
+    ) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
+    where
+        'tcx: 'mir,
+    {
+        ecx.stack_mut()[frame].locals[local].access_mut()
+    }
+
     /// Called before a basic block terminator is executed.
     /// You can use this to detect endlessly running programs.
     #[inline]
diff --git a/src/librustc_mir/interpret/operand.rs b/src/librustc_mir/interpret/operand.rs
index fd55deaf83b..b02b5219ba1 100644
--- a/src/librustc_mir/interpret/operand.rs
+++ b/src/librustc_mir/interpret/operand.rs
@@ -432,7 +432,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
         })
     }
 
-    /// This is used by [priroda](https://github.com/oli-obk/priroda) to get an OpTy from a local
+    /// Read from a local. Will not actually access the local if reading from a ZST.
+    /// Will not access memory, instead an indirect `Operand` is returned.
+    ///
+    /// This is public because it is used by [priroda](https://github.com/oli-obk/priroda) to get an
+    /// OpTy from a local
     pub fn access_local(
         &self,
         frame: &super::Frame<'mir, 'tcx, M::PointerTag, M::FrameExtra>,
diff --git a/src/librustc_mir/interpret/place.rs b/src/librustc_mir/interpret/place.rs
index 98a1cea97e2..3868150c6bd 100644
--- a/src/librustc_mir/interpret/place.rs
+++ b/src/librustc_mir/interpret/place.rs
@@ -741,7 +741,7 @@ where
         // but not factored as a separate function.
         let mplace = match dest.place {
             Place::Local { frame, local } => {
-                match self.stack_mut()[frame].locals[local].access_mut()? {
+                match M::access_local_mut(self, frame, local)? {
                     Ok(local) => {
                         // Local can be updated in-place.
                         *local = LocalValue::Live(Operand::Immediate(src));
@@ -974,7 +974,7 @@ where
     ) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, Option<Size>)> {
         let (mplace, size) = match place.place {
             Place::Local { frame, local } => {
-                match self.stack_mut()[frame].locals[local].access_mut()? {
+                match M::access_local_mut(self, frame, local)? {
                     Ok(&mut local_val) => {
                         // We need to make an allocation.
 
@@ -998,7 +998,7 @@ where
                         }
                         // Now we can call `access_mut` again, asserting it goes well,
                         // and actually overwrite things.
-                        *self.stack_mut()[frame].locals[local].access_mut().unwrap().unwrap() =
+                        *M::access_local_mut(self, frame, local).unwrap().unwrap() =
                             LocalValue::Live(Operand::Indirect(mplace));
                         (mplace, Some(size))
                     }
diff --git a/src/librustc_mir/transform/const_prop.rs b/src/librustc_mir/transform/const_prop.rs
index a891f12c8e1..841f1c2b647 100644
--- a/src/librustc_mir/transform/const_prop.rs
+++ b/src/librustc_mir/transform/const_prop.rs
@@ -4,6 +4,7 @@
 use std::cell::Cell;
 
 use rustc_ast::ast::Mutability;
+use rustc_data_structures::fx::FxHashSet;
 use rustc_hir::def::DefKind;
 use rustc_hir::HirId;
 use rustc_index::bit_set::BitSet;
@@ -28,7 +29,7 @@ use rustc_trait_selection::traits;
 use crate::const_eval::error_to_const_error;
 use crate::interpret::{
     self, compile_time_machine, AllocId, Allocation, Frame, ImmTy, Immediate, InterpCx, LocalState,
-    LocalValue, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
+    LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
     ScalarMaybeUninit, StackPopCleanup,
 };
 use crate::transform::{MirPass, MirSource};
@@ -151,11 +152,19 @@ impl<'tcx> MirPass<'tcx> for ConstProp {
 struct ConstPropMachine<'mir, 'tcx> {
     /// The virtual call stack.
     stack: Vec<Frame<'mir, 'tcx, (), ()>>,
+    /// `OnlyInsideOwnBlock` locals that were written in the current block get erased at the end.
+    written_only_inside_own_block_locals: FxHashSet<Local>,
+    /// Locals that need to be cleared after every block terminates.
+    only_propagate_inside_block_locals: BitSet<Local>,
 }
 
 impl<'mir, 'tcx> ConstPropMachine<'mir, 'tcx> {
-    fn new() -> Self {
-        Self { stack: Vec::new() }
+    fn new(only_propagate_inside_block_locals: BitSet<Local>) -> Self {
+        Self {
+            stack: Vec::new(),
+            written_only_inside_own_block_locals: Default::default(),
+            only_propagate_inside_block_locals,
+        }
     }
 }
 
@@ -227,6 +236,18 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx>
         l.access()
     }
 
+    fn access_local_mut<'a>(
+        ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
+        frame: usize,
+        local: Local,
+    ) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
+    {
+        if frame == 0 && ecx.machine.only_propagate_inside_block_locals.contains(local) {
+            ecx.machine.written_only_inside_own_block_locals.insert(local);
+        }
+        ecx.machine.stack[frame].locals[local].access_mut()
+    }
+
     fn before_access_global(
         _memory_extra: &(),
         _alloc_id: AllocId,
@@ -274,8 +295,6 @@ struct ConstPropagator<'mir, 'tcx> {
     // Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store
     // the last known `SourceInfo` here and just keep revisiting it.
     source_info: Option<SourceInfo>,
-    // Locals we need to forget at the end of the current block
-    locals_of_current_block: BitSet<Local>,
 }
 
 impl<'mir, 'tcx> LayoutOf for ConstPropagator<'mir, 'tcx> {
@@ -313,8 +332,20 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
         let param_env = tcx.param_env(def_id).with_reveal_all();
 
         let span = tcx.def_span(def_id);
-        let mut ecx = InterpCx::new(tcx, span, param_env, ConstPropMachine::new(), ());
         let can_const_prop = CanConstProp::check(body);
+        let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len());
+        for (l, mode) in can_const_prop.iter_enumerated() {
+            if *mode == ConstPropMode::OnlyInsideOwnBlock {
+                only_propagate_inside_block_locals.insert(l);
+            }
+        }
+        let mut ecx = InterpCx::new(
+            tcx,
+            span,
+            param_env,
+            ConstPropMachine::new(only_propagate_inside_block_locals),
+            (),
+        );
 
         let ret = ecx
             .layout_of(body.return_ty().subst(tcx, substs))
@@ -345,7 +376,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
             //FIXME(wesleywiser) we can't steal this because `Visitor::super_visit_body()` needs it
             local_decls: body.local_decls.clone(),
             source_info: None,
-            locals_of_current_block: BitSet::new_empty(body.local_decls.len()),
         }
     }
 
@@ -900,7 +930,6 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
                                 Will remove it from const-prop after block is finished. Local: {:?}",
                                 place.local
                             );
-                            self.locals_of_current_block.insert(place.local);
                         }
                         ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
                             trace!("can't propagate into {:?}", place);
@@ -1089,10 +1118,27 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
                 }
             }
         }
-        // We remove all Locals which are restricted in propagation to their containing blocks.
-        for local in self.locals_of_current_block.iter() {
+
+        // We remove all Locals which are restricted in propagation to their containing blocks and
+        // which were modified in the current block.
+        // Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`
+        let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals);
+        for &local in locals.iter() {
             Self::remove_const(&mut self.ecx, local);
         }
-        self.locals_of_current_block.clear();
+        locals.clear();
+        // Put it back so we reuse the heap of the storage
+        self.ecx.machine.written_only_inside_own_block_locals = locals;
+        if cfg!(debug_assertions) {
+            // Ensure we are correctly erasing locals with the non-debug-assert logic.
+            for local in self.ecx.machine.only_propagate_inside_block_locals.iter() {
+                assert!(
+                    self.get_const(local.into()).is_none()
+                        || self
+                            .layout_of(self.local_decls[local].ty)
+                            .map_or(true, |layout| layout.is_zst())
+                )
+            }
+        }
     }
 }