about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_mir_transform/src/match_branches.rs104
1 files changed, 54 insertions, 50 deletions
diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs
index a444df34048..e9203043769 100644
--- a/compiler/rustc_mir_transform/src/match_branches.rs
+++ b/compiler/rustc_mir_transform/src/match_branches.rs
@@ -1,4 +1,5 @@
-use rustc_index::IndexVec;
+use rustc_index::IndexSlice;
+use rustc_middle::mir::patch::MirPatch;
 use rustc_middle::mir::*;
 use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
 use rustc_target::abi::Size;
@@ -17,9 +18,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
         let def_id = body.source.def_id();
         let param_env = tcx.param_env_reveal_all_normalized(def_id);
 
-        let bbs = body.basic_blocks.as_mut();
         let mut should_cleanup = false;
-        for bb_idx in bbs.indices() {
+        for i in 0..body.basic_blocks.len() {
+            let bbs = &*body.basic_blocks;
+            let bb_idx = BasicBlock::from_usize(i);
             if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) {
                 continue;
             }
@@ -35,12 +37,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
                 _ => continue,
             };
 
-            if SimplifyToIf.simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) {
+            if SimplifyToIf.simplify(tcx, body, bb_idx, param_env) {
                 should_cleanup = true;
                 continue;
             }
-            if SimplifyToExp::default().simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env)
-            {
+            if SimplifyToExp::default().simplify(tcx, body, bb_idx, param_env) {
                 should_cleanup = true;
                 continue;
             }
@@ -58,41 +59,39 @@ trait SimplifyMatch<'tcx> {
     fn simplify(
         &mut self,
         tcx: TyCtxt<'tcx>,
-        local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
-        bbs: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+        body: &mut Body<'tcx>,
         switch_bb_idx: BasicBlock,
         param_env: ParamEnv<'tcx>,
     ) -> bool {
+        let bbs = &body.basic_blocks;
         let (discr, targets) = match bbs[switch_bb_idx].terminator().kind {
             TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets),
             _ => unreachable!(),
         };
 
-        let discr_ty = discr.ty(local_decls, tcx);
+        let discr_ty = discr.ty(body.local_decls(), tcx);
         if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) {
             return false;
         }
 
+        let mut patch = MirPatch::new(body);
+
         // Take ownership of items now that we know we can optimize.
         let discr = discr.clone();
 
         // Introduce a temporary for the discriminant value.
         let source_info = bbs[switch_bb_idx].terminator().source_info;
-        let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span));
+        let discr_local = patch.new_temp(discr_ty, source_info.span);
 
-        let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local, discr_ty);
         let (_, first) = targets.iter().next().unwrap();
-        let (from, first) = bbs.pick2_mut(switch_bb_idx, first);
-        from.statements
-            .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) });
-        from.statements.push(Statement {
-            source_info,
-            kind: StatementKind::Assign(Box::new((Place::from(discr_local), Rvalue::Use(discr)))),
-        });
-        from.statements.extend(new_stmts);
-        from.statements
-            .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) });
-        from.terminator_mut().kind = first.terminator().kind.clone();
+        let statement_index = bbs[switch_bb_idx].statements.len();
+        let parent_end = Location { block: switch_bb_idx, statement_index };
+        patch.add_statement(parent_end, StatementKind::StorageLive(discr_local));
+        patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr));
+        self.new_stmts(tcx, targets, param_env, &mut patch, parent_end, bbs, discr_local, discr_ty);
+        patch.add_statement(parent_end, StatementKind::StorageDead(discr_local));
+        patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone());
+        patch.apply(body);
         true
     }
 
@@ -104,7 +103,7 @@ trait SimplifyMatch<'tcx> {
         tcx: TyCtxt<'tcx>,
         targets: &SwitchTargets,
         param_env: ParamEnv<'tcx>,
-        bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
         discr_ty: Ty<'tcx>,
     ) -> bool;
 
@@ -113,10 +112,12 @@ trait SimplifyMatch<'tcx> {
         tcx: TyCtxt<'tcx>,
         targets: &SwitchTargets,
         param_env: ParamEnv<'tcx>,
-        bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+        patch: &mut MirPatch<'tcx>,
+        parent_end: Location,
+        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
         discr_local: Local,
         discr_ty: Ty<'tcx>,
-    ) -> Vec<Statement<'tcx>>;
+    );
 }
 
 struct SimplifyToIf;
