about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-02-07 09:34:17 +0000
committerbors <bors@rust-lang.org>2023-02-07 09:34:17 +0000
commit0d225bcf1eb270564297c1f094d54a685a5ed08e (patch)
tree2e1b2e5e0a66252af9572ab6d810e41b9fa84b1f /compiler/rustc_mir_transform/src
parent6c8d4073a4490671b05f3b5f5e6810cfdc23a641 (diff)
parent8461c0eb673e769cae656ce52dd662829437e200 (diff)
downloadrust-0d225bcf1eb270564297c1f094d54a685a5ed08e.tar.gz
rust-0d225bcf1eb270564297c1f094d54a685a5ed08e.zip
Auto merge of #2780 - RalfJung:rustup, r=RalfJung
Rustup
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/check_unsafety.rs1
-rw-r--r--compiler/rustc_mir_transform/src/copy_prop.rs25
-rw-r--r--compiler/rustc_mir_transform/src/dataflow_const_prop.rs25
-rw-r--r--compiler/rustc_mir_transform/src/deaggregator.rs45
-rw-r--r--compiler/rustc_mir_transform/src/generator.rs30
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs4
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs25
-rw-r--r--compiler/rustc_mir_transform/src/sroa.rs431
8 files changed, 319 insertions, 267 deletions
diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs
index 8afa53313fc..d00ee1f4bab 100644
--- a/compiler/rustc_mir_transform/src/check_unsafety.rs
+++ b/compiler/rustc_mir_transform/src/check_unsafety.rs
@@ -126,6 +126,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> {
                     }
                 }
                 &AggregateKind::Closure(def_id, _) | &AggregateKind::Generator(def_id, _, _) => {
+                    let def_id = def_id.expect_local();
                     let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } =
                         self.tcx.unsafety_check_result(def_id);
                     self.register_violations(violations, used_unsafe_blocks.iter().copied());
diff --git a/compiler/rustc_mir_transform/src/copy_prop.rs b/compiler/rustc_mir_transform/src/copy_prop.rs
index 182b3015dd7..4c7d45be075 100644
--- a/compiler/rustc_mir_transform/src/copy_prop.rs
+++ b/compiler/rustc_mir_transform/src/copy_prop.rs
@@ -162,17 +162,20 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> {
     }
 
     fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, loc: Location) {
-        if let StatementKind::StorageDead(l) = stmt.kind
-            && self.storage_to_remove.contains(l)
-        {
-            stmt.make_nop();
-        } else if let StatementKind::Assign(box (ref place, ref mut rvalue)) = stmt.kind
-            && place.as_local().is_some()
-        {
-            // Do not replace assignments.
-            self.visit_rvalue(rvalue, loc)
-        } else {
-            self.super_statement(stmt, loc);
+        match stmt.kind {
+            // When removing storage statements, we need to remove both (#107511).
+            StatementKind::StorageLive(l) | StatementKind::StorageDead(l)
+                if self.storage_to_remove.contains(l) =>
+            {
+                stmt.make_nop()
+            }
+            StatementKind::Assign(box (ref place, ref mut rvalue))
+                if place.as_local().is_some() =>
+            {
+                // Do not replace assignments.
+                self.visit_rvalue(rvalue, loc)
+            }
+            _ => self.super_statement(stmt, loc),
         }
     }
 }
diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
index c75fe2327de..949a59a97bf 100644
--- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
+++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
@@ -5,6 +5,7 @@
 use rustc_const_eval::const_eval::CheckAlignment;
 use rustc_const_eval::interpret::{ConstValue, ImmTy, Immediate, InterpCx, Scalar};
 use rustc_data_structures::fx::FxHashMap;
+use rustc_hir::def::DefKind;
 use rustc_middle::mir::visit::{MutVisitor, Visitor};
 use rustc_middle::mir::*;
 use rustc_middle::ty::{self, Ty, TyCtxt};
