about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src/instsimplify.rs
diff options
context:
space:
mode:
authorYotam Ofek <yotam.ofek@gmail.com>2025-04-10 15:26:27 +0000
committerYotam Ofek <yotam.ofek@gmail.com>2025-04-10 18:40:25 +0000
commit9491242ff70618571d3b0e102e800c04f8603ba9 (patch)
treea3306ea33640ccd0439deeded67ea836e3206720 /compiler/rustc_mir_transform/src/instsimplify.rs
parent69b3959afec9b5468d5de15133b199553f6e55d2 (diff)
downloadrust-9491242ff70618571d3b0e102e800c04f8603ba9.tar.gz
rust-9491242ff70618571d3b0e102e800c04f8603ba9.zip
Cleanup the `InstSimplify` MIR transformation
Diffstat (limited to 'compiler/rustc_mir_transform/src/instsimplify.rs')
-rw-r--r--compiler/rustc_mir_transform/src/instsimplify.rs192
1 files changed, 86 insertions, 106 deletions
diff --git a/compiler/rustc_mir_transform/src/instsimplify.rs b/compiler/rustc_mir_transform/src/instsimplify.rs
index 2eff6b31372..a8d6aaa50a2 100644
--- a/compiler/rustc_mir_transform/src/instsimplify.rs
+++ b/compiler/rustc_mir_transform/src/instsimplify.rs
@@ -39,26 +39,26 @@ impl<'tcx> crate::MirPass<'tcx> for InstSimplify {
             attr::contains_name(tcx.hir_krate_attrs(), sym::rustc_preserve_ub_checks);
         for block in body.basic_blocks.as_mut() {
             for statement in block.statements.iter_mut() {
-                match statement.kind {
-                    StatementKind::Assign(box (_place, ref mut rvalue)) => {
-                        if !preserve_ub_checks {
-                            ctx.simplify_ub_check(rvalue);
-                        }
-                        ctx.simplify_bool_cmp(rvalue);
-                        ctx.simplify_ref_deref(rvalue);
-                        ctx.simplify_ptr_aggregate(rvalue);
-                        ctx.simplify_cast(rvalue);
-                        ctx.simplify_repeated_aggregate(rvalue);
-                        ctx.simplify_repeat_once(rvalue);
-                    }
-                    _ => {}
+                let StatementKind::Assign(box (.., rvalue)) = &mut statement.kind else {
+                    continue;
+                };
+
+                if !preserve_ub_checks {
+                    ctx.simplify_ub_check(rvalue);
                 }
+                ctx.simplify_bool_cmp(rvalue);
+                ctx.simplify_ref_deref(rvalue);
+                ctx.simplify_ptr_aggregate(rvalue);
+                ctx.simplify_cast(rvalue);
+                ctx.simplify_repeated_aggregate(rvalue);
+                ctx.simplify_repeat_once(rvalue);
             }
 
-            ctx.simplify_primitive_clone(block.terminator.as_mut().unwrap(), &mut block.statements);
-            ctx.simplify_intrinsic_assert(block.terminator.as_mut().unwrap());
-            ctx.simplify_nounwind_call(block.terminator.as_mut().unwrap());
-            simplify_duplicate_switch_targets(block.terminator.as_mut().unwrap());
+            let terminator = block.terminator.as_mut().unwrap();
+            ctx.simplify_primitive_clone(terminator, &mut block.statements);
+            ctx.simplify_intrinsic_assert(terminator);
+            ctx.simplify_nounwind_call(terminator);
+            simplify_duplicate_switch_targets(terminator);
         }
     }
 
@@ -105,43 +105,34 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
 
     /// Transform boolean comparisons into logical operations.
     fn simplify_bool_cmp(&self, rvalue: &mut Rvalue<'tcx>) {
-        match rvalue {
-            Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), box (a, b)) => {
-                let new = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) {
-                    // Transform "Eq(a, true)" ==> "a"
-                    (BinOp::Eq, _, Some(true)) => Some(Rvalue::Use(a.clone())),
+        let Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), box (a, b)) = &*rvalue else { return };
+        *rvalue = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) {
+            // Transform "Eq(a, true)" ==> "a"
+            (BinOp::Eq, _, Some(true)) => Rvalue::Use(a.clone()),
 
-                    // Transform "Ne(a, false)" ==> "a"
-                    (BinOp::Ne, _, Some(false)) => Some(Rvalue::Use(a.clone())),
+            // Transform "Ne(a, false)" ==> "a"
+            (BinOp::Ne, _, Some(false)) => Rvalue::Use(a.clone()),
 
-                    // Transform "Eq(true, b)" ==> "b"
-                    (BinOp::Eq, Some(true), _) => Some(Rvalue::Use(b.clone())),
+            // Transform "Eq(true, b)" ==> "b"
+            (BinOp::Eq, Some(true), _) => Rvalue::Use(b.clone()),
 
-                    // Transform "Ne(false, b)" ==> "b"
-                    (BinOp::Ne, Some(false), _) => Some(Rvalue::Use(b.clone())),
+            // Transform "Ne(false, b)" ==> "b"
+            (BinOp::Ne, Some(false), _) => Rvalue::Use(b.clone()),
 
-                    // Transform "Eq(false, b)" ==> "Not(b)"
-                    (BinOp::Eq, Some(false), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())),
+            // Transform "Eq(false, b)" ==> "Not(b)"
+            (BinOp::Eq, Some(false), _) => Rvalue::UnaryOp(UnOp::Not, b.clone()),
 
