about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCamille GILLOT <gillot.camille@gmail.com>2023-02-05 12:08:42 +0000
committerCamille GILLOT <gillot.camille@gmail.com>2023-02-05 12:08:42 +0000
commit8e05ab04e54f2531198bb61c32ad2232d682a63c (patch)
tree49b3a1d580aede5372672bde86aafab7b9b47088
parent42c95146294c7773ca03e91e945fd545c6ce1ba2 (diff)
downloadrust-8e05ab04e54f2531198bb61c32ad2232d682a63c.tar.gz
rust-8e05ab04e54f2531198bb61c32ad2232d682a63c.zip
Run SROA to fixpoint.
-rw-r--r--compiler/rustc_mir_dataflow/src/value_analysis.rs2
-rw-r--r--compiler/rustc_mir_transform/src/sroa.rs74
-rw-r--r--tests/mir-opt/const_prop/mutable_variable_aggregate.main.ConstProp.diff27
-rw-r--r--tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff42
4 files changed, 78 insertions, 67 deletions
diff --git a/compiler/rustc_mir_dataflow/src/value_analysis.rs b/compiler/rustc_mir_dataflow/src/value_analysis.rs
index 90d07c81256..8bf6493be4b 100644
--- a/compiler/rustc_mir_dataflow/src/value_analysis.rs
+++ b/compiler/rustc_mir_dataflow/src/value_analysis.rs
@@ -824,7 +824,7 @@ pub fn iter_fields<'tcx>(
 }
 
 /// Returns all locals with projections that have their reference or address taken.
-fn excluded_locals(body: &Body<'_>) -> IndexVec<Local, bool> {
+pub fn excluded_locals(body: &Body<'_>) -> IndexVec<Local, bool> {
     struct Collector {
         result: IndexVec<Local, bool>,
     }
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs
index 3cfa0b16499..28963c77aa5 100644
--- a/compiler/rustc_mir_transform/src/sroa.rs
+++ b/compiler/rustc_mir_transform/src/sroa.rs
@@ -6,7 +6,7 @@ use rustc_middle::mir::patch::MirPatch;
 use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
 use rustc_middle::ty::TyCtxt;
-use rustc_mir_dataflow::value_analysis::iter_fields;
+use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
 
 pub struct ScalarReplacementOfAggregates;
 
@@ -18,26 +18,38 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
     #[instrument(level = "debug", skip(self, tcx, body))]
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         debug!(def_id = ?body.source.def_id());
-        let escaping = escaping_locals(&*body);
-        debug!(?escaping);
-        let replacements = compute_flattening(tcx, body, escaping);
-        debug!(?replacements);
-        replace_flattened_locals(tcx, body, replacements);
+        let mut excluded = excluded_locals(body);
+        loop {
+            debug!(?excluded);
+            let escaping = escaping_locals(&excluded, body);
+            debug!(?escaping);
+            let replacements = compute_flattening(tcx, body, escaping);
+            debug!(?replacements);
+            let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
+            if !all_dead_locals.is_empty() && tcx.sess.mir_opt_level() >= 4 {
+                for local in excluded.indices() {
+                    excluded[local] |= all_dead_locals.contains(local) ;
+                }
+                excluded.raw.resize(body.local_decls.len(), false);
+            } else {
+                break
+            }
+        }
     }
 }
 
 /// Identify all locals that are not eligible for SROA.
 ///
 /// There are 3 cases:
-/// - the aggegated local is used or passed to other code (function parameters and arguments);
+/// - the aggregated local is used or passed to other code (function parameters and arguments);
 /// - the locals is a union or an enum;
 /// - the local's address is taken, and thus the relative addresses of the fields are observable to
 ///   client code.
-fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
+fn escaping_locals(excluded: &IndexVec<Local, bool>, body: &Body<'_>) -> BitSet<Local> {
     let mut set = BitSet::new_empty(body.local_decls.len());
     set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
     for (local, decl) in body.local_decls().iter_enumerated() {
-        if decl.ty.is_union() || decl.ty.is_enum() {
+        if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] {
             set.insert(local);
         }
     }
@@ -62,17 +74,6 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
             self.super_place(place, context, location);
         }
 
