about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/librustc_infer/infer/error_reporting/mod.rs2
-rw-r--r--src/librustc_mir/transform/match_branches.rs93
-rw-r--r--src/librustc_mir/transform/mod.rs2
3 files changed, 40 insertions, 57 deletions
diff --git a/src/librustc_infer/infer/error_reporting/mod.rs b/src/librustc_infer/infer/error_reporting/mod.rs
index 26220698190..063246f79fe 100644
--- a/src/librustc_infer/infer/error_reporting/mod.rs
+++ b/src/librustc_infer/infer/error_reporting/mod.rs
@@ -827,7 +827,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
                 ty::GenericParamDefKind::Type { has_default, .. } => {
                     Some((param.def_id, has_default))
                 }
-                ty::GenericParamDefKind::Const { .. } => None, // FIXME(const_generics:defaults)
+                ty::GenericParamDefKind::Const => None, // FIXME(const_generics:defaults)
             })
             .peekable();
         let has_default = {
diff --git a/src/librustc_mir/transform/match_branches.rs b/src/librustc_mir/transform/match_branches.rs
index 5dc84955add..5fab46f029f 100644
--- a/src/librustc_mir/transform/match_branches.rs
+++ b/src/librustc_mir/transform/match_branches.rs
@@ -1,4 +1,4 @@
-use crate::transform::{simplify, MirPass, MirSource};
+use crate::transform::{MirPass, MirSource};
 use rustc_middle::mir::*;
 use rustc_middle::ty::TyCtxt;
 
@@ -12,10 +12,9 @@ pub struct MatchBranchSimplification;
 impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
     fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, body: &mut Body<'tcx>) {
         let param_env = tcx.param_env(src.def_id());
-        let mut did_remove_blocks = false;
         let bbs = body.basic_blocks_mut();
         'outer: for bb_idx in bbs.indices() {
-            let (discr, val, switch_ty, targets) = match bbs[bb_idx].terminator().kind {
+            let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind {
                 TerminatorKind::SwitchInt {
                     discr: Operand::Move(ref place),
                     switch_ty,
@@ -23,60 +22,53 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
                     ref values,
                     ..
                 } if targets.len() == 2 && values.len() == 1 => {
-                    (place.clone(), values[0], switch_ty, targets)
+                    (place, values[0], switch_ty, targets[0], targets[1])
                 }
+                // Only optimize switch int statements
                 _ => continue,
             };
-            let (first, rest) = if let ([first], rest) = targets.split_at(1) {
-                (*first, rest)
-            } else {
-                unreachable!();
-            };
-            let first_dest = bbs[first].terminator().kind.clone();
-            let same_destinations = rest
-                .iter()
-                .map(|target| &bbs[*target].terminator().kind)
-                .all(|t_kind| t_kind == &first_dest);
-            if !same_destinations {
+
+            // Check that destinations are identical, and if not, then don't optimize this block
+            if &bbs[first].terminator().kind != &bbs[second].terminator().kind {
                 continue;
             }
+
+            // Check that blocks are assignments of consts to the same place or same statement,
+            // and match up 1-1, if not don't optimize this block.
             let first_stmts = &bbs[first].statements;
-            for s in first_stmts.iter() {
-                match &s.kind {
-                    StatementKind::Assign(box (_, rhs)) => {
-                        if let Rvalue::Use(Operand::Constant(_)) = rhs {
-                        } else {
-                            continue 'outer;
-                        }
-                    }
-                    _ => continue 'outer,
-                }
+            let scnd_stmts = &bbs[second].statements;
+            if first_stmts.len() != scnd_stmts.len() {
+                continue;
             }
-            for target in rest.iter() {
-                for s in bbs[*target].statements.iter() {
-                    if let StatementKind::Assign(box (ref lhs, rhs)) = &s.kind {
-                        if let Rvalue::Use(Operand::Constant(_)) = rhs {
-                            let has_matching_assn = first_stmts
-                                .iter()
-                                .find(|s| {
-                                    if let StatementKind::Assign(box (lhs_f, _)) = &s.kind {
-                                        lhs_f == lhs
-                                    } else {
-                                        false
-                                    }
-                                })
-                                .is_some();
-                            if has_matching_assn {
-                                continue;
+            for (f, s) in first_stmts.iter().zip(scnd_stmts.iter()) {
+                match (&f.kind, &s.kind) {
+                    // If two statements are exactly the same just ignore them.
+                    (f_s, s_s) if f_s == s_s => (),
+
+                    (
+                        StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
+                        StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
+                    ) if lhs_f == lhs_s => {
+                        if let Some(f_c) = f_c.literal.try_eval_bool(tcx, param_env) {
+                            // This should also be a bool because it's writing to the same place
+                            let s_c = s_c.literal.try_eval_bool(tcx, param_env).unwrap();
+                            // Check that only const assignments of opposite bool values are
+                            // permitted.
+                            if f_c != s_c {
+                              continue
                             }
                         }
+                        continue 'outer;
                     }
-
-                    continue 'outer;
+                    // If there are not exclusively assignments, then ignore this
+                    _ => continue 'outer,
                 }
             }
-            let (first_block, to_add) = bbs.pick2_mut(first, bb_idx);
-            let new_stmts = first_block.statements.iter().cloned().map(|mut s| {
+            // Take owenership of items now that we know we can optimize.
+            let discr = discr.clone();
+
+            bbs[bb_idx].terminator_mut().kind = TerminatorKind::Goto { target: first };
+            for s in bbs[first].statements.iter_mut() {
                 if let StatementKind::Assign(box (_, ref mut rhs)) = s.kind {
                     let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size;
                     let const_cmp = Operand::const_from_scalar(
@@ -86,17 +78,8 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
                         rustc_span::DUMMY_SP,
                     );
                     *rhs = Rvalue::BinaryOp(BinOp::Eq, Operand::Move(discr), const_cmp);
-                } else {
-                    unreachable!()
                 }
-                s
-            });
-            to_add.statements.extend(new_stmts);
-            to_add.terminator_mut().kind = first_dest;
-            did_remove_blocks = true;
-        }
-        if did_remove_blocks {
-            simplify::remove_dead_blocks(body);
+            }
         }
     }
 }
diff --git a/src/librustc_mir/transform/mod.rs b/src/librustc_mir/transform/mod.rs
index b84514d1e94..4f26f3bb459 100644
--- a/src/librustc_mir/transform/mod.rs
+++ b/src/librustc_mir/transform/mod.rs
@@ -441,6 +441,7 @@ fn run_optimization_passes<'tcx>(
         // with async primitives.
         &generator::StateTransform,
         &instcombine::InstCombine,
+        &match_branches::MatchBranchSimplification,
         &const_prop::ConstProp,
         &simplify_branches::SimplifyBranches::new("after-const-prop"),
         &simplify_try::SimplifyArmIdentity,
@@ -452,7 +453,6 @@ fn run_optimization_passes<'tcx>(
         &simplify::SimplifyCfg::new("final"),
         &nrvo::RenameReturnPlace,
         &simplify::SimplifyLocals,
-        &match_branches::MatchBranchSimplification,
     ];
 
     let no_optimizations: &[&dyn MirPass<'tcx>] = &[