about summary refs log tree commit diff
diff options
context:
space:
mode:
authorManish Goregaokar <manishsmail@gmail.com>2020-06-23 13:10:03 -0700
committerGitHub <noreply@github.com>2020-06-23 13:10:03 -0700
commit781b5899970f08488a540392a5f48c03bbf98a1e (patch)
tree549b4962927dc04a384b2567ab3c957decfc7830
parent317a15142e842656f913d7debd2bc18a367948e4 (diff)
parentb2ec645d16ce9a3345b2f9cb527ca52e86e54324 (diff)
downloadrust-781b5899970f08488a540392a5f48c03bbf98a1e.tar.gz
rust-781b5899970f08488a540392a5f48c03bbf98a1e.zip
Rollup merge of #73244 - ecstatic-morse:validate-generator-mir, r=tmandry
Check for assignments between non-conflicting generator saved locals

This is to prevent future changes to the generator transform from reintroducing the problem that caused #73137. Namely, a store between two generator saved locals whose storage does not conflict.

My ultimate goal is to introduce a modified version of #71956 that handles this case properly.

r? @tmandry
-rw-r--r--src/librustc_mir/transform/generator.rs269
1 files changed, 213 insertions, 56 deletions
diff --git a/src/librustc_mir/transform/generator.rs b/src/librustc_mir/transform/generator.rs
index c8702eeae1d..59be8dc224d 100644
--- a/src/librustc_mir/transform/generator.rs
+++ b/src/librustc_mir/transform/generator.rs
@@ -64,7 +64,7 @@ use rustc_hir::def_id::DefId;
 use rustc_hir::lang_items::{GeneratorStateLangItem, PinTypeLangItem};
 use rustc_index::bit_set::{BitMatrix, BitSet};
 use rustc_index::vec::{Idx, IndexVec};
-use rustc_middle::mir::visit::{MutVisitor, PlaceContext};
+use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
 use rustc_middle::mir::*;
 use rustc_middle::ty::subst::SubstsRef;
 use rustc_middle::ty::GeneratorSubsts;
@@ -72,7 +72,7 @@ use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
 use rustc_target::abi::VariantIdx;
 use rustc_target::spec::PanicStrategy;
 use std::borrow::Cow;
-use std::iter;
+use std::{iter, ops};
 
 pub struct StateTransform;
 
@@ -417,11 +417,7 @@ fn replace_local<'tcx>(
 
 struct LivenessInfo {
     /// Which locals are live across any suspension point.
-    ///
-    /// GeneratorSavedLocal is indexed in terms of the elements in this set;
-    /// i.e. GeneratorSavedLocal::new(1) corresponds to the second local
-    /// included in this set.
-    live_locals: BitSet<Local>,
+    saved_locals: GeneratorSavedLocals,
 
     /// The set of saved locals live at each suspension point.
     live_locals_at_suspension_points: Vec<BitSet<GeneratorSavedLocal>>,
@@ -524,49 +520,75 @@ fn locals_live_across_suspend_points(
             live_locals_at_suspension_points.push(live_locals);
         }
     }
+
     debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point);
+    let saved_locals = GeneratorSavedLocals(live_locals_at_any_suspension_point);
 
     // Renumber our liveness_map bitsets to include only the locals we are
     // saving.
     let live_locals_at_suspension_points = live_locals_at_suspension_points
         .iter()
-        .map(|live_here| renumber_bitset(&live_here, &live_locals_at_any_suspension_point))
+        .map(|live_here| saved_locals.renumber_bitset(&live_here))
         .collect();
 
     let storage_conflicts = compute_storage_conflicts(
         body_ref,
-        &live_locals_at_any_suspension_point,
+        &saved_locals,
         always_live_locals.clone(),
         requires_storage_results,
     );
 
     LivenessInfo {
-        live_locals: live_locals_at_any_suspension_point,
+        saved_locals,
         live_locals_at_suspension_points,
         storage_conflicts,
         storage_liveness: storage_liveness_map,
     }
 }
 
-/// Renumbers the items present in `stored_locals` and applies the renumbering
-/// to 'input`.
+/// The set of `Local`s that must be saved across yield points.
 ///