-        fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
-            if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue {
-                if !place.is_indirect() {
-                    // Raw pointers may be used to access anything inside the enclosing place.
-                    self.set.insert(place.local);
-                    return;
-                }
-            }
-            self.super_rvalue(rvalue, location)
-        }
-
         fn visit_assign(
             &mut self,
             lvalue: &Place<'tcx>,
@@ -102,21 +103,6 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
             }
         }
 
-        fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
-            // Drop implicitly calls `drop_in_place`, which takes a `&mut`.
-            // This implies that `Drop` implicitly takes the address of the place.
-            if let TerminatorKind::Drop { place, .. }
-            | TerminatorKind::DropAndReplace { place, .. } = terminator.kind
-            {
-                if !place.is_indirect() {
-                    // Raw pointers may be used to access anything inside the enclosing place.
-                    self.set.insert(place.local);
-                    return;
-                }
-            }
-            self.super_terminator(terminator, location);
-        }
-
         // We ignore anything that happens in debuginfo, since we expand it using
         // `VarDebugInfoContents::Composite`.
         fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
@@ -198,14 +184,14 @@ fn replace_flattened_locals<'tcx>(
     tcx: TyCtxt<'tcx>,
     body: &mut Body<'tcx>,
     replacements: ReplacementMap<'tcx>,
-) {
+) -> BitSet<Local> {
     let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
     for p in replacements.fields.keys() {
         all_dead_locals.insert(p.local);
     }
     debug!(?all_dead_locals);
     if all_dead_locals.is_empty() {
-        return;
+        return all_dead_locals;
     }
 
     let mut visitor = ReplacementVisitor {
@@ -227,7 +213,9 @@ fn replace_flattened_locals<'tcx>(
     for var_debug_info in &mut body.var_debug_info {
         visitor.visit_var_debug_info(var_debug_info);
     }
-    visitor.patch.apply(body);
+    let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
+    patch.apply(body);
+    all_dead_locals
 }
 
 struct ReplacementVisitor<'tcx, 'll> {
@@ -361,6 +349,7 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
         }
     }
 
+    #[instrument(level = "trace", skip(self))]
     fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
         match &mut var_debug_info.value {
             VarDebugInfoContents::Place(ref mut place) => {
@@ -375,11 +364,12 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
             }
             VarDebugInfoContents::Composite { ty: _, ref mut fragments } => {
                 let mut new_fragments = Vec::new();
+                debug!(?fragments);
                 fragments
                     .drain_filter(|fragment| {
                         if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
                             fragment.contents = repl;
-                            true
+                            false
                         } else if let Some(frg) = self
                             .replacements
                             .gather_debug_info_fragments(fragment.contents.as_ref())
@@ -388,12 +378,14 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
                                 f.projection.splice(0..0, fragment.projection.iter().copied());
                                 f
                             }));
-                            false
-                        } else {
                             true
+                        } else {
+                            false
                         }
                     })
                     .for_each(drop);
+                debug!(?fragments);
+                debug!(?new_fragments);
                 fragments.extend(new_fragments);
             }
             VarDebugInfoContents::Const(_) => {}
