about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDianQK <dianqk@dianqk.net>2024-02-20 21:55:46 +0800
committerDianQK <dianqk@dianqk.net>2024-04-08 19:00:57 +0800
commite752af765ea04ba663d82524cfdcc2b7b6cb58aa (patch)
treee73c40cde72bee2575b9e6f087fc815b1988a0b8
parent1f061f47e2903e90651f63368e3ff0aebac8e3e6 (diff)
downloadrust-e752af765ea04ba663d82524cfdcc2b7b6cb58aa.tar.gz
rust-e752af765ea04ba663d82524cfdcc2b7b6cb58aa.zip
Transforms a match containing negative numbers into an assignment statement as well
-rw-r--r--compiler/rustc_mir_transform/src/match_branches.rs49
-rw-r--r--tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff51
-rw-r--r--tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff51
-rw-r--r--tests/mir-opt/matches_reduce_branches.rs8
4 files changed, 100 insertions, 59 deletions
diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs
index e766c1ae0f6..a444df34048 100644
--- a/compiler/rustc_mir_transform/src/match_branches.rs
+++ b/compiler/rustc_mir_transform/src/match_branches.rs
@@ -1,6 +1,7 @@
 use rustc_index::IndexVec;
 use rustc_middle::mir::*;
 use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
+use rustc_target::abi::Size;
 use std::iter;
 
 use super::simplify::simplify_cfg;
@@ -67,13 +68,13 @@ trait SimplifyMatch<'tcx> {
             _ => unreachable!(),
         };
 
-        if !self.can_simplify(tcx, targets, param_env, bbs) {
+        let discr_ty = discr.ty(local_decls, tcx);
+        if !self.can_simplify(tcx, targets, param_env, bbs, discr_ty) {
             return false;
         }
 
         // Take ownership of items now that we know we can optimize.
         let discr = discr.clone();
-        let discr_ty = discr.ty(local_decls, tcx);
 
         // Introduce a temporary for the discriminant value.
         let source_info = bbs[switch_bb_idx].terminator().source_info;
@@ -104,6 +105,7 @@ trait SimplifyMatch<'tcx> {
         targets: &SwitchTargets,
         param_env: ParamEnv<'tcx>,
         bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+        discr_ty: Ty<'tcx>,
     ) -> bool;
 
     fn new_stmts(
@@ -157,6 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
         targets: &SwitchTargets,
         param_env: ParamEnv<'tcx>,
         bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+        _discr_ty: Ty<'tcx>,
     ) -> bool {
         if targets.iter().len() != 1 {
             return false;
@@ -268,7 +271,7 @@ struct SimplifyToExp {
 enum CompareType<'tcx, 'a> {
     Same(&'a StatementKind<'tcx>),
     Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
-    Discr(&'a Place<'tcx>, Ty<'tcx>),
+    Discr(&'a Place<'tcx>, Ty<'tcx>, bool),
 }
 
 enum TransfromType {
@@ -282,7 +285,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
         match compare_type {
             CompareType::Same(_) => TransfromType::Same,
             CompareType::Eq(_, _, _) => TransfromType::Eq,
-            CompareType::Discr(_, _) => TransfromType::Discr,
+            CompareType::Discr(_, _, _) => TransfromType::Discr,
         }
     }
 }
@@ -333,6 +336,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
         targets: &SwitchTargets,
         param_env: ParamEnv<'tcx>,
         bbs: &IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+        discr_ty: Ty<'tcx>,
     ) -> bool {
         if targets.iter().len() < 2 || targets.iter().len() > 64 {
             return false;
@@ -355,6 +359,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
             return false;
         }
 
+        let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
         let first_stmts = &bbs[first_target].statements;
         let (second_val, second_target) = target_iter.next().unwrap();
         let second_stmts = &bbs[second_target].statements;
@@ -362,6 +367,11 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
             return false;
         }
 
+        fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool {
+            l.try_to_int(l.size()).unwrap()
+                == ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap()
+        }
+
         let mut compare_types = Vec::new();
         for (f, s) in iter::zip(first_stmts, second_stmts) {
             let compare_type = match (&f.kind, &s.kind) {
@@ -382,12 +392,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
                     ) {
                         (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
                         (Some(f), Some(s))
-                            if Some(f) == ScalarInt::try_from_uint(first_val, f.size())
-                                && Some(s) == ScalarInt::try_from_uint(second_val, s.size()) =>
+                            if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
+                                && int_equal(f, first_val, discr_size)
+                                && int_equal(s, second_val, discr_size))
+                                || (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
+                                    && Some(s)
+                                        == ScalarInt::try_from_uint(second_val, s.size())) =>
                         {
-                            CompareType::Discr(lhs_f, f_c.const_.ty())
+                            CompareType::Discr(
+                                lhs_f,
+                                f_c.const_.ty(),
+                                f_c.const_.ty().is_signed() || discr_ty.is_signed(),
+                            )
+                        }
+                        _ => {
+                            return false;
                         }
-                        _ => return false,
                     }
                 }
 
@@ -413,15 +433,22 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
                         && s_c.const_.ty() == f_ty
                         && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
                     (
-                        CompareType::Discr(lhs_f, f_ty),
+                        CompareType::Discr(lhs_f, f_ty, is_signed),
                         StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
                     ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
                         let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
                             return false;
                         };
-                        if Some(f) != ScalarInt::try_from_uint(other_val, f.size()) {
-                            return false;
+                        if is_signed
+                            && s_c.const_.ty().is_signed()
+                            && int_equal(f, other_val, discr_size)
+                        {
+                            continue;
+                        }
+                        if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
+                            continue;
                         }
+                        return false;
                     }
                     _ => return false,
                 }
diff --git a/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff
index 4b435310916..e1b537b1b71 100644
--- a/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff
+++ b/tests/mir-opt/matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff
@@ -5,32 +5,37 @@
       debug i => _1;
       let mut _0: i8;
       let mut _2: i16;
++     let mut _3: i16;
   
       bb0: {
           _2 = discriminant(_1);
-          switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1];
-      }
-  
-      bb1: {
-          unreachable;
-      }
-  
-      bb2: {
-          _0 = const -3_i8;
-          goto -> bb5;
-      }
-  
-      bb3: {
-          _0 = const -1_i8;
-          goto -> bb5;
-      }
-  
-      bb4: {
-          _0 = const 2_i8;
-          goto -> bb5;
-      }
-  
-      bb5: {
+-         switchInt(move _2) -> [65535: bb3, 2: bb4, 65533: bb2, otherwise: bb1];
+-     }
+- 
+-     bb1: {
+-         unreachable;
+-     }
+- 
+-     bb2: {
+-         _0 = const -3_i8;
+-         goto -> bb5;
+-     }
+- 
+-     bb3: {
+-         _0 = const -1_i8;
+-         goto -> bb5;
+-     }
+- 
+-     bb4: {
+-         _0 = const 2_i8;
+-         goto -> bb5;
+-     }
+- 
+-     bb5: {
++         StorageLive(_3);
++         _3 = move _2;
++         _0 = _3 as i8 (IntToInt);
++         StorageDead(_3);
           return;
       }
   }
