about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_mir_transform/Cargo.toml1
-rw-r--r--compiler/rustc_mir_transform/src/match_branches.rs141
2 files changed, 80 insertions, 62 deletions
diff --git a/compiler/rustc_mir_transform/Cargo.toml b/compiler/rustc_mir_transform/Cargo.toml
index f864a13a31b..07ca51a67ae 100644
--- a/compiler/rustc_mir_transform/Cargo.toml
+++ b/compiler/rustc_mir_transform/Cargo.toml
@@ -25,6 +25,7 @@ rustc_session = { path = "../rustc_session" }
 rustc_span = { path = "../rustc_span" }
 rustc_target = { path = "../rustc_target" }
 rustc_trait_selection = { path = "../rustc_trait_selection" }
+rustc_type_ir = { path = "../rustc_type_ir" }
 smallvec = { version = "1.8.1", features = ["union", "may_dangle"] }
 tracing = "0.1"
 # tidy-alphabetical-end
diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs
index df4f3ccb9b5..68000fe0ef8 100644
--- a/compiler/rustc_mir_transform/src/match_branches.rs
+++ b/compiler/rustc_mir_transform/src/match_branches.rs
@@ -3,8 +3,10 @@ use std::iter;
 use rustc_index::IndexSlice;
 use rustc_middle::mir::patch::MirPatch;
 use rustc_middle::mir::*;
+use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
 use rustc_middle::ty::{ParamEnv, ScalarInt, Ty, TyCtxt};
-use rustc_target::abi::Size;
+use rustc_target::abi::Integer;
+use rustc_type_ir::TyKind::*;
 
 use super::simplify::simplify_cfg;
 
@@ -264,33 +266,56 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
     }
 }
 
+/// Check if the cast constant using `IntToInt` is equal to the target constant.
+fn can_cast(
+    tcx: TyCtxt<'_>,
+    src_val: impl Into<u128>,
+    src_layout: TyAndLayout<'_>,
+    cast_ty: Ty<'_>,
+    target_scalar: ScalarInt,
+) -> bool {
+    let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
+    let v = match src_layout.ty.kind() {
+        Uint(_) => from_scalar.to_uint(src_layout.size),
+        Int(_) => from_scalar.to_int(src_layout.size) as u128,
+        _ => unreachable!("invalid int"),
+    };
+    let size = match *cast_ty.kind() {
+        Int(t) => Integer::from_int_ty(&tcx, t).size(),
+        Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
+        _ => unreachable!("invalid int"),
+    };
+    let v = size.truncate(v);
+    let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap();
+    cast_scalar == target_scalar
+}
+
 #[derive(Default)]
 struct SimplifyToExp {
-    transfrom_types: Vec<TransfromType>,
+    transfrom_kinds: Vec<TransfromKind>,
 }
 
 #[derive(Clone, Copy)]