diff --git a/tests/mir-opt/const_prop/mutable_variable_aggregate.main.ConstProp.diff b/tests/mir-opt/const_prop/mutable_variable_aggregate.main.ConstProp.diff
index 37fbcf9dd49..d088c4f662b 100644
--- a/tests/mir-opt/const_prop/mutable_variable_aggregate.main.ConstProp.diff
+++ b/tests/mir-opt/const_prop/mutable_variable_aggregate.main.ConstProp.diff
@@ -3,30 +3,27 @@
   
   fn main() -> () {
       let mut _0: ();                      // return place in scope 0 at $DIR/mutable_variable_aggregate.rs:+0:11: +0:11
-      let mut _1: (i32, i32);              // in scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
+      let mut _3: i32;                     // in scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
+      let mut _4: i32;                     // in scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
       scope 1 {
-          debug x => _1;                   // in scope 1 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
+          debug x => (i32, i32){ .0 => _3, .1 => _4, }; // in scope 1 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
+          let _1: i32;                     // in scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
           let _2: i32;                     // in scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
-          let _3: i32;                     // in scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
           scope 2 {
-              debug y => (i32, i32){ .0 => _2, .1 => _3, }; // in scope 2 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
+              debug y => (i32, i32){ .0 => _3, .1 => _2, }; // in scope 2 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
           }
       }
   
       bb0: {
-          StorageLive(_1);                 // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
--         _1 = (const 42_i32, const 43_i32); // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25
-+         _1 = const (42_i32, 43_i32);     // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25
-          (_1.1: i32) = const 99_i32;      // scope 1 at $DIR/mutable_variable_aggregate.rs:+2:5: +2:13
+          StorageLive(_4);                 // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:9: +1:14
+          _3 = const 42_i32;               // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25
+          _4 = const 43_i32;               // scope 0 at $DIR/mutable_variable_aggregate.rs:+1:17: +1:25
+          _4 = const 99_i32;               // scope 1 at $DIR/mutable_variable_aggregate.rs:+2:5: +2:13
           StorageLive(_2);                 // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
-          StorageLive(_3);                 // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:9: +3:10
--         _2 = (_1.0: i32);                // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
--         _3 = (_1.1: i32);                // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
-+         _2 = const 42_i32;               // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
-+         _3 = const 99_i32;               // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
+-         _2 = _4;                         // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
++         _2 = const 99_i32;               // scope 1 at $DIR/mutable_variable_aggregate.rs:+3:13: +3:14
           StorageDead(_2);                 // scope 1 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2
-          StorageDead(_3);                 // scope 1 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2
-          StorageDead(_1);                 // scope 0 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2
+          StorageDead(_4);                 // scope 0 at $DIR/mutable_variable_aggregate.rs:+4:1: +4:2
           return;                          // scope 0 at $DIR/mutable_variable_aggregate.rs:+4:2: +4:2
       }
   }
diff --git a/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff b/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff
index b76e2d6d0f2..976f6d44b75 100644
--- a/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff
+++ b/tests/mir-opt/sroa.copies.ScalarReplacementOfAggregates.diff
@@ -5,8 +5,13 @@
       debug x => _1;                       // in scope 0 at $DIR/sroa.rs:+0:11: +0:12
       let mut _0: ();                      // return place in scope 0 at $DIR/sroa.rs:+0:19: +0:19
       let _2: Foo;                         // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