diff --git a/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
index 8a390736add..cabc5a44cd8 100644
--- a/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
+++ b/tests/mir-opt/matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
@@ -5,32 +5,37 @@
       debug i => _1;
       let mut _0: i16;
       let mut _2: i8;
++     let mut _3: i8;
   
       bb0: {
           _2 = discriminant(_1);
-          switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1];
-      }
-  
-      bb1: {
-          unreachable;
-      }
-  
-      bb2: {
-          _0 = const -3_i16;
-          goto -> bb5;
-      }
-  
-      bb3: {
-          _0 = const -1_i16;
-          goto -> bb5;
-      }
-  
-      bb4: {
-          _0 = const 2_i16;
-          goto -> bb5;
-      }
-  
-      bb5: {
+-         switchInt(move _2) -> [255: bb3, 2: bb4, 253: bb2, otherwise: bb1];
+-     }
+- 
+-     bb1: {
+-         unreachable;
+-     }
+- 
+-     bb2: {
+-         _0 = const -3_i16;
+-         goto -> bb5;
+-     }
+- 
+-     bb3: {
+-         _0 = const -1_i16;
+-         goto -> bb5;
+-     }
+- 
+-     bb4: {
+-         _0 = const 2_i16;
+-         goto -> bb5;
+-     }
+- 
+-     bb5: {
++         StorageLive(_3);
++         _3 = move _2;
++         _0 = _3 as i16 (IntToInt);
++         StorageDead(_3);
           return;
       }
   }
diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs
index d51dd7c5873..ca3e5f747d1 100644
--- a/tests/mir-opt/matches_reduce_branches.rs
+++ b/tests/mir-opt/matches_reduce_branches.rs
@@ -204,7 +204,9 @@ enum EnumAi8 {
 // EMIT_MIR matches_reduce_branches.match_i8_i16.MatchBranchSimplification.diff
 fn match_i8_i16(i: EnumAi8) -> i16 {
     // CHECK-LABEL: fn match_i8_i16(
-    // CHECK: switchInt
+    // CHECK-NOT: switchInt
+    // CHECK: _0 = _3 as i16 (IntToInt);
+    // CHECH: return
     match i {
         EnumAi8::A => -1,
         EnumAi8::B => 2,
@@ -233,7 +235,9 @@ enum EnumAi16 {
 // EMIT_MIR matches_reduce_branches.match_i16_i8.MatchBranchSimplification.diff
 fn match_i16_i8(i: EnumAi16) -> i8 {
     // CHECK-LABEL: fn match_i16_i8(
-    // CHECK: switchInt
+    // CHECK-NOT: switchInt
+    // CHECK: _0 = _3 as i8 (IntToInt);
+    // CHECH: return
     match i {
         EnumAi16::A => -1,
         EnumAi16::B => 2,