@@ -85,6 +86,30 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
         state: &mut State<Self::Value>,
     ) {
         match rvalue {
+            Rvalue::Aggregate(kind, operands) => {
+                let target = self.map().find(target.as_ref());
+                if let Some(target) = target {
+                    state.flood_idx_with(target, self.map(), FlatSet::Bottom);
+                    let field_based = match **kind {
+                        AggregateKind::Tuple | AggregateKind::Closure(..) => true,
+                        AggregateKind::Adt(def_id, ..) => {
+                            matches!(self.tcx.def_kind(def_id), DefKind::Struct)
+                        }
+                        _ => false,
+                    };
+                    if field_based {
+                        for (field_index, operand) in operands.iter().enumerate() {
+                            if let Some(field) = self
+                                .map()
+                                .apply(target, TrackElem::Field(Field::from_usize(field_index)))
+                            {
+                                let result = self.handle_operand(operand, state);
+                                state.assign_idx(field, result, self.map());
+                            }
+                        }
+                    }
+                }
+            }
             Rvalue::CheckedBinaryOp(op, box (left, right)) => {
                 let target = self.map().find(target.as_ref());
                 if let Some(target) = target {
diff --git a/compiler/rustc_mir_transform/src/deaggregator.rs b/compiler/rustc_mir_transform/src/deaggregator.rs
deleted file mode 100644
index fe272de20f8..00000000000
--- a/compiler/rustc_mir_transform/src/deaggregator.rs
+++ /dev/null
@@ -1,45 +0,0 @@
-use crate::util::expand_aggregate;
-use crate::MirPass;
-use rustc_middle::mir::*;
-use rustc_middle::ty::TyCtxt;
-
-pub struct Deaggregator;
-
-impl<'tcx> MirPass<'tcx> for Deaggregator {
-    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
-        let basic_blocks = body.basic_blocks.as_mut_preserves_cfg();
-        for bb in basic_blocks {
-            bb.expand_statements(|stmt| {
-                // FIXME(eddyb) don't match twice on `stmt.kind` (post-NLL).
-                match stmt.kind {
-                    // FIXME(#48193) Deaggregate arrays when it's cheaper to do so.
-                    StatementKind::Assign(box (
-                        _,
-                        Rvalue::Aggregate(box AggregateKind::Array(_), _),
-                    )) => {
-                        return None;
-                    }
-                    StatementKind::Assign(box (_, Rvalue::Aggregate(_, _))) => {}
-                    _ => return None,
-                }
-
-                let stmt = stmt.replace_nop();
-                let source_info = stmt.source_info;
-                let StatementKind::Assign(box (lhs, Rvalue::Aggregate(kind, operands))) = stmt.kind else {
-                    bug!();
-                };
-
-                Some(expand_aggregate(
-                    lhs,
-                    operands.into_iter().map(|op| {
-                        let ty = op.ty(&body.local_decls, tcx);
-                        (op, ty)
-                    }),
-                    *kind,
-                    source_info,
-                    tcx,
-                ))
-            });
-        }
-    }
-}
diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs
index 5624e312da1..47f9d35a4f7 100644
--- a/compiler/rustc_mir_transform/src/generator.rs
+++ b/compiler/rustc_mir_transform/src/generator.rs
@@ -52,7 +52,6 @@
 
 use crate::deref_separator::deref_finder;
 use crate::simplify;
-use crate::util::expand_aggregate;
 use crate::MirPass;
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
 use rustc_errors::pluralize;
@@ -272,31 +271,26 @@ impl<'tcx> TransformVisitor<'tcx> {
             assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
 
             // FIXME(swatinem): assert that `val` is indeed unit?
-            statements.extend(expand_aggregate(
-                Place::return_place(),
-                std::iter::empty(),
-                kind,
+            statements.push(Statement {
+                kind: StatementKind::Assign(Box::new((
+                    Place::return_place(),
+                    Rvalue::Aggregate(Box::new(kind), vec![]),
+                ))),
                 source_info,
-                self.tcx,
-            ));
+            });
             return;
         }
 
         // else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)`
         assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
 
-        let ty = self
-            .tcx
-            .bound_type_of(self.state_adt_ref.variant(idx).fields[0].did)
-            .subst(self.tcx, self.state_substs);
-
-        statements.extend(expand_aggregate(
-            Place::return_place(),
-            std::iter::once((val, ty)),
-            kind,
+        statements.push(Statement {
+            kind: StatementKind::Assign(Box::new((
+                Place::return_place(),
+                Rvalue::Aggregate(Box::new(kind), vec![val]),
+            ))),
             source_info,
-            self.tcx,
-        ));
+        });
     }
 
     // Create a Place referencing a generator struct field
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 6815289776e..9070a7368b1 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -60,7 +60,6 @@ mod coverage;
 mod ctfe_limit;
 mod dataflow_const_prop;
 mod dead_store_elimination;
-mod deaggregator;
 mod deduce_param_attrs;
 mod deduplicate_blocks;
 mod deref_separator;
@@ -523,9 +522,6 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         &elaborate_box_derefs::ElaborateBoxDerefs,
         &generator::StateTransform,
         &add_retag::AddRetag,
-        // Deaggregator is necessary for const prop. We may want to consider implementing
-        // CTFE support for aggregates.
-        &deaggregator::Deaggregator,
         &Lint(const_prop_lint::ConstProp),
     ];
     pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial)));
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index e9ca6f7c93c..551422386f6 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -15,7 +15,6 @@ use rustc_target::spec::abi::Abi;
 use std::fmt;
 use std::iter;
 
