about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/sroa.rs431
1 files changed, 253 insertions, 178 deletions
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs
index 42124f5a480..26acd406ed8 100644
--- a/compiler/rustc_mir_transform/src/sroa.rs
+++ b/compiler/rustc_mir_transform/src/sroa.rs
@@ -1,10 +1,11 @@
 use crate::MirPass;
-use rustc_data_structures::fx::{FxIndexMap, IndexEntry};
 use rustc_index::bit_set::BitSet;
 use rustc_index::vec::IndexVec;
+use rustc_middle::mir::patch::MirPatch;
 use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
-use rustc_middle::ty::TyCtxt;
+use rustc_middle::ty::{Ty, TyCtxt};
+use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
 
 pub struct ScalarReplacementOfAggregates;
 
@@ -13,27 +14,41 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
         sess.mir_opt_level() >= 3
     }
 
+    #[instrument(level = "debug", skip(self, tcx, body))]
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
-        let escaping = escaping_locals(&*body);
-        debug!(?escaping);
-        let replacements = compute_flattening(tcx, body, escaping);
-        debug!(?replacements);
-        replace_flattened_locals(tcx, body, replacements);
+        debug!(def_id = ?body.source.def_id());
+        let mut excluded = excluded_locals(body);
+        loop {
+            debug!(?excluded);
+            let escaping = escaping_locals(&excluded, body);
+            debug!(?escaping);
+            let replacements = compute_flattening(tcx, body, escaping);
+            debug!(?replacements);
+            let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
+            if !all_dead_locals.is_empty() {
+                for local in excluded.indices() {
+                    excluded[local] |= all_dead_locals.contains(local);
+                }
+                excluded.raw.resize(body.local_decls.len(), false);
+            } else {
+                break;
+            }
+        }
     }
 }
 
 /// Identify all locals that are not eligible for SROA.
 ///
 /// There are 3 cases:
-/// - the aggegated local is used or passed to other code (function parameters and arguments);
+/// - the aggregated local is used or passed to other code (function parameters and arguments);
 /// - the locals is a union or an enum;
 /// - the local's address is taken, and thus the relative addresses of the fields are observable to
 ///   client code.
-fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
+fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<Local> {
     let mut set = BitSet::new_empty(body.local_decls.len());
     set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
     for (local, decl) in body.local_decls().iter_enumerated() {
-        if decl.ty.is_union() || decl.ty.is_enum() {
+        if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] {
             set.insert(local);
         }
     }
@@ -58,41 +73,33 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
             self.super_place(place, context, location);
         }
 
-        fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
-            if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue {
-                if !place.is_indirect() {
-                    // Raw pointers may be used to access anything inside the enclosing place.
-                    self.set.insert(place.local);
-                    return;
+        fn visit_assign(
+            &mut self,
+            lvalue: &Place<'tcx>,
+            rvalue: &Rvalue<'tcx>,
+            location: Location,
+        ) {
+            if lvalue.as_local().is_some() {
+                match rvalue {
+                    // Aggregate assignments are expanded in run_pass.
+                    Rvalue::Aggregate(..) | Rvalue::Use(..) => {
+                        self.visit_rvalue(rvalue, location);
+                        return;
+                    }
+                    _ => {}
                 }
             }
-            self.super_rvalue(rvalue, location)
+            self.super_assign(lvalue, rvalue, location)
         }
 
         fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
-            if let StatementKind::StorageLive(..)
-            | StatementKind::StorageDead(..)
-            | StatementKind::Deinit(..) = statement.kind
-            {
+            match statement.kind {
                 // Storage statements are expanded in run_pass.
-                return;
+                StatementKind::StorageLive(..)
+                | StatementKind::StorageDead(..)
+                | StatementKind::Deinit(..) => return,
+                _ => self.super_statement(statement, location),
             }
-            self.super_statement(statement, location)
-        }
-
-        fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
-            // Drop implicitly calls `drop_in_place`, which takes a `&mut`.
-            // This implies that `Drop` implicitly takes the address of the place.
-            if let TerminatorKind::Drop { place, .. }
-            | TerminatorKind::DropAndReplace { place, .. } = terminator.kind
-            {
-                if !place.is_indirect() {
-                    // Raw pointers may be used to access anything inside the enclosing place.
-                    self.set.insert(place.local);
-                    return;
-                }
-            }
-            self.super_terminator(terminator, location);
         }
 
         // We ignore anything that happens in debuginfo, since we expand it using
@@ -103,7 +110,30 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
 
 #[derive(Default, Debug)]
 struct ReplacementMap<'tcx> {
-    fields: FxIndexMap<PlaceRef<'tcx>, Local>,
+    /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
+    /// and deinit statement and debuginfo.
+    fragments: IndexVec<Local, Option<IndexVec<Field, Option<(Ty<'tcx>, Local)>>>>,
+}
+
+impl<'tcx> ReplacementMap<'tcx> {
+    fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
+        let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; };
+        let fields = self.fragments[place.local].as_ref()?;
+        let (_, new_local) = fields[f]?;
+        Some(Place { local: new_local, projection: tcx.intern_place_elems(&rest) })
+    }
+
+    fn place_fragments(
+        &self,
+        place: Place<'tcx>,
+    ) -> Option<impl Iterator<Item = (Field, Ty<'tcx>, Local)> + '_> {
+        let local = place.as_local()?;
+        let fields = self.fragments[local].as_ref()?;
+        Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
+            let (ty, local) = opt_ty_local?;
+            Some((field, ty, local))
+        }))
+    }
 }
 
 /// Compute the replacement of flattened places into locals.
