diff options
| author | DianQK <dianqk@dianqk.net> | 2024-02-20 21:55:46 +0800 |
|---|---|---|
| committer | DianQK <dianqk@dianqk.net> | 2024-04-08 19:00:57 +0800 |
| commit | e752af765ea04ba663d82524cfdcc2b7b6cb58aa (patch) | |
| tree | e73c40cde72bee2575b9e6f087fc815b1988a0b8 | |
| parent | 1f061f47e2903e90651f63368e3ff0aebac8e3e6 (diff) | |
| download | rust-e752af765ea04ba663d82524cfdcc2b7b6cb58aa.tar.gz rust-e752af765ea04ba663d82524cfdcc2b7b6cb58aa.zip | |
Transforms a match containing negative numbers into an assignment statement as well
4 files changed, 100 insertions, 59 deletions
diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index e766c1ae0f6..a444df34048 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,6 +1,7 @@ use rustc_index::IndexVec; use rustc_middle::mir::*; use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt}; +use rustc_target::abi::Size; use std::iter; use super::simplify::simplify_cfg; @@ -67,13 +68,13 @@ trait SimplifyMatch<'tcx> { _ => unreachable!(), }; - if !self.can_simplify(tcx, targets, param_env, bbs) { + let discr_ty = discr.ty(local_decls, tcx); + if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) { return false; } // Take ownership of items now that we know we can optimize. let discr = discr.clone(); - let discr_ty = discr.ty(local_decls, tcx); // Introduce a temporary for the discriminant value. let source_info = bbs[switch_bb_idx].terminator().source_info; @@ -104,6 +105,7 @@ trait SimplifyMatch<'tcx> { targets: &SwitchTargets, param_env: ParamEnv<'tcx>, bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, + discr_ty: Ty<'tcx>, ) -> bool; fn new_stmts( @@ -157,6 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { targets: &SwitchTargets, param_env: ParamEnv<'tcx>, bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, + _discr_ty: Ty<'tcx>, ) -> bool { if targets.iter().len() != 1 { return false; @@ -268,7 +271,7 @@ struct SimplifyToExp { enum CompareType<'tcx, 'a> { Same(&'a StatementKind<'tcx>), Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt), - Discr(&'a Place<'tcx>, Ty<'tcx>), + Discr(&'a Place<'tcx>, Ty<'tcx>, bool), } enum TransfromType { @@ -282,7 +285,7 @@ impl From<CompareType<'_, '_>> for TransfromType { match compare_type { CompareType::Same(_) => TransfromType::Same, CompareType::Eq(_, _, _) => TransfromType::Eq, - CompareType::Discr(_, _) => TransfromType::Discr, + CompareType::Discr(_, _, _) => TransfromType::Discr, } } } @@ -333,6 +336,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { targets: &SwitchTargets, param_env: ParamEnv<'tcx>, bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>, + discr_ty: Ty<'tcx>, ) -> bool { if targets.iter().len() < 2 || targets.iter().len() > 64 { return false; @@ -355,6 +359,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { return false; } + let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size; let first_stmts = &bbs[first_target].statements; let (second_val, second_target) = target_iter.next().unwrap(); let second_stmts = &bbs[second_target].statements; @@ -362,6 +367,11 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { return false; } + fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool { + l.try_to_int(l.size()).unwrap() + == ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap() + } + let mut compare_types = Vec::new(); for (f, s) in iter::zip(first_stmts, second_stmts) { let compare_type = match (&f.kind, &s.kind) { @@ -382,12 +392,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { ) { (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f), (Some(f), Some(s)) - if Some(f) == ScalarInt::try_from_uint(first_val, f.size()) - && Some(s) == ScalarInt::try_from_uint(second_val, s.size()) => + if ((f_c.const_.ty().is_signed() || discr_ty.is_signed()) + && int_equal(f, first_val, discr_size) + && int_equal(s, second_val, discr_size)) + || (Some(f) == ScalarInt::try_from_uint(first_val, f.size()) + && Some(s) + == ScalarInt::try_from_uint(second_val, s.size())) => { - CompareType::Discr(lhs_f, f_c.const_.ty()) + CompareType::Discr( + lhs_f, + f_c.const_.ty(), + f_c.const_.ty().is_signed() || discr_ty.is_signed(), + ) + } + _ => { + return false; } - _ => return false, } } @@ -413,15 +433,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { && s_c.const_.ty() == f_ty && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {} ( - CompareType::Discr(lhs_f, f_ty), + CompareType::Discr(lhs_f, f_ty, is_signed), StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => { let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else { return false; }; - if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) { - return false; + if is_signed + && s_c.const_.ty().is_signed() + && int_equal(f, other_val, discr_size) + { + continue; + } + if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) { + continue; } + return false; } _ => return false, } diff --git a/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff index 4b435310916..e1b537b1b71 100644 --- a/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff @@ -5,32 +5,37 @@ debug i => _1; let mut _0: i8; let mut _2: i16; ++ let mut _3: i16; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const -3_i8; - goto -> bb5; - } - - bb3: { - _0 = const -1_i8; - goto -> bb5; - } - - bb4: { - _0 = const 2_i8; - goto -> bb5; - } - - bb5: { +- switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const -3_i8; +- goto -> bb5; +- } +- +- bb3: { +- _0 = const -1_i8; +- goto -> bb5; +- } +- +- bb4: { +- _0 = const 2_i8; +- goto -> bb5; +- } +- +- bb5: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i8 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff index 8a390736add..cabc5a44cd8 100644 --- a/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff +++ b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff @@ -5,32 +5,37 @@ debug i => _1; let mut _0: i16; let mut _2: i8; ++ let mut _3: i8; bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1]; - } - - bb1: { - unreachable; - } - - bb2: { - _0 = const -3_i16; - goto -> bb5; - } - - bb3: { - _0 = const -1_i16; - goto -> bb5; - } - - bb4: { - _0 = const 2_i16; - goto -> bb5; - } - - bb5: { +- switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = const -3_i16; +- goto -> bb5; +- } +- +- bb3: { +- _0 = const -1_i16; +- goto -> bb5; +- } +- +- bb4: { +- _0 = const 2_i16; +- goto -> bb5; +- } +- +- bb5: { ++ StorageLive(_3); ++ _3 = move _2; ++ _0 = _3 as i16 (IntToInt); ++ StorageDead(_3); return; } } diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index d51dd7c5873..ca3e5f747d1 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -204,7 +204,9 @@ enum EnumAi8 { // EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff fn match_i8_i16(i: EnumAi8) -> i16 { // CHECK-LABEL: fn match_i8_i16( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as i16 (IntToInt); + // CHECH: return match i { EnumAi8::A => -1, EnumAi8::B => 2, @@ -233,7 +235,9 @@ enum EnumAi16 { // EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff fn match_i16_i8(i: EnumAi16) -> i8 { // CHECK-LABEL: fn match_i16_i8( - // CHECK: switchInt + // CHECK-NOT: switchInt + // CHECK: _0 = _3 as i8 (IntToInt); + // CHECH: return match i { EnumAi16::A => -1, EnumAi16::B => 2, |