-use crate::util::expand_aggregate;
 use crate::{
     abort_unwinding_calls, add_call_guards, add_moves_for_packed_drops, deref_separator,
     pass_manager as pm, remove_noop_landing_pads, simplify,
@@ -831,19 +830,23 @@ pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> {
     // return;
     debug!("build_ctor: variant_index={:?}", variant_index);
 
-    let statements = expand_aggregate(
-        Place::return_place(),
-        adt_def.variant(variant_index).fields.iter().enumerate().map(|(idx, field_def)| {
-            (Operand::Move(Place::from(Local::new(idx + 1))), field_def.ty(tcx, substs))
-        }),
-        AggregateKind::Adt(adt_def.did(), variant_index, substs, None, None),
+    let kind = AggregateKind::Adt(adt_def.did(), variant_index, substs, None, None);
+    let variant = adt_def.variant(variant_index);
+    let statement = Statement {
+        kind: StatementKind::Assign(Box::new((
+            Place::return_place(),
+            Rvalue::Aggregate(
+                Box::new(kind),
+                (0..variant.fields.len())
+                    .map(|idx| Operand::Move(Place::from(Local::new(idx + 1))))
+                    .collect(),
+            ),
+        ))),
         source_info,
-        tcx,
-    )
-    .collect();
+    };
 
     let start_block = BasicBlockData {
-        statements,
+        statements: vec![statement],
         terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }),
         is_cleanup: false,
     };
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs
index 42124f5a480..26acd406ed8 100644
--- a/compiler/rustc_mir_transform/src/sroa.rs
+++ b/compiler/rustc_mir_transform/src/sroa.rs
@@ -1,10 +1,11 @@
 use crate::MirPass;
-use rustc_data_structures::fx::{FxIndexMap, IndexEntry};
 use rustc_index::bit_set::BitSet;
 use rustc_index::vec::IndexVec;
+use rustc_middle::mir::patch::MirPatch;
 use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
-use rustc_middle::ty::TyCtxt;
+use rustc_middle::ty::{Ty, TyCtxt};
+use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
 
 pub struct ScalarReplacementOfAggregates;
 
@@ -13,27 +14,41 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
         sess.mir_opt_level() >= 3
     }
 
+    #[instrument(level = "debug", skip(self, tcx, body))]
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
-        let escaping = escaping_locals(&*body);
-        debug!(?escaping);
-        let replacements = compute_flattening(tcx, body, escaping);
-        debug!(?replacements);
-        replace_flattened_locals(tcx, body, replacements);
+        debug!(def_id = ?body.source.def_id());
+        let mut excluded = excluded_locals(body);
+        loop {
+            debug!(?excluded);
+            let escaping = escaping_locals(&excluded, body);
+            debug!(?escaping);
+            let replacements = compute_flattening(tcx, body, escaping);
+            debug!(?replacements);
+            let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
+            if !all_dead_locals.is_empty() {
+                for local in excluded.indices() {
+                    excluded[local] |= all_dead_locals.contains(local);
+                }
+                excluded.raw.resize(body.local_decls.len(), false);
+            } else {
+                break;
+            }
+        }
     }
 }
 
 /// Identify all locals that are not eligible for SROA.
 ///
 /// There are 3 cases:
