use crate::transform::{MirPass, MirSource}; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; pub struct MatchBranchSimplification; /// If a source block is found that switches between two blocks that are exactly /// the same modulo const bool assignments (e.g., one assigns true another false /// to the same place), merge a target block statements into the source block, /// using Eq / Ne comparison with switch value where const bools value differ. /// /// For example: /// /// ```rust /// bb0: { /// switchInt(move _3) -> [42_isize: bb1, otherwise: bb2]; /// } /// /// bb1: { /// _2 = const true; /// goto -> bb3; /// } /// /// bb2: { /// _2 = const false; /// goto -> bb3; /// } /// ``` /// /// into: /// /// ```rust /// bb0: { /// _2 = Eq(move _3, const 42_isize); /// goto -> bb3; /// } /// ``` impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, body: &mut Body<'tcx>) { let param_env = tcx.param_env(src.def_id()); let bbs = body.basic_blocks_mut(); 'outer: for bb_idx in bbs.indices() { let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind { TerminatorKind::SwitchInt { discr: Operand::Copy(ref place) | Operand::Move(ref place), switch_ty, ref targets, ref values, .. } if targets.len() == 2 && values.len() == 1 && targets[0] != targets[1] => { (place, values[0], switch_ty, targets[0], targets[1]) } // Only optimize switch int statements _ => continue, }; // Check that destinations are identical, and if not, then don't optimize this block if &bbs[first].terminator().kind != &bbs[second].terminator().kind { continue; } // Check that blocks are assignments of consts to the same place or same statement, // and match up 1-1, if not don't optimize this block. let first_stmts = &bbs[first].statements; let scnd_stmts = &bbs[second].statements; if first_stmts.len() != scnd_stmts.len() { continue; } for (f, s) in first_stmts.iter().zip(scnd_stmts.iter()) { match (&f.kind, &s.kind) { // If two statements are exactly the same, we can optimize. (f_s, s_s) if f_s == s_s => {} // If two statements are const bool assignments to the same place, we can optimize. ( StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), ) if lhs_f == lhs_s && f_c.literal.ty.is_bool() && s_c.literal.ty.is_bool() && f_c.literal.try_eval_bool(tcx, param_env).is_some() && s_c.literal.try_eval_bool(tcx, param_env).is_some() => {} // Otherwise we cannot optimize. Try another block. _ => continue 'outer, } } // Take ownership of items now that we know we can optimize. let discr = discr.clone(); // We already checked that first and second are different blocks, // and bb_idx has a different terminator from both of them. let (from, first, second) = bbs.pick3_mut(bb_idx, first, second); let new_stmts = first.statements.iter().zip(second.statements.iter()).map(|(f, s)| { match (&f.kind, &s.kind) { (f_s, s_s) if f_s == s_s => (*f).clone(), ( StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), ) => { // From earlier loop we know that we are dealing with bool constants only: let f_b = f_c.literal.try_eval_bool(tcx, param_env).unwrap(); let s_b = s_c.literal.try_eval_bool(tcx, param_env).unwrap(); if f_b == s_b { // Same value in both blocks. Use statement as is. (*f).clone() } else { // Different value between blocks. Make value conditional on switch condition. let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size; let const_cmp = Operand::const_from_scalar( tcx, switch_ty, crate::interpret::Scalar::from_uint(val, size), rustc_span::DUMMY_SP, ); let op = if f_b { BinOp::Eq } else { BinOp::Ne }; let rhs = Rvalue::BinaryOp(op, Operand::Copy(discr.clone()), const_cmp); Statement { source_info: f.source_info, kind: StatementKind::Assign(box (*lhs, rhs)), } } } _ => unreachable!(), } }); from.statements.extend(new_stmts); from.terminator_mut().kind = first.terminator().kind.clone(); } } }