about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/copy_prop.rs34
-rw-r--r--compiler/rustc_mir_transform/src/ssa.rs23
2 files changed, 30 insertions, 27 deletions
diff --git a/compiler/rustc_mir_transform/src/copy_prop.rs b/compiler/rustc_mir_transform/src/copy_prop.rs
index 27af5818982..fe78a104fa0 100644
--- a/compiler/rustc_mir_transform/src/copy_prop.rs
+++ b/compiler/rustc_mir_transform/src/copy_prop.rs
@@ -30,6 +30,8 @@ impl<'tcx> crate::MirPass<'tcx> for CopyProp {
 
         let typing_env = body.typing_env(tcx);
         let ssa = SsaLocals::new(tcx, body, typing_env);
+        debug!(borrowed_locals = ?ssa.borrowed_locals());
+        debug!(copy_classes = ?ssa.copy_classes());
 
         let fully_moved = fully_moved_locals(&ssa, body);
         debug!(?fully_moved);
@@ -43,14 +45,8 @@ impl<'tcx> crate::MirPass<'tcx> for CopyProp {
 
         let any_replacement = ssa.copy_classes().iter_enumerated().any(|(l, &h)| l != h);
 
-        Replacer {
-            tcx,
-            copy_classes: ssa.copy_classes(),
-            fully_moved,
-            borrowed_locals: ssa.borrowed_locals(),
-            storage_to_remove,
-        }
-        .visit_body_preserves_cfg(body);
+        Replacer { tcx, copy_classes: ssa.copy_classes(), fully_moved, storage_to_remove }
+            .visit_body_preserves_cfg(body);
 
         if any_replacement {
             crate::simplify::remove_unused_definitions(body);
@@ -102,7 +98,6 @@ struct Replacer<'a, 'tcx> {
     tcx: TyCtxt<'tcx>,
     fully_moved: DenseBitSet<Local>,
     storage_to_remove: DenseBitSet<Local>,
-    borrowed_locals: &'a DenseBitSet<Local>,
     copy_classes: &'a IndexSlice<Local, Local>,
 }
 
@@ -111,34 +106,18 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> {
         self.tcx
     }
 
+    #[tracing::instrument(level = "trace", skip(self))]
     fn visit_local(&mut self, local: &mut Local, ctxt: PlaceContext, _: Location) {
         let new_local = self.copy_classes[*local];
-        // We must not unify two locals that are borrowed. But this is fine if one is borrowed and
-        // the other is not. We chose to check the original local, and not the target. That way, if
-        // the original local is borrowed and the target is not, we do not pessimize the whole class.
-        if self.borrowed_locals.contains(*local) {
-            return;
-        }
         match ctxt {
             // Do not modify the local in storage statements.
             PlaceContext::NonUse(NonUseContext::StorageLive | NonUseContext::StorageDead) => {}
-            // The local should have been marked as non-SSA.
-            PlaceContext::MutatingUse(_) => assert_eq!(*local, new_local),
             // We access the value.
             _ => *local = new_local,
         }
     }
 
-    fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, loc: Location) {
-        if let Some(new_projection) = self.process_projection(place.projection, loc) {
-            place.projection = self.tcx().mk_place_elems(&new_projection);
-        }
-
-        // Any non-mutating use context is ok.
-        let ctxt = PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy);
-        self.visit_local(&mut place.local, ctxt, loc)
-    }
-
+    #[tracing::instrument(level = "trace", skip(self))]
     fn visit_operand(&mut self, operand: &mut Operand<'tcx>, loc: Location) {
         if let Operand::Move(place) = *operand
             // A move out of a projection of a copy is equivalent to a copy of the original
@@ -151,6 +130,7 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> {
         self.super_operand(operand, loc);
     }
 
+    #[tracing::instrument(level = "trace", skip(self))]
     fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, loc: Location) {
         // When removing storage statements, we need to remove both (#107511).
         if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = stmt.kind
diff --git a/compiler/rustc_mir_transform/src/ssa.rs b/compiler/rustc_mir_transform/src/ssa.rs
index edd0cabca49..03b6f9b7ff3 100644
--- a/compiler/rustc_mir_transform/src/ssa.rs
+++ b/compiler/rustc_mir_transform/src/ssa.rs
@@ -293,6 +293,10 @@ impl<'tcx> Visitor<'tcx> for SsaVisitor<'_, 'tcx> {
 fn compute_copy_classes(ssa: &mut SsaLocals, body: &Body<'_>) {
     let mut direct_uses = std::mem::take(&mut ssa.direct_uses);
     let mut copies = IndexVec::from_fn_n(|l| l, body.local_decls.len());
+    // We must not unify two locals that are borrowed. But this is fine if one is borrowed and
+    // the other is not. This bitset is keyed by *class head* and contains whether any member of
+    // the class is borrowed.
+    let mut borrowed_classes = ssa.borrowed_locals().clone();
 
     for (local, rvalue, _) in ssa.assignments(body) {
         let (Rvalue::Use(Operand::Copy(place) | Operand::Move(place))
@@ -318,6 +322,11 @@ fn compute_copy_classes(ssa: &mut SsaLocals, body: &Body<'_>) {
         // visited before `local`, and we just have to copy the representing local.
         let head = copies[rhs];
 
+        // Do not unify two borrowed locals.
+        if borrowed_classes.contains(local) && borrowed_classes.contains(head) {
+            continue;
+        }
+
         if local == RETURN_PLACE {
             // `_0` is special, we cannot rename it. Instead, rename the class of `rhs` to
             // `RETURN_PLACE`. This is only possible if the class head is a temporary, not an
@@ -330,14 +339,21 @@ fn compute_copy_classes(ssa: &mut SsaLocals, body: &Body<'_>) {
                     *h = RETURN_PLACE;
                 }
             }
+            if borrowed_classes.contains(head) {
+                borrowed_classes.insert(RETURN_PLACE);
+            }
         } else {
             copies[local] = head;
+            if borrowed_classes.contains(local) {
+                borrowed_classes.insert(head);
+            }
         }
         direct_uses[rhs] -= 1;
     }
 
     debug!(?copies);
     debug!(?direct_uses);
+    debug!(?borrowed_classes);
 
     // Invariant: `copies` must point to the head of an equivalence class.
     #[cfg(debug_assertions)]
@@ -346,6 +362,13 @@ fn compute_copy_classes(ssa: &mut SsaLocals, body: &Body<'_>) {
     }
     debug_assert_eq!(copies[RETURN_PLACE], RETURN_PLACE);
 
+    // Invariant: `borrowed_classes` must be true if any member of the class is borrowed.
+    #[cfg(debug_assertions)]
+    for &head in copies.iter() {
+        let any_borrowed = ssa.borrowed_locals.iter().any(|l| copies[l] == head);
+        assert_eq!(borrowed_classes.contains(head), any_borrowed);
+    }
+
     ssa.direct_uses = direct_uses;
     ssa.copy_classes = copies;
 }