-/// - the aggegated local is used or passed to other code (function parameters and arguments);
+/// - the aggregated local is used or passed to other code (function parameters and arguments);
 /// - the locals is a union or an enum;
 /// - the local's address is taken, and thus the relative addresses of the fields are observable to
 ///   client code.
-fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
+fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<Local> {
     let mut set = BitSet::new_empty(body.local_decls.len());
     set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
     for (local, decl) in body.local_decls().iter_enumerated() {
-        if decl.ty.is_union() || decl.ty.is_enum() {
+        if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] {
             set.insert(local);
         }
     }
@@ -58,41 +73,33 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
             self.super_place(place, context, location);
         }
 
-        fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
-            if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue {
-                if !place.is_indirect() {
-                    // Raw pointers may be used to access anything inside the enclosing place.
-                    self.set.insert(place.local);
-                    return;
+        fn visit_assign(
+            &mut self,
+            lvalue: &Place<'tcx>,
+            rvalue: &Rvalue<'tcx>,
+            location: Location,
+        ) {
+            if lvalue.as_local().is_some() {
+                match rvalue {
+                    // Aggregate assignments are expanded in run_pass.
+                    Rvalue::Aggregate(..) | Rvalue::Use(..) => {
+                        self.visit_rvalue(rvalue, location);
+                        return;
+                    }
+                    _ => {}
                 }
             }
-            self.super_rvalue(rvalue, location)
+            self.super_assign(lvalue, rvalue, location)
         }
 
         fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
-            if let StatementKind::StorageLive(..)
-            | StatementKind::StorageDead(..)
-            | StatementKind::Deinit(..) = statement.kind
-            {
+            match statement.kind {
                 // Storage statements are expanded in run_pass.
-                return;
+                StatementKind::StorageLive(..)
+                | StatementKind::StorageDead(..)
+                | StatementKind::Deinit(..) => return,
+                _ => self.super_statement(statement, location),
             }
-            self.super_statement(statement, location)
-        }
-
-        fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
-            // Drop implicitly calls `drop_in_place`, which takes a `&mut`.
-            // This implies that `Drop` implicitly takes the address of the place.
-            if let TerminatorKind::Drop { place, .. }
-            | TerminatorKind::DropAndReplace { place, .. } = terminator.kind
-            {
-                if !place.is_indirect() {
-                    // Raw pointers may be used to access anything inside the enclosing place.
-                    self.set.insert(place.local);
-                    return;
-                }
-            }
-            self.super_terminator(terminator, location);
         }
 
         // We ignore anything that happens in debuginfo, since we expand it using
@@ -103,7 +110,30 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
 
 #[derive(Default, Debug)]
 struct ReplacementMap<'tcx> {
-    fields: FxIndexMap<PlaceRef<'tcx>, Local>,
+    /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
+    /// and deinit statement and debuginfo.
+    fragments: IndexVec<Local, Option<IndexVec<Field, Option<(Ty<'tcx>, Local)>>>>,
+}
+
+impl<'tcx> ReplacementMap<'tcx> {
+    fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
+        let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; };
+        let fields = self.fragments[place.local].as_ref()?;
+        let (_, new_local) = fields[f]?;
+        Some(Place { local: new_local, projection: tcx.intern_place_elems(&rest) })
+    }
+
+    fn place_fragments(
+        &self,
+        place: Place<'tcx>,
+    ) -> Option<impl Iterator<Item = (Field, Ty<'tcx>, Local)> + '_> {
+        let local = place.as_local()?;
+        let fields = self.fragments[local].as_ref()?;
+        Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
+            let (ty, local) = opt_ty_local?;
+            Some((field, ty, local))
+        }))
+    }
 }
 
 /// Compute the replacement of flattened places into locals.
@@ -115,53 +145,25 @@ fn compute_flattening<'tcx>(
     body: &mut Body<'tcx>,
     escaping: BitSet<Local>,
 ) -> ReplacementMap<'tcx> {