-enum CompareType<'tcx, 'a> {
+enum ExpectedTransformKind<'tcx, 'a> {
     /// Identical statements.
     Same(&'a StatementKind<'tcx>),
     /// Assignment statements have the same value.
-    Eq(&'a Place<'tcx>, Ty<'tcx>, ScalarInt),
+    SameByEq { place: &'a Place<'tcx>, ty: Ty<'tcx>, scalar: ScalarInt },
     /// Enum variant comparison type.
-    Discr { place: &'a Place<'tcx>, ty: Ty<'tcx>, is_signed: bool },
+    Cast { place: &'a Place<'tcx>, ty: Ty<'tcx> },
 }
 
-enum TransfromType {
+enum TransfromKind {
     Same,
-    Eq,
-    Discr,
+    Cast,
 }
 
-impl From<CompareType<'_, '_>> for TransfromType {
-    fn from(compare_type: CompareType<'_, '_>) -> Self {
+impl From<ExpectedTransformKind<'_, '_>> for TransfromKind {
+    fn from(compare_type: ExpectedTransformKind<'_, '_>) -> Self {
         match compare_type {
-            CompareType::Same(_) => TransfromType::Same,
-            CompareType::Eq(_, _, _) => TransfromType::Eq,
-            CompareType::Discr { .. } => TransfromType::Discr,
+            ExpectedTransformKind::Same(_) => TransfromKind::Same,
+            ExpectedTransformKind::SameByEq { .. } => TransfromKind::Same,
+            ExpectedTransformKind::Cast { .. } => TransfromKind::Cast,
         }
     }
 }
@@ -354,7 +379,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
             return None;
         }
         let mut target_iter = targets.iter();
-        let (first_val, first_target) = target_iter.next().unwrap();
+        let (first_case_val, first_target) = target_iter.next().unwrap();
         let first_terminator_kind = &bbs[first_target].terminator().kind;
         // Check that destinations are identical, and if not, then don't optimize this block
         if !targets
@@ -364,24 +389,20 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
             return None;
         }
 
-        let discr_size = tcx.layout_of(param_env.and(discr_ty)).unwrap().size;
+        let discr_layout = tcx.layout_of(param_env.and(discr_ty)).unwrap();
         let first_stmts = &bbs[first_target].statements;
-        let (second_val, second_target) = target_iter.next().unwrap();
+        let (second_case_val, second_target) = target_iter.next().unwrap();
         let second_stmts = &bbs[second_target].statements;
         if first_stmts.len() != second_stmts.len() {
             return None;
         }
 
-        fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool {
-            l.to_bits_unchecked() == ScalarInt::try_from_uint(r, size).unwrap().to_bits_unchecked()
-        }
-
         // We first compare the two branches, and then the other branches need to fulfill the same conditions.
-        let mut compare_types = Vec::new();
+        let mut expected_transform_kinds = Vec::new();
         for (f, s) in iter::zip(first_stmts, second_stmts) {
             let compare_type = match (&f.kind, &s.kind) {
                 // If two statements are exactly the same, we can optimize.
-                (f_s, s_s) if f_s == s_s => CompareType::Same(f_s),
+                (f_s, s_s) if f_s == s_s => ExpectedTransformKind::Same(f_s),
 
                 // If two statements are assignments with the match values to the same place, we can optimize.
                 (
@@ -395,22 +416,29 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
                         f_c.const_.try_eval_scalar_int(tcx, param_env),
                         s_c.const_.try_eval_scalar_int(tcx, param_env),
                     ) {
-                        (Some(f), Some(s)) if f == s => CompareType::Eq(lhs_f, f_c.const_.ty(), f),
-                        // Enum variants can also be simplified to an assignment statement if their values are equal.
-                        // We need to consider both unsigned and signed scenarios here.
+                        (Some(f), Some(s)) if f == s => ExpectedTransformKind::SameByEq {
+                            place: lhs_f,
+                            ty: f_c.const_.ty(),
+                            scalar: f,
+                        },
+                        // Enum variants can also be simplified to an assignment statement,
+                        // if we can use `IntToInt` cast to get an equal value.
                         (Some(f), Some(s))
-                            if ((f_c.const_.ty().is_signed() || discr_ty.is_signed())
-                                && int_equal(f, first_val, discr_size)
-                                && int_equal(s, second_val, discr_size))
-                                || (Some(f) == ScalarInt::try_from_uint(first_val, f.size())
-                                    && Some(s)
-                                        == ScalarInt::try_from_uint(second_val, s.size())) =>
+                            if (can_cast(
+                                tcx,
+                                first_case_val,
+                                discr_layout,
+                                f_c.const_.ty(),
+                                f,
+                            ) && can_cast(
+                                tcx,
+                                second_case_val,
+                                discr_layout,
+                                f_c.const_.ty(),
+                                s,
+                            )) =>
                         {
-                            CompareType::Discr {
-                                place: lhs_f,
-                                ty: f_c.const_.ty(),
-                                is_signed: f_c.const_.ty().is_signed() || discr_ty.is_signed(),
-                            }
+                            ExpectedTransformKind::Cast { place: lhs_f, ty: f_c.const_.ty() }
                         }
                         _ => {
                             return None;
@@ -421,47 +449,36 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
                 // Otherwise we cannot optimize. Try another block.
                 _ => return None,
             };
-            compare_types.push(compare_type);
+            expected_transform_kinds.push(compare_type);
         }
 
         // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
         for (other_val, other_target) in target_iter {
             let other_stmts = &bbs[other_target].statements;
-            if compare_types.len() != other_stmts.len() {
+            if expected_transform_kinds.len() != other_stmts.len() {
                 return None;
             }
-            for (f, s) in iter::zip(&compare_types, other_stmts) {
+            for (f, s) in iter::zip(&expected_transform_kinds, other_stmts) {
                 match (*f, &s.kind) {
-                    (CompareType::Same(f_s), s_s) if f_s == s_s => {}
+                    (ExpectedTransformKind::Same(f_s), s_s) if f_s == s_s => {}
                     (
-                        CompareType::Eq(lhs_f, f_ty, val),
+                        ExpectedTransformKind::SameByEq { place: lhs_f, ty: f_ty, scalar },
                         StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
                     ) if lhs_f == lhs_s
                         && s_c.const_.ty() == f_ty
-                        && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(val) => {}
+                        && s_c.const_.try_eval_scalar_int(tcx, param_env) == Some(scalar) => {}
                     (
-                        CompareType::Discr { place: lhs_f, ty: f_ty, is_signed },
+                        ExpectedTransformKind::Cast { place: lhs_f, ty: f_ty },
                         StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
-                    ) if lhs_f == lhs_s && s_c.const_.ty() == f_ty => {
-                        let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env) else {
-                            return None;
-                        };
-                        if is_signed
-                            && s_c.const_.ty().is_signed()
-                            && int_equal(f, other_val, discr_size)
-                        {
-                            continue;
-                        }
-                        if Some(f) == ScalarInt::try_from_uint(other_val, f.size()) {
-                            continue;
-                        }
-                        return None;
-                    }
+                    ) if let Some(f) = s_c.const_.try_eval_scalar_int(tcx, param_env)
+                        && lhs_f == lhs_s
+                        && s_c.const_.ty() == f_ty
+                        && can_cast(tcx, other_val, discr_layout, f_ty, f) => {}
                     _ => return None,
                 }
             }
         }
-        self.transfrom_types = compare_types.into_iter().map(|c| c.into()).collect();
+        self.transfrom_kinds = expected_transform_kinds.into_iter().map(|c| c.into()).collect();
         Some(())
     }
 
@@ -479,13 +496,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
         let (_, first) = targets.iter().next().unwrap();
         let first = &bbs[first];
 
-        for (t, s) in iter::zip(&self.transfrom_types, &first.statements) {
+        for (t, s) in iter::zip(&self.transfrom_kinds, &first.statements) {
             match (t, &s.kind) {
-                (TransfromType::Same, _) | (TransfromType::Eq, _) => {
+                (TransfromKind::Same, _) => {
                     patch.add_statement(parent_end, s.kind.clone());
                 }
                 (
-                    TransfromType::Discr,
+                    TransfromKind::Cast,
                     StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
                 ) => {
                     let operand = Operand::Copy(Place::from(discr_local));