@@ -158,7 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
         tcx: TyCtxt<'tcx>,
         targets: &SwitchTargets,
         param_env: ParamEnv<'tcx>,
-        bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
         _discr_ty: Ty<'tcx>,
     ) -> bool {
         if targets.iter().len() != 1 {
@@ -209,20 +210,23 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
         tcx: TyCtxt<'tcx>,
         targets: &SwitchTargets,
         param_env: ParamEnv<'tcx>,
-        bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+        patch: &mut MirPatch<'tcx>,
+        parent_end: Location,
+        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
         discr_local: Local,
         discr_ty: Ty<'tcx>,
-    ) -> Vec<Statement<'tcx>> {
+    ) {
         let (val, first) = targets.iter().next().unwrap();
         let second = targets.otherwise();
         // We already checked that first and second are different blocks,
         // and bb_idx has a different terminator from both of them.
         let first = &bbs[first];
         let second = &bbs[second];
-
-        let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| {
+        for (f, s) in iter::zip(&first.statements, &second.statements) {
             match (&f.kind, &s.kind) {
-                (f_s, s_s) if f_s == s_s => (*f).clone(),
+                (f_s, s_s) if f_s == s_s => {
+                    patch.add_statement(parent_end, f.kind.clone());
+                }
 
                 (
                     StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
@@ -233,7 +237,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
                     let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap();
                     if f_b == s_b {
                         // Same value in both blocks. Use statement as is.
-                        (*f).clone()
+                        patch.add_statement(parent_end, f.kind.clone());
                     } else {
                         // Different value between blocks. Make value conditional on switch condition.
                         let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
@@ -248,17 +252,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
                             op,
                             Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)),
                         );
-                        Statement {
-                            source_info: f.source_info,
-                            kind: StatementKind::Assign(Box::new((*lhs, rhs))),
-                        }
+                        patch.add_assign(parent_end, *lhs, rhs);
                     }
                 }
 
                 _ => unreachable!(),
             }
-        });
-        new_stmts.collect()
+        }
     }
 }
 
@@ -335,7 +335,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
         tcx: TyCtxt<'tcx>,
         targets: &SwitchTargets,
         param_env: ParamEnv<'tcx>,
-        bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
         discr_ty: Ty<'tcx>,
     ) -> bool {
         if targets.iter().len() < 2 || targets.iter().len() > 64 {
@@ -372,6 +372,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
                 == ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap()
         }
 
+        // We first compare the two branches, and then the other branches need to fulfill the same conditions.
         let mut compare_types = Vec::new();
         for (f, s) in iter::zip(first_stmts, second_stmts) {
             let compare_type = match (&f.kind, &s.kind) {
@@ -391,6 +392,8 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
                         s_c.const_.try_eval_scalar_int(tcx, param_env),
                     ) {
                         (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
+                        // Enum variants can also be simplified to an assignment statement if their values are equal.
+                        // We need to consider both unsigned and signed scenarios here.
                         (Some(f), Some(s))
                             if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
                                 && int_equal(f, first_val, discr_size)
@@ -463,16 +466,20 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
         _tcx: TyCtxt<'tcx>,
         targets: &SwitchTargets,
         _param_env: ParamEnv<'tcx>,
-        bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+        patch: &mut MirPatch<'tcx>,
+        parent_end: Location,
+        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
         discr_local: Local,
         discr_ty: Ty<'tcx>,
-    ) -> Vec<Statement<'tcx>> {
+    ) {
         let (_, first) = targets.iter().next().unwrap();
         let first = &bbs[first];
 
-        let new_stmts =
-            iter::zip(&self.transfrom_types, &first.statements).map(|(t, s)| match (t, &s.kind) {
-                (TransfromType::Same, _) | (TransfromType::Eq, _) => (*s).clone(),
+        for (t, s) in iter::zip(&self.transfrom_types, &first.statements) {
+            match (t, &s.kind) {
+                (TransfromType::Same, _) | (TransfromType::Eq, _) => {
+                    patch.add_statement(parent_end, s.kind.clone());
+                }
                 (
                     TransfromType::Discr,
                     StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
@@ -483,13 +490,10 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
                     } else {
                         Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty())
                     };
-                    Statement {
-                        source_info: s.source_info,
-                        kind: StatementKind::Assign(Box::new((*lhs, r_val))),
-                    }
+                    patch.add_assign(parent_end, *lhs, r_val);
                 }
                 _ => unreachable!(),
-            });
-        new_stmts.collect()
+            }
+        }
     }
 }