++     let _11: u8;                         // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
++     let _12: ();                         // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
++     let _13: &str;                       // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
++     let _14: std::option::Option<isize>; // in scope 0 at $DIR/sroa.rs:+1:9: +1:10
       scope 1 {
-          debug y => _2;                   // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
+-         debug y => _2;                   // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
++         debug y => Foo{ .0 => _11, .1 => _12, .2 => _13, .3 => _14, }; // in scope 1 at $DIR/sroa.rs:+1:9: +1:10
           let _3: u8;                      // in scope 1 at $DIR/sroa.rs:+2:9: +2:10
           scope 2 {
               debug t => _3;               // in scope 2 at $DIR/sroa.rs:+2:9: +2:10
@@ -31,23 +36,35 @@
       }
   
       bb0: {
-          StorageLive(_2);                 // scope 0 at $DIR/sroa.rs:+1:9: +1:10
-          _2 = _1;                         // scope 0 at $DIR/sroa.rs:+1:13: +1:14
+-         StorageLive(_2);                 // scope 0 at $DIR/sroa.rs:+1:9: +1:10
+-         _2 = _1;                         // scope 0 at $DIR/sroa.rs:+1:13: +1:14
++         StorageLive(_11);                // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         StorageLive(_12);                // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         StorageLive(_13);                // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         StorageLive(_14);                // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         nop;                             // scope 0 at $DIR/sroa.rs:+1:9: +1:10
++         _11 = (_1.0: u8);                // scope 0 at $DIR/sroa.rs:+1:13: +1:14
++         _12 = (_1.1: ());                // scope 0 at $DIR/sroa.rs:+1:13: +1:14
++         _13 = (_1.2: &str);              // scope 0 at $DIR/sroa.rs:+1:13: +1:14
++         _14 = (_1.3: std::option::Option<isize>); // scope 0 at $DIR/sroa.rs:+1:13: +1:14
++         nop;                             // scope 0 at $DIR/sroa.rs:+1:13: +1:14
           StorageLive(_3);                 // scope 1 at $DIR/sroa.rs:+2:9: +2:10
-          _3 = (_2.0: u8);                 // scope 1 at $DIR/sroa.rs:+2:13: +2:16
+-         _3 = (_2.0: u8);                 // scope 1 at $DIR/sroa.rs:+2:13: +2:16
++         _3 = _11;                        // scope 1 at $DIR/sroa.rs:+2:13: +2:16
           StorageLive(_4);                 // scope 2 at $DIR/sroa.rs:+3:9: +3:10
-          _4 = (_2.2: &str);               // scope 2 at $DIR/sroa.rs:+3:13: +3:16
+-         _4 = (_2.2: &str);               // scope 2 at $DIR/sroa.rs:+3:13: +3:16
 -         StorageLive(_5);                 // scope 3 at $DIR/sroa.rs:+4:9: +4:10
 -         _5 = _2;                         // scope 3 at $DIR/sroa.rs:+4:13: +4:14
++         _4 = _13;                        // scope 2 at $DIR/sroa.rs:+3:13: +3:16
 +         StorageLive(_7);                 // scope 3 at $DIR/sroa.rs:+4:9: +4:10
 +         StorageLive(_8);                 // scope 3 at $DIR/sroa.rs:+4:9: +4:10
 +         StorageLive(_9);                 // scope 3 at $DIR/sroa.rs:+4:9: +4:10
 +         StorageLive(_10);                // scope 3 at $DIR/sroa.rs:+4:9: +4:10
 +         nop;                             // scope 3 at $DIR/sroa.rs:+4:9: +4:10
-+         _7 = (_2.0: u8);                 // scope 3 at $DIR/sroa.rs:+4:13: +4:14
-+         _8 = (_2.1: ());                 // scope 3 at $DIR/sroa.rs:+4:13: +4:14
-+         _9 = (_2.2: &str);               // scope 3 at $DIR/sroa.rs:+4:13: +4:14
-+         _10 = (_2.3: std::option::Option<isize>); // scope 3 at $DIR/sroa.rs:+4:13: +4:14
++         _7 = _11;                        // scope 3 at $DIR/sroa.rs:+4:13: +4:14
++         _8 = _12;                        // scope 3 at $DIR/sroa.rs:+4:13: +4:14
++         _9 = _13;                        // scope 3 at $DIR/sroa.rs:+4:13: +4:14
++         _10 = _14;                       // scope 3 at $DIR/sroa.rs:+4:13: +4:14
 +         nop;                             // scope 3 at $DIR/sroa.rs:+4:13: +4:14
           StorageLive(_6);                 // scope 4 at $DIR/sroa.rs:+5:9: +5:10
 -         _6 = (_5.1: ());                 // scope 4 at $DIR/sroa.rs:+5:13: +5:16
@@ -62,7 +79,12 @@
 +         nop;                             // scope 3 at $DIR/sroa.rs:+6:1: +6:2
           StorageDead(_4);                 // scope 2 at $DIR/sroa.rs:+6:1: +6:2
           StorageDead(_3);                 // scope 1 at $DIR/sroa.rs:+6:1: +6:2
-          StorageDead(_2);                 // scope 0 at $DIR/sroa.rs:+6:1: +6:2
+-         StorageDead(_2);                 // scope 0 at $DIR/sroa.rs:+6:1: +6:2
++         StorageDead(_11);                // scope 0 at $DIR/sroa.rs:+6:1: +6:2
++         StorageDead(_12);                // scope 0 at $DIR/sroa.rs:+6:1: +6:2
++         StorageDead(_13);                // scope 0 at $DIR/sroa.rs:+6:1: +6:2
++         StorageDead(_14);                // scope 0 at $DIR/sroa.rs:+6:1: +6:2
++         nop;                             // scope 0 at $DIR/sroa.rs:+6:1: +6:2
           return;                          // scope 0 at $DIR/sroa.rs:+6:2: +6:2
       }
   }