-/// For example, if `stored_locals = [1, 3, 5]`, this would be renumbered to
-/// `[0, 1, 2]`. Thus, if `input = [3, 5]` we would return `[1, 2]`.
-fn renumber_bitset(
-    input: &BitSet<Local>,
-    stored_locals: &BitSet<Local>,
-) -> BitSet<GeneratorSavedLocal> {
-    assert!(stored_locals.superset(&input), "{:?} not a superset of {:?}", stored_locals, input);
-    let mut out = BitSet::new_empty(stored_locals.count());
-    for (idx, local) in stored_locals.iter().enumerate() {
-        let saved_local = GeneratorSavedLocal::from(idx);
-        if input.contains(local) {
-            out.insert(saved_local);
+/// `GeneratorSavedLocal` is indexed in terms of the elements in this set;
+/// i.e. `GeneratorSavedLocal::new(1)` corresponds to the second local
+/// included in this set.
+struct GeneratorSavedLocals(BitSet<Local>);
+
+impl GeneratorSavedLocals {
+    /// Returns an iterator over each `GeneratorSavedLocal` along with the `Local` it corresponds
+    /// to.
+    fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (GeneratorSavedLocal, Local)> {
+        self.iter().enumerate().map(|(i, l)| (GeneratorSavedLocal::from(i), l))
+    }
+
+    /// Transforms a `BitSet<Local>` that contains only locals saved across yield points to the
+    /// equivalent `BitSet<GeneratorSavedLocal>`.
+    fn renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<GeneratorSavedLocal> {
+        assert!(self.superset(&input), "{:?} not a superset of {:?}", self.0, input);
+        let mut out = BitSet::new_empty(self.count());
+        for (saved_local, local) in self.iter_enumerated() {
+            if input.contains(local) {
+                out.insert(saved_local);
+            }
         }
+        out
+    }
+
+    fn get(&self, local: Local) -> Option<GeneratorSavedLocal> {
+        if !self.contains(local) {
+            return None;
+        }
+
+        let idx = self.iter().take_while(|&l| l < local).count();
+        Some(GeneratorSavedLocal::new(idx))
+    }
+}
+
+impl ops::Deref for GeneratorSavedLocals {
+    type Target = BitSet<Local>;
+
+    fn deref(&self) -> &Self::Target {
+        &self.0
     }
-    debug!("renumber_bitset({:?}, {:?}) => {:?}", input, stored_locals, out);
-    out
 }
 
 /// For every saved local, looks for which locals are StorageLive at the same
@@ -575,11 +597,11 @@ fn renumber_bitset(
 /// computation; see `GeneratorLayout` for more.
 fn compute_storage_conflicts(
     body: &'mir Body<'tcx>,
-    stored_locals: &BitSet<Local>,
+    saved_locals: &GeneratorSavedLocals,
     always_live_locals: storage::AlwaysLiveLocals,
     requires_storage: dataflow::Results<'tcx, MaybeRequiresStorage<'mir, 'tcx>>,
 ) -> BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal> {
-    assert_eq!(body.local_decls.len(), stored_locals.domain_size());
+    assert_eq!(body.local_decls.len(), saved_locals.domain_size());
 
     debug!("compute_storage_conflicts({:?})", body.span);
     debug!("always_live = {:?}", always_live_locals);
@@ -587,12 +609,12 @@ fn compute_storage_conflicts(
     // Locals that are always live or ones that need to be stored across
     // suspension points are not eligible for overlap.
     let mut ineligible_locals = always_live_locals.into_inner();
-    ineligible_locals.intersect(stored_locals);
+    ineligible_locals.intersect(saved_locals);
 
     // Compute the storage conflicts for all eligible locals.
     let mut visitor = StorageConflictVisitor {
         body,
-        stored_locals: &stored_locals,
+        saved_locals: &saved_locals,
         local_conflicts: BitMatrix::from_row_n(&ineligible_locals, body.local_decls.len()),
     };
 
@@ -609,16 +631,14 @@ fn compute_storage_conflicts(
     // However, in practice these bitsets are not usually large. The layout code
     // also needs to keep track of how many conflicts each local has, so it's
     // simpler to keep it this way for now.
-    let mut storage_conflicts = BitMatrix::new(stored_locals.count(), stored_locals.count());
-    for (idx_a, local_a) in stored_locals.iter().enumerate() {
-        let saved_local_a = GeneratorSavedLocal::new(idx_a);
+    let mut storage_conflicts = BitMatrix::new(saved_locals.count(), saved_locals.count());
+    for (saved_local_a, local_a) in saved_locals.iter_enumerated() {
         if ineligible_locals.contains(local_a) {
             // Conflicts with everything.
             storage_conflicts.insert_all_into_row(saved_local_a);
         } else {
             // Keep overlap information only for stored locals.
-            for (idx_b, local_b) in stored_locals.iter().enumerate() {
-                let saved_local_b = GeneratorSavedLocal::new(idx_b);
+            for (saved_local_b, local_b) in saved_locals.iter_enumerated() {
                 if local_conflicts.contains(local_a, local_b) {
                     storage_conflicts.insert(saved_local_a, saved_local_b);
                 }
@@ -630,7 +650,7 @@ fn compute_storage_conflicts(
 
 struct StorageConflictVisitor<'mir, 'tcx, 's> {
     body: &'mir Body<'tcx>,
-    stored_locals: &'s BitSet<Local>,
+    saved_locals: &'s GeneratorSavedLocals,
     // FIXME(tmandry): Consider using sparse bitsets here once we have good
     // benchmarks for generators.
     local_conflicts: BitMatrix<Local, Local>,
@@ -666,7 +686,7 @@ impl<'body, 'tcx, 's> StorageConflictVisitor<'body, 'tcx, 's> {
         }
 
         let mut eligible_storage_live = flow_state.clone();
-        eligible_storage_live.intersect(&self.stored_locals);
+        eligible_storage_live.intersect(&self.saved_locals);
 
         for local in eligible_storage_live.iter() {
             self.local_conflicts.union_row_with(&eligible_storage_live, local);
@@ -678,7 +698,7 @@ impl<'body, 'tcx, 's> StorageConflictVisitor<'body, 'tcx, 's> {
     }
 }
 
-/// Validates the typeck view of the generator against the actual set of types retained between
+/// Validates the typeck view of the generator against the actual set of types saved between
 /// yield points.
 fn sanitize_witness<'tcx>(
     tcx: TyCtxt<'tcx>,
@@ -686,7 +706,7 @@ fn sanitize_witness<'tcx>(
     did: DefId,
     witness: Ty<'tcx>,
     upvars: &Vec<Ty<'tcx>>,
-    retained: &BitSet<Local>,
+    saved_locals: &GeneratorSavedLocals,
 ) {
     let allowed_upvars = tcx.erase_regions(upvars);
     let allowed = match witness.kind {
@@ -703,8 +723,8 @@ fn sanitize_witness<'tcx>(
     let param_env = tcx.param_env(did);
 
     for (local, decl) in body.local_decls.iter_enumerated() {
-        // Ignore locals which are internal or not retained between yields.
-        if !retained.contains(local) || decl.internal {
+        // Ignore locals which are internal or not saved between yields.
+        if !saved_locals.contains(local) || decl.internal {
             continue;
         }
         let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty);
@@ -724,35 +744,27 @@ fn sanitize_witness<'tcx>(
 }
 
 fn compute_layout<'tcx>(
-    tcx: TyCtxt<'tcx>,
-    source: MirSource<'tcx>,
-    upvars: &Vec<Ty<'tcx>>,
-    interior: Ty<'tcx>,
-    always_live_locals: &storage::AlwaysLiveLocals,
-    movable: bool,
+    liveness: LivenessInfo,
     body: &mut Body<'tcx>,
 ) -> (
     FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>,
     GeneratorLayout<'tcx>,
     IndexVec<BasicBlock, Option<BitSet<Local>>>,
 ) {
-    // Use a liveness analysis to compute locals which are live across a suspension point
     let LivenessInfo {
-        live_locals,
+        saved_locals,
         live_locals_at_suspension_points,
         storage_conflicts,
         storage_liveness,
-    } = locals_live_across_suspend_points(tcx, body, source, always_live_locals, movable);
-
-    sanitize_witness(tcx, body, source.def_id(), interior, upvars, &live_locals);
+    } = liveness;
 
     // Gather live local types and their indices.
     let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
     let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
-    for (idx, local) in live_locals.iter().enumerate() {
+    for (saved_local, local) in saved_locals.iter_enumerated() {
         locals.push(local);
         tys.push(body.local_decls[local].ty);
-        debug!("generator saved local {:?} => {:?}", GeneratorSavedLocal::from(idx), local);
+        debug!("generator saved local {:?} => {:?}", saved_local, local);
     }
 
     // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states.
@@ -1260,11 +1272,25 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
 
         let always_live_locals = storage::AlwaysLiveLocals::new(&body);
 
+        let liveness_info =
+            locals_live_across_suspend_points(tcx, body, source, &always_live_locals, movable);
+
+        sanitize_witness(tcx, body, def_id, interior, &upvars, &liveness_info.saved_locals);
+
+        if tcx.sess.opts.debugging_opts.validate_mir {
+            let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias {
+                assigned_local: None,
+                saved_locals: &liveness_info.saved_locals,
+                storage_conflicts: &liveness_info.storage_conflicts,
+            };
+
+            vis.visit_body(body);
+        }
+
         // Extract locals which are live across suspension point into `layout`
         // `remap` gives a mapping from local indices onto generator struct indices
         // `storage_liveness` tells us which locals have live storage at suspension points
-        let (remap, layout, storage_liveness) =
-            compute_layout(tcx, source, &upvars, interior, &always_live_locals, movable, body);
+        let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
 
         let can_return = can_return(tcx, body);
 
@@ -1315,3 +1341,134 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         create_generator_resume_function(tcx, transform, source, body, can_return);
     }
 }
+
+/// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields
+/// in the generator state machine but whose storage is not marked as conflicting
+///
+/// Validation needs to happen immediately *before* `TransformVisitor` is invoked, not after.
+///
+/// This condition would arise when the assignment is the last use of `_5` but the initial
+/// definition of `_4` if we weren't extra careful to mark all locals used inside a statement as
+/// conflicting. Non-conflicting generator saved locals may be stored at the same location within
+/// the generator state machine, which would result in ill-formed MIR: the left-hand and right-hand
+/// sides of an assignment may not alias. This caused a miscompilation in [#73137].
+///
+/// [#73137]: https://github.com/rust-lang/rust/issues/73137
+struct EnsureGeneratorFieldAssignmentsNeverAlias<'a> {
+    saved_locals: &'a GeneratorSavedLocals,
+    storage_conflicts: &'a BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>,
+    assigned_local: Option<GeneratorSavedLocal>,
+}
+
+impl EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
+    fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<GeneratorSavedLocal> {
+        if place.is_indirect() {
+            return None;
+        }
+
+        self.saved_locals.get(place.local)
+    }
+
+    fn check_assigned_place(&mut self, place: Place<'tcx>, f: impl FnOnce(&mut Self)) {
+        if let Some(assigned_local) = self.saved_local_for_direct_place(place) {
+            assert!(self.assigned_local.is_none(), "`check_assigned_place` must not recurse");
+
+            self.assigned_local = Some(assigned_local);
+            f(self);
+            self.assigned_local = None;
+        }
+    }
+}
+
+impl Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
+    fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
+        let lhs = match self.assigned_local {
+            Some(l) => l,
+            None => {
+                // This visitor only invokes `visit_place` for the right-hand side of an assignment
+                // and only after setting `self.assigned_local`. However, the default impl of
+                // `Visitor::super_body` may call `visit_place` with a `NonUseContext` for places
+                // with debuginfo. Ignore them here.
+                assert!(!context.is_use());
+                return;
+            }
+        };
+
+        let rhs = match self.saved_local_for_direct_place(*place) {
+            Some(l) => l,
+            None => return,
+        };
+
+        if !self.storage_conflicts.contains(lhs, rhs) {
+            bug!(
+                "Assignment between generator saved locals whose storage is not \
+                    marked as conflicting: {:?}: {:?} = {:?}",
+                location,
+                lhs,
+                rhs,
+            );
+        }
+    }
+
+    fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
+        match &statement.kind {
+            StatementKind::Assign(box (lhs, rhs)) => {
+                self.check_assigned_place(*lhs, |this| this.visit_rvalue(rhs, location));
+            }
+
+            // FIXME: Does `llvm_asm!` have any aliasing requirements?
+            StatementKind::LlvmInlineAsm(_) => {}
+
+            StatementKind::FakeRead(..)
+            | StatementKind::SetDiscriminant { .. }
+            | StatementKind::StorageLive(_)
+            | StatementKind::StorageDead(_)
+            | StatementKind::Retag(..)
+            | StatementKind::AscribeUserType(..)
+            | StatementKind::Nop => {}
+        }
+    }
+
+    fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
+        // Checking for aliasing in terminators is probably overkill, but until we have actual
+        // semantics, we should be conservative here.
+        match &terminator.kind {
+            TerminatorKind::Call {
+                func,
+                args,
+                destination: Some((dest, _)),
+                cleanup: _,
+                from_hir_call: _,
+                fn_span: _,
+            } => {
+                self.check_assigned_place(*dest, |this| {
+                    this.visit_operand(func, location);
+                    for arg in args {
+                        this.visit_operand(arg, location);
+                    }
+                });
+            }
+
+            TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => {
+                self.check_assigned_place(*resume_arg, |this| this.visit_operand(value, location));
+            }
+
+            // FIXME: Does `asm!` have any aliasing requirements?
+            TerminatorKind::InlineAsm { .. } => {}
+
+            TerminatorKind::Call { .. }
+            | TerminatorKind::Goto { .. }
+            | TerminatorKind::SwitchInt { .. }
+            | TerminatorKind::Resume
+            | TerminatorKind::Abort
+            | TerminatorKind::Return
+            | TerminatorKind::Unreachable
+            | TerminatorKind::Drop { .. }
+            | TerminatorKind::DropAndReplace { .. }
+            | TerminatorKind::Assert { .. }
+            | TerminatorKind::GeneratorDrop
+            | TerminatorKind::FalseEdge { .. }
+            | TerminatorKind::FalseUnwind { .. } => {}
+        }
+    }
+}