@@ -115,53 +145,25 @@ fn compute_flattening<'tcx>(
     body: &mut Body<'tcx>,
     escaping: BitSet<Local>,
 ) -> ReplacementMap<'tcx> {
-    let mut visitor = PreFlattenVisitor {
-        tcx,
-        escaping,
-        local_decls: &mut body.local_decls,
-        map: Default::default(),
-    };
-    for (block, bbdata) in body.basic_blocks.iter_enumerated() {
-        visitor.visit_basic_block_data(block, bbdata);
-    }
-    return visitor.map;
-
-    struct PreFlattenVisitor<'tcx, 'll> {
-        tcx: TyCtxt<'tcx>,
-        local_decls: &'ll mut LocalDecls<'tcx>,
-        escaping: BitSet<Local>,
-        map: ReplacementMap<'tcx>,
-    }
-
-    impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> {
-        fn create_place(&mut self, place: PlaceRef<'tcx>) {
-            if self.escaping.contains(place.local) {
-                return;
-            }
+    let mut fragments = IndexVec::from_elem(None, &body.local_decls);
 
-            match self.map.fields.entry(place) {
-                IndexEntry::Occupied(_) => {}
-                IndexEntry::Vacant(v) => {
-                    let ty = place.ty(&*self.local_decls, self.tcx).ty;
-                    let local = self.local_decls.push(LocalDecl {
-                        ty,
-                        user_ty: None,
-                        ..self.local_decls[place.local].clone()
-                    });
-                    v.insert(local);
-                }
-            }
-        }
-    }
-
-    impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> {
-        fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) {
-            if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
-                let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
-                self.create_place(pr)
-            }
+    for local in body.local_decls.indices() {
+        if escaping.contains(local) {
+            continue;
         }
+        let decl = body.local_decls[local].clone();
+        let ty = decl.ty;
+        iter_fields(ty, tcx, |variant, field, field_ty| {
+            if variant.is_some() {
+                // Downcasts are currently not supported.
+                return;
+            };
+            let new_local =
+                body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
+            fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
+        });
     }
+    ReplacementMap { fragments }
 }
 
 /// Perform the replacement computed by `compute_flattening`.
@@ -169,29 +171,24 @@ fn replace_flattened_locals<'tcx>(
     tcx: TyCtxt<'tcx>,
     body: &mut Body<'tcx>,
     replacements: ReplacementMap<'tcx>,