-                    // Transform "Ne(true, b)" ==> "Not(b)"
-                    (BinOp::Ne, Some(true), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())),
+            // Transform "Ne(true, b)" ==> "Not(b)"
+            (BinOp::Ne, Some(true), _) => Rvalue::UnaryOp(UnOp::Not, b.clone()),
 
-                    // Transform "Eq(a, false)" ==> "Not(a)"
-                    (BinOp::Eq, _, Some(false)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())),
+            // Transform "Eq(a, false)" ==> "Not(a)"
+            (BinOp::Eq, _, Some(false)) => Rvalue::UnaryOp(UnOp::Not, a.clone()),
 
-                    // Transform "Ne(a, true)" ==> "Not(a)"
-                    (BinOp::Ne, _, Some(true)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())),
-
-                    _ => None,
-                };
-
-                if let Some(new) = new {
-                    *rvalue = new;
-                }
-            }
+            // Transform "Ne(a, true)" ==> "Not(a)"
+            (BinOp::Ne, _, Some(true)) => Rvalue::UnaryOp(UnOp::Not, a.clone()),
 
-            _ => {}
-        }
+            _ => return,
+        };
     }
 
     fn try_eval_bool(&self, a: &Operand<'_>) -> Option<bool> {
@@ -151,64 +142,58 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
 
     /// Transform `&(*a)` ==> `a`.
     fn simplify_ref_deref(&self, rvalue: &mut Rvalue<'tcx>) {
-        if let Rvalue::Ref(_, _, place) | Rvalue::RawPtr(_, place) = rvalue {
-            if let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() {
-                if rvalue.ty(self.local_decls, self.tcx) != base.ty(self.local_decls, self.tcx).ty {
-                    return;
-                }
-
-                *rvalue = Rvalue::Use(Operand::Copy(Place {
-                    local: base.local,
-                    projection: self.tcx.mk_place_elems(base.projection),
-                }));
-            }
+        if let Rvalue::Ref(_, _, place) | Rvalue::RawPtr(_, place) = rvalue
+            && let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection()
+            && rvalue.ty(self.local_decls, self.tcx) == base.ty(self.local_decls, self.tcx).ty
+        {
+            *rvalue = Rvalue::Use(Operand::Copy(Place {
+                local: base.local,
+                projection: self.tcx.mk_place_elems(base.projection),
+            }));
         }
     }
 
     /// Transform `Aggregate(RawPtr, [p, ()])` ==> `Cast(PtrToPtr, p)`.
     fn simplify_ptr_aggregate(&self, rvalue: &mut Rvalue<'tcx>) {
         if let Rvalue::Aggregate(box AggregateKind::RawPtr(pointee_ty, mutability), fields) = rvalue
+            && let meta_ty = fields.raw[1].ty(self.local_decls, self.tcx)
+            && meta_ty.is_unit()
         {
-            let meta_ty = fields.raw[1].ty(self.local_decls, self.tcx);
-            if meta_ty.is_unit() {
-                // The mutable borrows we're holding prevent printing `rvalue` here
-                let mut fields = std::mem::take(fields);
-                let _meta = fields.pop().unwrap();
-                let data = fields.pop().unwrap();
-                let ptr_ty = Ty::new_ptr(self.tcx, *pointee_ty, *mutability);
-                *rvalue = Rvalue::Cast(CastKind::PtrToPtr, data, ptr_ty);
-            }
+            // The mutable borrows we're holding prevent printing `rvalue` here
+            let mut fields = std::mem::take(fields);
+            let _meta = fields.pop().unwrap();
+            let data = fields.pop().unwrap();
+            let ptr_ty = Ty::new_ptr(self.tcx, *pointee_ty, *mutability);
+            *rvalue = Rvalue::Cast(CastKind::PtrToPtr, data, ptr_ty);
         }
     }
 
     fn simplify_ub_check(&self, rvalue: &mut Rvalue<'tcx>) {
-        if let Rvalue::NullaryOp(NullOp::UbChecks, _) = *rvalue {
-            let const_ = Const::from_bool(self.tcx, self.tcx.sess.ub_checks());
-            let constant = ConstOperand { span: DUMMY_SP, const_, user_ty: None };
-            *rvalue = Rvalue::Use(Operand::Constant(Box::new(constant)));
-        }
+        let Rvalue::NullaryOp(NullOp::UbChecks, _) = *rvalue else { return };
+
+        let const_ = Const::from_bool(self.tcx, self.tcx.sess.ub_checks());
+        let constant = ConstOperand { span: DUMMY_SP, const_, user_ty: None };
+        *rvalue = Rvalue::Use(Operand::Constant(Box::new(constant)));
     }
 
     fn simplify_cast(&self, rvalue: &mut Rvalue<'tcx>) {
-        if let Rvalue::Cast(kind, operand, cast_ty) = rvalue {
-            let operand_ty = operand.ty(self.local_decls, self.tcx);
-            if operand_ty == *cast_ty {
-                *rvalue = Rvalue::Use(operand.clone());
-            } else if *kind == CastKind::Transmute {
-                // Transmuting an integer to another integer is just a signedness cast
-                if let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) =
-                    (operand_ty.kind(), cast_ty.kind())
-                    && int.bit_width() == uint.bit_width()
-                {
-                    // The width check isn't strictly necessary, as different widths
-                    // are UB and thus we'd be allowed to turn it into a cast anyway.
-                    // But let's keep the UB around for codegen to exploit later.
-                    // (If `CastKind::Transmute` ever becomes *not* UB for mismatched sizes,
-                    // then the width check is necessary for big-endian correctness.)
-                    *kind = CastKind::IntToInt;
-                    return;
-                }
-            }
+        let Rvalue::Cast(kind, operand, cast_ty) = rvalue else { return };
+
+        let operand_ty = operand.ty(self.local_decls, self.tcx);
+        if operand_ty == *cast_ty {
+            *rvalue = Rvalue::Use(operand.clone());
+        } else if *kind == CastKind::Transmute
+            // Transmuting an integer to another integer is just a signedness cast
+            && let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) =
+                (operand_ty.kind(), cast_ty.kind())
+            && int.bit_width() == uint.bit_width()
+        {
+            // The width check isn't strictly necessary, as different widths
+            // are UB and thus we'd be allowed to turn it into a cast anyway.
+            // But let's keep the UB around for codegen to exploit later.
+            // (If `CastKind::Transmute` ever becomes *not* UB for mismatched sizes,
+            // then the width check is necessary for big-endian correctness.)
+            *kind = CastKind::IntToInt;
         }
     }
 
@@ -277,7 +262,7 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
     }
 
     fn simplify_nounwind_call(&self, terminator: &mut Terminator<'tcx>) {
-        let TerminatorKind::Call { func, unwind, .. } = &mut terminator.kind else {
+        let TerminatorKind::Call { ref func, ref mut unwind, .. } = terminator.kind else {
             return;
         };
 
@@ -290,7 +275,7 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
             ty::FnDef(..) => body_ty.fn_sig(self.tcx).abi(),
             ty::Closure(..) => ExternAbi::RustCall,
             ty::Coroutine(..) => ExternAbi::Rust,
-            _ => bug!("unexpected body ty: {:?}", body_ty),
+            _ => bug!("unexpected body ty: {body_ty:?}"),
         };
 
         if !layout::fn_can_unwind(self.tcx, Some(def_id), body_abi) {
@@ -299,10 +284,9 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
     }
 
     fn simplify_intrinsic_assert(&self, terminator: &mut Terminator<'tcx>) {
-        let TerminatorKind::Call { func, target, .. } = &mut terminator.kind else {
-            return;
-        };
-        let Some(target_block) = target else {
+        let TerminatorKind::Call { ref func, target: ref mut target @ Some(target_block), .. } =
+            terminator.kind
+        else {
             return;
         };
         let func_ty = func.ty(self.local_decls, self.tcx);
@@ -310,12 +294,10 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
             return;
         };
         // The intrinsics we are interested in have one generic parameter
-        if args.is_empty() {
-            return;
-        }
+        let [arg, ..] = args[..] else { return };
 
         let known_is_valid =
-            intrinsic_assert_panics(self.tcx, self.typing_env, args[0], intrinsic_name);
+            intrinsic_assert_panics(self.tcx, self.typing_env, arg, intrinsic_name);
         match known_is_valid {
             // We don't know the layout or it's not validity assertion at all, don't touch it
             None => {}
@@ -325,7 +307,7 @@ impl<'tcx> InstSimplifyContext<'_, 'tcx> {
             }
             Some(false) => {
                 // If we know the assert does not panic, turn the call into a Goto
-                terminator.kind = TerminatorKind::Goto { target: *target_block };
+                terminator.kind = TerminatorKind::Goto { target: target_block };
             }
         }
     }
@@ -346,9 +328,7 @@ fn resolve_rust_intrinsic<'tcx>(
     tcx: TyCtxt<'tcx>,
     func_ty: Ty<'tcx>,
 ) -> Option<(Symbol, GenericArgsRef<'tcx>)> {
-    if let ty::FnDef(def_id, args) = *func_ty.kind() {
-        let intrinsic = tcx.intrinsic(def_id)?;
-        return Some((intrinsic.name, args));
-    }
-    None
+    let ty::FnDef(def_id, args) = *func_ty.kind() else { return None };
+    let intrinsic = tcx.intrinsic(def_id)?;
+    Some((intrinsic.name, args))
 }