diff options
| author | bors <bors@rust-lang.org> | 2020-09-20 17:54:44 +0000 |
|---|---|---|
| committer | bors <bors@rust-lang.org> | 2020-09-20 17:54:44 +0000 |
| commit | 2e0edc0f28c5647141bedba02e7a222d3a5dc9c3 (patch) | |
| tree | 8c6c0ef494292f83366eccc7cb4b639b73013112 /compiler | |
| parent | 81e02708f1f4760244756548981277d5199baa9a (diff) | |
| parent | 0363694c7ff72d0a4b1c52ebf2320930c3b60da8 (diff) | |
| download | rust-2e0edc0f28c5647141bedba02e7a222d3a5dc9c3.tar.gz rust-2e0edc0f28c5647141bedba02e7a222d3a5dc9c3.zip | |
Auto merge of #75119 - simonvandel:early-otherwise, r=oli-obk
New MIR optimization pass to reduce branches on match of tuples of enums
Fixes #68867 by adding a new pass that turns something like
```rust
let x: Option<()>;
let y: Option<()>;
match (x,y) {
(Some(_), Some(_)) => {0},
_ => {1}
}
```
into something like
```rust
let x: Option<()>;
let y: Option<()>;
let discriminant_x = // get discriminant of x
let discriminant_y = // get discriminant of x
if discriminant_x != discriminant_y {1} else {0}
```
The opt-diffs still have the old basic blocks like
```
bb3: {
_8 = discriminant((*(_4.1: &ViewportPercentageLength))); // scope 0 at $DIR/early-otherwise-branch-68867.rs:21:21: 21:30
switchInt(move _8) -> [1_isize: bb7, otherwise: bb2]; // scope 0 at $DIR/early-otherwise-branch-68867.rs:21:21: 21:30
}
bb4: {
_9 = discriminant((*(_4.1: &ViewportPercentageLength))); // scope 0 at $DIR/early-otherwise-branch-68867.rs:22:23: 22:34
switchInt(move _9) -> [2_isize: bb8, otherwise: bb2]; // scope 0 at $DIR/early-otherwise-branch-68867.rs:22:23: 22:34
}
bb5: {
_10 = discriminant((*(_4.1: &ViewportPercentageLength))); // scope 0 at $DIR/early-otherwise-branch-68867.rs:23:23: 23:34
switchInt(move _10) -> [3_isize: bb9, otherwise: bb2]; // scope 0 at $DIR/early-otherwise-branch-68867.rs:23:23: 23:34
}
```
These do get removed on later passes. I'm not sure if I should include those passes in the test to make it clear?
Diffstat (limited to 'compiler')
| -rw-r--r-- | compiler/rustc_mir/src/transform/early_otherwise_branch.rs | 339 | ||||
| -rw-r--r-- | compiler/rustc_mir/src/transform/mod.rs | 2 |
2 files changed, 341 insertions, 0 deletions
diff --git a/compiler/rustc_mir/src/transform/early_otherwise_branch.rs b/compiler/rustc_mir/src/transform/early_otherwise_branch.rs new file mode 100644 index 00000000000..67e679a8b08 --- /dev/null +++ b/compiler/rustc_mir/src/transform/early_otherwise_branch.rs @@ -0,0 +1,339 @@ +use crate::{ + transform::{MirPass, MirSource}, + util::patch::MirPatch, +}; +use rustc_middle::mir::*; +use rustc_middle::ty::{Ty, TyCtxt}; +use std::{borrow::Cow, fmt::Debug}; + +use super::simplify::simplify_cfg; + +/// This pass optimizes something like +/// ```text +/// let x: Option<()>; +/// let y: Option<()>; +/// match (x,y) { +/// (Some(_), Some(_)) => {0}, +/// _ => {1} +/// } +/// ``` +/// into something like +/// ```text +/// let x: Option<()>; +/// let y: Option<()>; +/// let discriminant_x = // get discriminant of x +/// let discriminant_y = // get discriminant of y +/// if discriminant_x != discriminant_y || discriminant_x == None {1} else {0} +/// ``` +pub struct EarlyOtherwiseBranch; + +impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { + fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) { + if tcx.sess.opts.debugging_opts.mir_opt_level < 1 { + return; + } + trace!("running EarlyOtherwiseBranch on {:?}", source); + // we are only interested in this bb if the terminator is a switchInt + let bbs_with_switch = + body.basic_blocks().iter_enumerated().filter(|(_, bb)| is_switch(bb.terminator())); + + let opts_to_apply: Vec<OptimizationToApply<'tcx>> = bbs_with_switch + .flat_map(|(bb_idx, bb)| { + let switch = bb.terminator(); + let helper = Helper { body, tcx }; + let infos = helper.go(bb, switch)?; + Some(OptimizationToApply { infos, basic_block_first_switch: bb_idx }) + }) + .collect(); + + let should_cleanup = !opts_to_apply.is_empty(); + + for opt_to_apply in opts_to_apply { + trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_to_apply); + + let statements_before = + body.basic_blocks()[opt_to_apply.basic_block_first_switch].statements.len(); + let end_of_block_location = Location { + block: opt_to_apply.basic_block_first_switch, + statement_index: statements_before, + }; + + let mut patch = MirPatch::new(body); + + // create temp to store second discriminant in + let discr_type = opt_to_apply.infos[0].second_switch_info.discr_ty; + let discr_span = opt_to_apply.infos[0].second_switch_info.discr_source_info.span; + let second_discriminant_temp = patch.new_temp(discr_type, discr_span); + + patch.add_statement( + end_of_block_location, + StatementKind::StorageLive(second_discriminant_temp), + ); + + // create assignment of discriminant + let place_of_adt_to_get_discriminant_of = + opt_to_apply.infos[0].second_switch_info.place_of_adt_discr_read; + patch.add_assign( + end_of_block_location, + Place::from(second_discriminant_temp), + Rvalue::Discriminant(place_of_adt_to_get_discriminant_of), + ); + + // create temp to store NotEqual comparison between the two discriminants + let not_equal = BinOp::Ne; + let not_equal_res_type = not_equal.ty(tcx, discr_type, discr_type); + let not_equal_temp = patch.new_temp(not_equal_res_type, discr_span); + patch.add_statement(end_of_block_location, StatementKind::StorageLive(not_equal_temp)); + + // create NotEqual comparison between the two discriminants + let first_descriminant_place = + opt_to_apply.infos[0].first_switch_info.discr_used_in_switch; + let not_equal_rvalue = Rvalue::BinaryOp( + not_equal, + Operand::Copy(Place::from(second_discriminant_temp)), + Operand::Copy(Place::from(first_descriminant_place)), + ); + patch.add_statement( + end_of_block_location, + StatementKind::Assign(box (Place::from(not_equal_temp), not_equal_rvalue)), + ); + + let (mut targets_to_jump_to, values_to_jump_to): (Vec<_>, Vec<_>) = opt_to_apply + .infos + .iter() + .flat_map(|x| x.second_switch_info.targets_with_values.iter()) + .cloned() + .unzip(); + + // add otherwise case in the end + targets_to_jump_to.push(opt_to_apply.infos[0].first_switch_info.otherwise_bb); + // new block that jumps to the correct discriminant case. This block is switched to if the discriminants are equal + let new_switch_data = BasicBlockData::new(Some(Terminator { + source_info: opt_to_apply.infos[0].second_switch_info.discr_source_info, + kind: TerminatorKind::SwitchInt { + // the first and second discriminants are equal, so just pick one + discr: Operand::Copy(first_descriminant_place), + switch_ty: discr_type, + values: Cow::from(values_to_jump_to), + targets: targets_to_jump_to, + }, + })); + + let new_switch_bb = patch.new_block(new_switch_data); + + // switch on the NotEqual. If true, then jump to the `otherwise` case. + // If false, then jump to a basic block that then jumps to the correct disciminant case + let true_case = opt_to_apply.infos[0].first_switch_info.otherwise_bb; + let false_case = new_switch_bb; + patch.patch_terminator( + opt_to_apply.basic_block_first_switch, + TerminatorKind::if_( + tcx, + Operand::Move(Place::from(not_equal_temp)), + true_case, + false_case, + ), + ); + + // generate StorageDead for the second_discriminant_temp not in use anymore + patch.add_statement( + end_of_block_location, + StatementKind::StorageDead(second_discriminant_temp), + ); + + // Generate a StorageDead for not_equal_temp in each of the targets, since we moved it into the switch + for bb in [false_case, true_case].iter() { + patch.add_statement( + Location { block: *bb, statement_index: 0 }, + StatementKind::StorageDead(not_equal_temp), + ); + } + + patch.apply(body); + } + + // Since this optimization adds new basic blocks and invalidates others, + // clean up the cfg to make it nicer for other passes + if should_cleanup { + simplify_cfg(body); + } + } +} + +fn is_switch<'tcx>(terminator: &Terminator<'tcx>) -> bool { + match terminator.kind { + TerminatorKind::SwitchInt { .. } => true, + _ => false, + } +} + +struct Helper<'a, 'tcx> { + body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, +} + +#[derive(Debug, Clone)] +struct SwitchDiscriminantInfo<'tcx> { + /// Type of the discriminant being switched on + discr_ty: Ty<'tcx>, + /// The basic block that the otherwise branch points to + otherwise_bb: BasicBlock, + /// Target along with the value being branched from. Otherwise is not included + targets_with_values: Vec<(BasicBlock, u128)>, + discr_source_info: SourceInfo, + /// The place of the discriminant used in the switch + discr_used_in_switch: Place<'tcx>, + /// The place of the adt that has its discriminant read + place_of_adt_discr_read: Place<'tcx>, + /// The type of the adt that has its discriminant read + type_adt_matched_on: Ty<'tcx>, +} + +#[derive(Debug)] +struct OptimizationToApply<'tcx> { + infos: Vec<OptimizationInfo<'tcx>>, + /// Basic block of the original first switch + basic_block_first_switch: BasicBlock, +} + +#[derive(Debug)] +struct OptimizationInfo<'tcx> { + /// Info about the first switch and discriminant + first_switch_info: SwitchDiscriminantInfo<'tcx>, + /// Info about the second switch and discriminant + second_switch_info: SwitchDiscriminantInfo<'tcx>, +} + +impl<'a, 'tcx> Helper<'a, 'tcx> { + pub fn go( + &self, + bb: &BasicBlockData<'tcx>, + switch: &Terminator<'tcx>, + ) -> Option<Vec<OptimizationInfo<'tcx>>> { + // try to find the statement that defines the discriminant that is used for the switch + let discr = self.find_switch_discriminant_info(bb, switch)?; + + // go through each target, finding a discriminant read, and a switch + let results = discr.targets_with_values.iter().map(|(target, value)| { + self.find_discriminant_switch_pairing(&discr, target.clone(), value.clone()) + }); + + // if the optimization did not apply for one of the targets, then abort + if results.clone().any(|x| x.is_none()) || results.len() == 0 { + trace!("NO: not all of the targets matched the pattern for optimization"); + return None; + } + + Some(results.flatten().collect()) + } + + fn find_discriminant_switch_pairing( + &self, + discr_info: &SwitchDiscriminantInfo<'tcx>, + target: BasicBlock, + value: u128, + ) -> Option<OptimizationInfo<'tcx>> { + let bb = &self.body.basic_blocks()[target]; + // find switch + let terminator = bb.terminator(); + if is_switch(terminator) { + let this_bb_discr_info = self.find_switch_discriminant_info(bb, terminator)?; + + // the types of the two adts matched on have to be equalfor this optimization to apply + if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on { + trace!( + "NO: types do not match. LHS: {:?}, RHS: {:?}", + discr_info.type_adt_matched_on, + this_bb_discr_info.type_adt_matched_on + ); + return None; + } + + // the otherwise branch of the two switches have to point to the same bb + if discr_info.otherwise_bb != this_bb_discr_info.otherwise_bb { + trace!("NO: otherwise target is not the same"); + return None; + } + + // check that the value being matched on is the same. The + if this_bb_discr_info.targets_with_values.iter().find(|x| x.1 == value).is_none() { + trace!("NO: values being matched on are not the same"); + return None; + } + + // only allow optimization if the left and right of the tuple being matched are the same variants. + // so the following should not optimize + // ```rust + // let x: Option<()>; + // let y: Option<()>; + // match (x,y) { + // (Some(_), None) => {}, + // _ => {} + // } + // ``` + // We check this by seeing that the value of the first discriminant is the only other discriminant value being used as a target in the second switch + if !(this_bb_discr_info.targets_with_values.len() == 1 + && this_bb_discr_info.targets_with_values[0].1 == value) + { + trace!( + "NO: The second switch did not have only 1 target (besides otherwise) that had the same value as the value from the first switch that got us here" + ); + return None; + } + + // if we reach this point, the optimization applies, and we should be able to optimize this case + // store the info that is needed to apply the optimization + + Some(OptimizationInfo { + first_switch_info: discr_info.clone(), + second_switch_info: this_bb_discr_info, + }) + } else { + None + } + } + + fn find_switch_discriminant_info( + &self, + bb: &BasicBlockData<'tcx>, + switch: &Terminator<'tcx>, + ) -> Option<SwitchDiscriminantInfo<'tcx>> { + match &switch.kind { + TerminatorKind::SwitchInt { discr, targets, values, .. } => { + let discr_local = discr.place()?.as_local()?; + // the declaration of the discriminant read. Place of this read is being used in the switch + let discr_decl = &self.body.local_decls()[discr_local]; + let discr_ty = discr_decl.ty; + // the otherwise target lies as the last element + let otherwise_bb = targets.get(values.len())?.clone(); + let targets_with_values = targets + .iter() + .zip(values.iter()) + .map(|(t, v)| (t.clone(), v.clone())) + .collect(); + + // find the place of the adt where the discriminant is being read from + // assume this is the last statement of the block + let place_of_adt_discr_read = match bb.statements.last()?.kind { + StatementKind::Assign(box (_, Rvalue::Discriminant(adt_place))) => { + Some(adt_place) + } + _ => None, + }?; + + let type_adt_matched_on = place_of_adt_discr_read.ty(self.body, self.tcx).ty; + + Some(SwitchDiscriminantInfo { + discr_used_in_switch: discr.place()?, + discr_ty, + otherwise_bb, + targets_with_values, + discr_source_info: discr_decl.source_info, + place_of_adt_discr_read, + type_adt_matched_on, + }) + } + _ => unreachable!("must only be passed terminator that is a switch"), + } + } +} diff --git a/compiler/rustc_mir/src/transform/mod.rs b/compiler/rustc_mir/src/transform/mod.rs index fc9854ba499..abe2dc496a6 100644 --- a/compiler/rustc_mir/src/transform/mod.rs +++ b/compiler/rustc_mir/src/transform/mod.rs @@ -26,6 +26,7 @@ pub mod copy_prop; pub mod deaggregator; pub mod dest_prop; pub mod dump_mir; +pub mod early_otherwise_branch; pub mod elaborate_drops; pub mod generator; pub mod inline; @@ -465,6 +466,7 @@ fn run_optimization_passes<'tcx>( &instcombine::InstCombine, &const_prop::ConstProp, &simplify_branches::SimplifyBranches::new("after-const-prop"), + &early_otherwise_branch::EarlyOtherwiseBranch, &simplify_comparison_integral::SimplifyComparisonIntegral, &simplify_try::SimplifyArmIdentity, &simplify_try::SimplifyBranchSame, |
