diff options
| -rw-r--r-- | compiler/rustc_mir_transform/src/match_branches.rs | 338 |
1 files changed, 205 insertions, 133 deletions
diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index 6d4332793af..be1158683ac 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,11 +1,116 @@ +use rustc_index::IndexVec; use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{ParamEnv, Ty, TyCtxt}; use std::iter; use super::simplify::simplify_cfg; pub struct MatchBranchSimplification; +impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 1 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + + let bbs = body.basic_blocks.as_mut(); + let mut should_cleanup = false; + for bb_idx in bbs.indices() { + if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { + continue; + } + + match bbs[bb_idx].terminator().kind { + TerminatorKind::SwitchInt { + discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)), + ref targets, + .. + // We require that the possible target blocks don't contain this block. + } if !targets.all_targets().contains(&bb_idx) => {} + // Only optimize switch int statements + _ => continue, + }; + + if SimplifyToIf.simplify(tcx, &mut body.local_decls, bbs, bb_idx, param_env) { + should_cleanup = true; + continue; + } + } + + if should_cleanup { + simplify_cfg(body); + } + } +} + +trait SimplifyMatch<'tcx> { + fn simplify( + &self, + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + bbs: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>, + switch_bb_idx: BasicBlock, + param_env: ParamEnv<'tcx>, + ) -> bool { + let (discr, targets) = match bbs[switch_bb_idx].terminator().kind { + TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets), + _ => unreachable!(), + }; + + if !self.can_simplify(tcx, targets, param_env, bbs) { + return false; + } + + // Take ownership of items now that we know we can optimize. + let discr = discr.clone(); + let discr_ty = discr.ty(local_decls, tcx); + + // Introduce a temporary for the discriminant value. + let source_info = bbs[switch_bb_idx].terminator().source_info; + let discr_local = local_decls.push(LocalDecl::new(discr_ty, source_info.span)); + + // We already checked that first and second are different blocks, + // and bb_idx has a different terminator from both of them. + let new_stmts = self.new_stmts(tcx, targets, param_env, bbs, discr_local.clone(), discr_ty); + let (_, first) = targets.iter().next().unwrap(); + let (from, first) = bbs.pick2_mut(switch_bb_idx, first); + from.statements + .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); + from.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((Place::from(discr_local), Rvalue::Use(discr)))), + }); + from.statements.extend(new_stmts); + from.statements + .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); + from.terminator_mut().kind = first.terminator().kind.clone(); + true + } + + fn can_simplify( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, + ) -> bool; + + fn new_stmts( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) -> Vec<Statement<'tcx>>; +} + +struct SimplifyToIf; + /// If a source block is found that switches between two blocks that are exactly /// the same modulo const bool assignments (e.g., one assigns true another false /// to the same place), merge a target block statements into the source block, @@ -37,144 +142,111 @@ pub struct MatchBranchSimplification; /// goto -> bb3; /// } /// ``` +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { + fn can_simplify( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, + ) -> bool { + if targets.iter().len() != 1 { + return false; + } + // We require that the possible target blocks all be distinct. + let (_, first) = targets.iter().next().unwrap(); + let second = targets.otherwise(); + if first == second { + return false; + } + // Check that destinations are identical, and if not, then don't optimize this block + if bbs[first].terminator().kind != bbs[second].terminator().kind { + return false; + } -impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 1 - } - - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let def_id = body.source.def_id(); - let param_env = tcx.param_env_reveal_all_normalized(def_id); - - let bbs = body.basic_blocks.as_mut(); - let mut should_cleanup = false; - 'outer: for bb_idx in bbs.indices() { - if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {def_id:?} ")) { - continue; - } - - let (discr, val, first, second) = match bbs[bb_idx].terminator().kind { - TerminatorKind::SwitchInt { - discr: ref discr @ (Operand::Copy(_) | Operand::Move(_)), - ref targets, - .. - } if targets.iter().len() == 1 => { - let (value, target) = targets.iter().next().unwrap(); - // We require that this block and the two possible target blocks all be - // distinct. - if target == targets.otherwise() - || bb_idx == target - || bb_idx == targets.otherwise() - { - continue; - } - (discr, value, target, targets.otherwise()) - } - // Only optimize switch int statements - _ => continue, - }; - - // Check that destinations are identical, and if not, then don't optimize this block - if bbs[first].terminator().kind != bbs[second].terminator().kind { - continue; + // Check that blocks are assignments of consts to the same place or same statement, + // and match up 1-1, if not don't optimize this block. + let first_stmts = &bbs[first].statements; + let second_stmts = &bbs[second].statements; + if first_stmts.len() != second_stmts.len() { + return false; + } + for (f, s) in iter::zip(first_stmts, second_stmts) { + match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => {} + + // If two statements are const bool assignments to the same place, we can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.const_.ty().is_bool() + && s_c.const_.ty().is_bool() + && f_c.const_.try_eval_bool(tcx, param_env).is_some() + && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} + + // Otherwise we cannot optimize. Try another block. + _ => return false, } + } + true + } - // Check that blocks are assignments of consts to the same place or same statement, - // and match up 1-1, if not don't optimize this block. - let first_stmts = &bbs[first].statements; - let scnd_stmts = &bbs[second].statements; - if first_stmts.len() != scnd_stmts.len() { - continue; - } - for (f, s) in iter::zip(first_stmts, scnd_stmts) { - match (&f.kind, &s.kind) { - // If two statements are exactly the same, we can optimize. - (f_s, s_s) if f_s == s_s => {} - - // If two statements are const bool assignments to the same place, we can optimize. - ( - StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if lhs_f == lhs_s - && f_c.const_.ty().is_bool() - && s_c.const_.ty().is_bool() - && f_c.const_.try_eval_bool(tcx, param_env).is_some() - && s_c.const_.try_eval_bool(tcx, param_env).is_some() => {} - - // Otherwise we cannot optimize. Try another block. - _ => continue 'outer, - } - } - // Take ownership of items now that we know we can optimize. - let discr = discr.clone(); - let discr_ty = discr.ty(&body.local_decls, tcx); - - // Introduce a temporary for the discriminant value. - let source_info = bbs[bb_idx].terminator().source_info; - let discr_local = body.local_decls.push(LocalDecl::new(discr_ty, source_info.span)); - - // We already checked that first and second are different blocks, - // and bb_idx has a different terminator from both of them. - let (from, first, second) = bbs.pick3_mut(bb_idx, first, second); - - let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { - match (&f.kind, &s.kind) { - (f_s, s_s) if f_s == s_s => (*f).clone(), - - ( - StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), - ) => { - // From earlier loop we know that we are dealing with bool constants only: - let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); - let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); - if f_b == s_b { - // Same value in both blocks. Use statement as is. - (*f).clone() - } else { - // Different value between blocks. Make value conditional on switch condition. - let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; - let const_cmp = Operand::const_from_scalar( - tcx, - discr_ty, - rustc_const_eval::interpret::Scalar::from_uint(val, size), - rustc_span::DUMMY_SP, - ); - let op = if f_b { BinOp::Eq } else { BinOp::Ne }; - let rhs = Rvalue::BinaryOp( - op, - Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), - ); - Statement { - source_info: f.source_info, - kind: StatementKind::Assign(Box::new((*lhs, rhs))), - } + fn new_stmts( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + param_env: ParamEnv<'tcx>, + bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) -> Vec<Statement<'tcx>> { + let (val, first) = targets.iter().next().unwrap(); + let second = targets.otherwise(); + // We already checked that first and second are different blocks, + // and bb_idx has a different terminator from both of them. + let first = &bbs[first]; + let second = &bbs[second]; + + let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { + match (&f.kind, &s.kind) { + (f_s, s_s) if f_s == s_s => (*f).clone(), + + ( + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), + ) => { + // From earlier loop we know that we are dealing with bool constants only: + let f_b = f_c.const_.try_eval_bool(tcx, param_env).unwrap(); + let s_b = s_c.const_.try_eval_bool(tcx, param_env).unwrap(); + if f_b == s_b { + // Same value in both blocks. Use statement as is. + (*f).clone() + } else { + // Different value between blocks. Make value conditional on switch condition. + let size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; + let const_cmp = Operand::const_from_scalar( + tcx, + discr_ty, + rustc_const_eval::interpret::Scalar::from_uint(val, size), + rustc_span::DUMMY_SP, + ); + let op = if f_b { BinOp::Eq } else { BinOp::Ne }; + let rhs = Rvalue::BinaryOp( + op, + Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), + ); + Statement { + source_info: f.source_info, + kind: StatementKind::Assign(Box::new((*lhs, rhs))), } } - - _ => unreachable!(), } - }); - - from.statements - .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); - from.statements.push(Statement { - source_info, - kind: StatementKind::Assign(Box::new(( - Place::from(discr_local), - Rvalue::Use(discr), - ))), - }); - from.statements.extend(new_stmts); - from.statements - .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); - from.terminator_mut().kind = first.terminator().kind.clone(); - should_cleanup = true; - } - if should_cleanup { - simplify_cfg(body); - } + _ => unreachable!(), + } + }); + new_stmts.collect() } } |