-) {
+) -> BitSet<Local> {
     let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
-    for p in replacements.fields.keys() {
-        all_dead_locals.insert(p.local);
+    for (local, replacements) in replacements.fragments.iter_enumerated() {
+        if replacements.is_some() {
+            all_dead_locals.insert(local);
+        }
     }
     debug!(?all_dead_locals);
     if all_dead_locals.is_empty() {
-        return;
+        return all_dead_locals;
     }
 
-    let mut fragments = IndexVec::new();
-    for (k, v) in &replacements.fields {
-        fragments.ensure_contains_elem(k.local, || Vec::new());
-        fragments[k.local].push((k.projection, *v));
-    }
-    debug!(?fragments);
-
     let mut visitor = ReplacementVisitor {
         tcx,
         local_decls: &body.local_decls,
-        replacements,
+        replacements: &replacements,
         all_dead_locals,
-        fragments,
+        patch: MirPatch::new(body),
     };
     for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
         visitor.visit_basic_block_data(bb, data);
@@ -205,6 +202,9 @@ fn replace_flattened_locals<'tcx>(
     for var_debug_info in &mut body.var_debug_info {
         visitor.visit_var_debug_info(var_debug_info);
     }
+    let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
+    patch.apply(body);
+    all_dead_locals
 }
 
 struct ReplacementVisitor<'tcx, 'll> {
@@ -212,40 +212,23 @@ struct ReplacementVisitor<'tcx, 'll> {
     /// This is only used to compute the type for `VarDebugInfoContents::Composite`.
     local_decls: &'ll LocalDecls<'tcx>,
     /// Work to do.
-    replacements: ReplacementMap<'tcx>,
+    replacements: &'ll ReplacementMap<'tcx>,
     /// This is used to check that we are not leaving references to replaced locals behind.
     all_dead_locals: BitSet<Local>,
-    /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
-    /// and deinit statement and debuginfo.
-    fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
+    patch: MirPatch<'tcx>,
 }
 
-impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
-    fn gather_debug_info_fragments(
-        &self,
-        place: PlaceRef<'tcx>,
-    ) -> Vec<VarDebugInfoFragment<'tcx>> {
+impl<'tcx> ReplacementVisitor<'tcx, '_> {
+    fn gather_debug_info_fragments(&self, local: Local) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
         let mut fragments = Vec::new();
-        let parts = &self.fragments[place.local];
-        for (proj, replacement_local) in parts {
-            if proj.starts_with(place.projection) {
-                fragments.push(VarDebugInfoFragment {
-                    projection: proj[place.projection.len()..].to_vec(),
-                    contents: Place::from(*replacement_local),
-                });
-            }
-        }
-        fragments
-    }
-
-    fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
-        if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection {
-            let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
-            let local = self.replacements.fields.get(&pr)?;
-            Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) })
-        } else {
-            None
+        let parts = self.replacements.place_fragments(local.into())?;
+        for (field, ty, replacement_local) in parts {
+            fragments.push(VarDebugInfoFragment {
+                projection: vec![PlaceElem::Field(field, ty)],
+                contents: Place::from(replacement_local),
+            });
         }
+        Some(fragments)
     }
 }
 
@@ -254,94 +237,186 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
         self.tcx
     }
 
-    fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
-        if let StatementKind::StorageLive(..)
-        | StatementKind::StorageDead(..)
-        | StatementKind::Deinit(..) = statement.kind
-        {
-            // Storage statements are expanded in run_pass.
-            return;
-        }
-        self.super_statement(statement, location)
-    }
-
     fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
-        if let Some(repl) = self.replace_place(place.as_ref()) {
+        if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
             *place = repl
         } else {
             self.super_place(place, context, location)
         }
     }
 