-    let mut visitor = PreFlattenVisitor {
-        tcx,
-        escaping,
-        local_decls: &mut body.local_decls,
-        map: Default::default(),
-    };
-    for (block, bbdata) in body.basic_blocks.iter_enumerated() {
-        visitor.visit_basic_block_data(block, bbdata);
-    }
-    return visitor.map;
-
-    struct PreFlattenVisitor<'tcx, 'll> {
-        tcx: TyCtxt<'tcx>,
-        local_decls: &'ll mut LocalDecls<'tcx>,
-        escaping: BitSet<Local>,
-        map: ReplacementMap<'tcx>,
-    }
-
-    impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> {
-        fn create_place(&mut self, place: PlaceRef<'tcx>) {
-            if self.escaping.contains(place.local) {
-                return;
-            }
+    let mut fragments = IndexVec::from_elem(None, &body.local_decls);
 
-            match self.map.fields.entry(place) {
-                IndexEntry::Occupied(_) => {}
-                IndexEntry::Vacant(v) => {
-                    let ty = place.ty(&*self.local_decls, self.tcx).ty;
-                    let local = self.local_decls.push(LocalDecl {
-                        ty,
-                        user_ty: None,
-                        ..self.local_decls[place.local].clone()
-                    });
-                    v.insert(local);
-                }
-            }
-        }
-    }
-
-    impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> {
-        fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) {
-            if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
-                let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
-                self.create_place(pr)
-            }
+    for local in body.local_decls.indices() {
+        if escaping.contains(local) {
+            continue;
         }
+        let decl = body.local_decls[local].clone();
+        let ty = decl.ty;
+        iter_fields(ty, tcx, |variant, field, field_ty| {
+            if variant.is_some() {
+                // Downcasts are currently not supported.
+                return;
+            };
+            let new_local =
+                body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
+            fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
+        });
     }
+    ReplacementMap { fragments }
 }
 
 /// Perform the replacement computed by `compute_flattening`.
@@ -169,29 +171,24 @@ fn replace_flattened_locals<'tcx>(
     tcx: TyCtxt<'tcx>,
     body: &mut Body<'tcx>,
     replacements: ReplacementMap<'tcx>,
-) {
+) -> BitSet<Local> {
     let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
-    for p in replacements.fields.keys() {
-        all_dead_locals.insert(p.local);
+    for (local, replacements) in replacements.fragments.iter_enumerated() {
+        if replacements.is_some() {
+            all_dead_locals.insert(local);
+        }
     }
     debug!(?all_dead_locals);
     if all_dead_locals.is_empty() {
-        return;
+        return all_dead_locals;
     }
 
-    let mut fragments = IndexVec::new();
-    for (k, v) in &replacements.fields {
-        fragments.ensure_contains_elem(k.local, || Vec::new());
-        fragments[k.local].push((k.projection, *v));
-    }
-    debug!(?fragments);
-
     let mut visitor = ReplacementVisitor {
         tcx,
         local_decls: &body.local_decls,
-        replacements,
+        replacements: &replacements,
         all_dead_locals,
-        fragments,
+        patch: MirPatch::new(body),
     };
     for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
         visitor.visit_basic_block_data(bb, data);
@@ -205,6 +202,9 @@ fn replace_flattened_locals<'tcx>(
     for var_debug_info in &mut body.var_debug_info {
         visitor.visit_var_debug_info(var_debug_info);
     }
+    let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
+    patch.apply(body);
+    all_dead_locals
 }
 
 struct ReplacementVisitor<'tcx, 'll> {
@@ -212,40 +212,23 @@ struct ReplacementVisitor<'tcx, 'll> {
     /// This is only used to compute the type for `VarDebugInfoContents::Composite`.
     local_decls: &'ll LocalDecls<'tcx>,
     /// Work to do.
-    replacements: ReplacementMap<'tcx>,
+    replacements: &'ll ReplacementMap<'tcx>,
     /// This is used to check that we are not leaving references to replaced locals behind.
     all_dead_locals: BitSet<Local>,
-    /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
-    /// and deinit statement and debuginfo.
-    fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
+    patch: MirPatch<'tcx>,
 }
 
-impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
-    fn gather_debug_info_fragments(
-        &self,
-        place: PlaceRef<'tcx>,
-    ) -> Vec<VarDebugInfoFragment<'tcx>> {
+impl<'tcx> ReplacementVisitor<'tcx, '_> {
+    fn gather_debug_info_fragments(&self, local: Local) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
         let mut fragments = Vec::new();
-        let parts = &self.fragments[place.local];
-        for (proj, replacement_local) in parts {
-            if proj.starts_with(place.projection) {
-                fragments.push(VarDebugInfoFragment {
-                    projection: proj[place.projection.len()..].to_vec(),
-                    contents: Place::from(*replacement_local),
-                });
-            }
-        }
-        fragments
-    }
-
-    fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
-        if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection {
-            let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
-            let local = self.replacements.fields.get(&pr)?;
-            Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) })
-        } else {
-            None
+        let parts = self.replacements.place_fragments(local.into())?;
+        for (field, ty, replacement_local) in parts {
+            fragments.push(VarDebugInfoFragment {
+                projection: vec![PlaceElem::Field(field, ty)],
+                contents: Place::from(replacement_local),
+            });
         }
+        Some(fragments)
     }
 }
 