+    #[instrument(level = "trace", skip(self))]
+    fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
+        match statement.kind {
+            // Duplicate storage and deinit statements, as they pretty much apply to all fields.
+            StatementKind::StorageLive(l) => {
+                if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
+                    for (_, _, fl) in final_locals {
+                        self.patch.add_statement(location, StatementKind::StorageLive(fl));
+                    }
+                    statement.make_nop();
+                }
+                return;
+            }
+            StatementKind::StorageDead(l) => {
+                if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
+                    for (_, _, fl) in final_locals {
+                        self.patch.add_statement(location, StatementKind::StorageDead(fl));
+                    }
+                    statement.make_nop();
+                }
+                return;
+            }
+            StatementKind::Deinit(box place) => {
+                if let Some(final_locals) = self.replacements.place_fragments(place) {
+                    for (_, _, fl) in final_locals {
+                        self.patch
+                            .add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
+                    }
+                    statement.make_nop();
+                    return;
+                }
+            }
+
+            // We have `a = Struct { 0: x, 1: y, .. }`.
+            // We replace it by
+            // ```
+            // a_0 = x
+            // a_1 = y
+            // ...
+            // ```
+            StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
+                if let Some(local) = place.as_local()
+                    && let Some(final_locals) = &self.replacements.fragments[local]
+                {
+                    // This is ok as we delete the statement later.
+                    let operands = std::mem::take(operands);
+                    for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
+                        if let Some((_, new_local)) = opt_ty_local {
+                            // Replace mentions of SROA'd locals that appear in the operand.
+                            self.visit_operand(&mut operand, location);
+
+                            let rvalue = Rvalue::Use(operand);
+                            self.patch.add_statement(
+                                location,
+                                StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+                            );
+                        }
+                    }
+                    statement.make_nop();
+                    return;
+                }
+            }
+
+            // We have `a = some constant`
+            // We add the projections.
+            // ```
+            // a_0 = a.0
+            // a_1 = a.1
+            // ...
+            // ```
+            // ConstProp will pick up the pieces and replace them by actual constants.
+            StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
+                if let Some(final_locals) = self.replacements.place_fragments(place) {
+                    for (field, ty, new_local) in final_locals {
+                        let rplace = self.tcx.mk_place_field(place, field, ty);
+                        let rvalue = Rvalue::Use(Operand::Move(rplace));
+                        self.patch.add_statement(
+                            location,
+                            StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+                        );
+                    }
+                    // We still need `place.local` to exist, so don't make it nop.
+                    return;
+                }
+            }
+
+            // We have `a = move? place`
+            // We replace it by
+            // ```
+            // a_0 = move? place.0
+            // a_1 = move? place.1
+            // ...
+            // ```
+            StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
+                let (rplace, copy) = match *op {
+                    Operand::Copy(rplace) => (rplace, true),
+                    Operand::Move(rplace) => (rplace, false),
+                    Operand::Constant(_) => bug!(),
+                };
+                if let Some(final_locals) = self.replacements.place_fragments(lhs) {
+                    for (field, ty, new_local) in final_locals {
+                        let rplace = self.tcx.mk_place_field(rplace, field, ty);
+                        debug!(?rplace);
+                        let rplace = self
+                            .replacements
+                            .replace_place(self.tcx, rplace.as_ref())
+                            .unwrap_or(rplace);
+                        debug!(?rplace);
+                        let rvalue = if copy {
+                            Rvalue::Use(Operand::Copy(rplace))
+                        } else {
+                            Rvalue::Use(Operand::Move(rplace))
+                        };
+                        self.patch.add_statement(
+                            location,
+                            StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+                        );
+                    }
+                    statement.make_nop();
+                    return;
+                }
+            }
+
+            _ => {}
+        }
+        self.super_statement(statement, location)
+    }
+
+    #[instrument(level = "trace", skip(self))]
     fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
         match &mut var_debug_info.value {
             VarDebugInfoContents::Place(ref mut place) => {
-                if let Some(repl) = self.replace_place(place.as_ref()) {
+                if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
                     *place = repl;
-                } else if self.all_dead_locals.contains(place.local) {
+                } else if let Some(local) = place.as_local()
+                    && let Some(fragments) = self.gather_debug_info_fragments(local)
+                {
                     let ty = place.ty(self.local_decls, self.tcx).ty;
-                    let fragments = self.gather_debug_info_fragments(place.as_ref());
                     var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
                 }
             }
             VarDebugInfoContents::Composite { ty: _, ref mut fragments } => {
                 let mut new_fragments = Vec::new();
+                debug!(?fragments);
                 fragments
                     .drain_filter(|fragment| {
-                        if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
+                        if let Some(repl) =
+                            self.replacements.replace_place(self.tcx, fragment.contents.as_ref())
+                        {
                             fragment.contents = repl;
-                            true
-                        } else if self.all_dead_locals.contains(fragment.contents.local) {
-                            let frg = self.gather_debug_info_fragments(fragment.contents.as_ref());
+                            false
+                        } else if let Some(local) = fragment.contents.as_local()
+                            && let Some(frg) = self.gather_debug_info_fragments(local)
+                        {
                             new_fragments.extend(frg.into_iter().map(|mut f| {
                                 f.projection.splice(0..0, fragment.projection.iter().copied());
                                 f
                             }));
-                            false
-                        } else {
                             true
+                        } else {
+                            false
                         }
                     })
                     .for_each(drop);
+                debug!(?fragments);
+                debug!(?new_fragments);
                 fragments.extend(new_fragments);
             }
             VarDebugInfoContents::Const(_) => {}
         }
     }
 
-    fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) {
-        self.super_basic_block_data(bb, bbdata);
-
-        #[derive(Debug)]
-        enum Stmt {
-            StorageLive,
-            StorageDead,
-            Deinit,
-        }
-
-        bbdata.expand_statements(|stmt| {
-            let source_info = stmt.source_info;
-            let (stmt, origin_local) = match &stmt.kind {
-                StatementKind::StorageLive(l) => (Stmt::StorageLive, *l),
-                StatementKind::StorageDead(l) => (Stmt::StorageDead, *l),
-                StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l),
-                _ => return None,
-            };
-            if !self.all_dead_locals.contains(origin_local) {
-                return None;
-            }
-            let final_locals = self.fragments.get(origin_local)?;
-            Some(final_locals.iter().map(move |&(_, l)| {
-                let kind = match stmt {
-                    Stmt::StorageLive => StatementKind::StorageLive(l),
-                    Stmt::StorageDead => StatementKind::StorageDead(l),
-                    Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())),
-                };
-                Statement { source_info, kind }
-            }))
-        });
-    }
-
     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
         assert!(!self.all_dead_locals.contains(*local));
     }