@@ -254,94 +237,186 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
         self.tcx
     }
 
-    fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
-        if let StatementKind::StorageLive(..)
-        | StatementKind::StorageDead(..)
-        | StatementKind::Deinit(..) = statement.kind
-        {
-            // Storage statements are expanded in run_pass.
-            return;
-        }
-        self.super_statement(statement, location)
-    }
-
     fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
-        if let Some(repl) = self.replace_place(place.as_ref()) {
+        if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
             *place = repl
         } else {
             self.super_place(place, context, location)
         }
     }
 
+    #[instrument(level = "trace", skip(self))]
+    fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
+        match statement.kind {
+            // Duplicate storage and deinit statements, as they pretty much apply to all fields.
+            StatementKind::StorageLive(l) => {
+                if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
+                    for (_, _, fl) in final_locals {
+                        self.patch.add_statement(location, StatementKind::StorageLive(fl));
+                    }
+                    statement.make_nop();
+                }
+                return;
+            }
+            StatementKind::StorageDead(l) => {
+                if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
+                    for (_, _, fl) in final_locals {
+                        self.patch.add_statement(location, StatementKind::StorageDead(fl));
+                    }
+                    statement.make_nop();
+                }
+                return;
+            }
+            StatementKind::Deinit(box place) => {
+                if let Some(final_locals) = self.replacements.place_fragments(place) {
+                    for (_, _, fl) in final_locals {
+                        self.patch
+                            .add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
+                    }
+                    statement.make_nop();
+                    return;
+                }
+            }
+
+            // We have `a = Struct { 0: x, 1: y, .. }`.
+            // We replace it by
+            // ```
+            // a_0 = x
+            // a_1 = y
+            // ...
+            // ```
+            StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
+                if let Some(local) = place.as_local()
+                    && let Some(final_locals) = &self.replacements.fragments[local]
+                {
+                    // This is ok as we delete the statement later.
+                    let operands = std::mem::take(operands);
+                    for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
+                        if let Some((_, new_local)) = opt_ty_local {
+                            // Replace mentions of SROA'd locals that appear in the operand.
+                            self.visit_operand(&mut operand, location);
+
+                            let rvalue = Rvalue::Use(operand);
+                            self.patch.add_statement(
+                                location,
+                                StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+                            );
+                        }
+                    }
+                    statement.make_nop();
+                    return;
+                }
+            }
+
+            // We have `a = some constant`
+            // We add the projections.
+            // ```
+            // a_0 = a.0
+            // a_1 = a.1
+            // ...
+            // ```
+            // ConstProp will pick up the pieces and replace them by actual constants.
+            StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
+                if let Some(final_locals) = self.replacements.place_fragments(place) {
+                    for (field, ty, new_local) in final_locals {
+                        let rplace = self.tcx.mk_place_field(place, field, ty);
+                        let rvalue = Rvalue::Use(Operand::Move(rplace));
+                        self.patch.add_statement(
+                            location,
+                            StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+                        );
+                    }
+                    // We still need `place.local` to exist, so don't make it nop.
+                    return;
+                }
+            }
+
+            // We have `a = move? place`
+            // We replace it by
+            // ```
+            // a_0 = move? place.0
+            // a_1 = move? place.1
+            // ...
+            // ```
+            StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
+                let (rplace, copy) = match *op {
+                    Operand::Copy(rplace) => (rplace, true),
+                    Operand::Move(rplace) => (rplace, false),
+                    Operand::Constant(_) => bug!(),
+                };
+                if let Some(final_locals) = self.replacements.place_fragments(lhs) {
+                    for (field, ty, new_local) in final_locals {
+                        let rplace = self.tcx.mk_place_field(rplace, field, ty);
+                        debug!(?rplace);
+                        let rplace = self
+                            .replacements
+                            .replace_place(self.tcx, rplace.as_ref())
+                            .unwrap_or(rplace);
+                        debug!(?rplace);
+                        let rvalue = if copy {
+                            Rvalue::Use(Operand::Copy(rplace))
+                        } else {
+                            Rvalue::Use(Operand::Move(rplace))
+                        };
+                        self.patch.add_statement(
+                            location,
+                            StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+                        );
+                    }
+                    statement.make_nop();
+                    return;
+                }
+            }
+
+            _ => {}
+        }
+        self.super_statement(statement, location)
+    }
+
+    #[instrument(level = "trace", skip(self))]
     fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
         match &mut var_debug_info.value {
             VarDebugInfoContents::Place(ref mut place) => {
-                if let Some(repl) = self.replace_place(place.as_ref()) {
+                if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
                     *place = repl;
-                } else if self.all_dead_locals.contains(place.local) {
+                } else if let Some(local) = place.as_local()
+                    && let Some(fragments) = self.gather_debug_info_fragments(local)
+                {
                     let ty = place.ty(self.local_decls, self.tcx).ty;
-                    let fragments = self.gather_debug_info_fragments(place.as_ref());
                     var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
                 }
             }
             VarDebugInfoContents::Composite { ty: _, ref mut fragments } => {
                 let mut new_fragments = Vec::new();
+                debug!(?fragments);
                 fragments
                     .drain_filter(|fragment| {
-                        if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
+                        if let Some(repl) =
+                            self.replacements.replace_place(self.tcx, fragment.contents.as_ref())
+                        {
                             fragment.contents = repl;
-                            true
-                        } else if self.all_dead_locals.contains(fragment.contents.local) {
-                            let frg = self.gather_debug_info_fragments(fragment.contents.as_ref());
+                            false
+                        } else if let Some(local) = fragment.contents.as_local()
+                            && let Some(frg) = self.gather_debug_info_fragments(local)
+                        {
                             new_fragments.extend(frg.into_iter().map(|mut f| {
                                 f.projection.splice(0..0, fragment.projection.iter().copied());
                                 f
                             }));
-                            false
-                        } else {
                             true
+                        } else {
+                            false
                         }
                     })
                     .for_each(drop);
+                debug!(?fragments);
+                debug!(?new_fragments);
                 fragments.extend(new_fragments);
             }
             VarDebugInfoContents::Const(_) => {}
         }
     }
 
-    fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) {
-        self.super_basic_block_data(bb, bbdata);
-
-        #[derive(Debug)]
-        enum Stmt {
-            StorageLive,
-            StorageDead,
-            Deinit,
-        }
-
-        bbdata.expand_statements(|stmt| {
-            let source_info = stmt.source_info;
-            let (stmt, origin_local) = match &stmt.kind {
-                StatementKind::StorageLive(l) => (Stmt::StorageLive, *l),
-                StatementKind::StorageDead(l) => (Stmt::StorageDead, *l),
-                StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l),
-                _ => return None,
-            };
-            if !self.all_dead_locals.contains(origin_local) {
-                return None;
-            }
-            let final_locals = self.fragments.get(origin_local)?;
-            Some(final_locals.iter().map(move |&(_, l)| {
-                let kind = match stmt {
-                    Stmt::StorageLive => StatementKind::StorageLive(l),
-                    Stmt::StorageDead => StatementKind::StorageDead(l),
-                    Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())),
-                };
-                Statement { source_info, kind }
-            }))
-        });
-    }
-
     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
         assert!(!self.all_dead_locals.contains(*local));
     }