about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src/cost_checker.rs
diff options
context:
space:
mode:
authorScott McMurray <scottmcm@users.noreply.github.com>2024-06-17 00:36:21 -0700
committerScott McMurray <scottmcm@users.noreply.github.com>2024-06-19 21:35:37 -0700
commit4236da52af3c6eeb4b1c926d932a31e1e53c8e77 (patch)
tree48d4cc108218e6f718558c979bddc423b92148d4 /compiler/rustc_mir_transform/src/cost_checker.rs
parentf3349510303531edaf55dc18ae286bece8c85bdb (diff)
downloadrust-4236da52af3c6eeb4b1c926d932a31e1e53c8e77.tar.gz
rust-4236da52af3c6eeb4b1c926d932a31e1e53c8e77.zip
Give inlining bonuses to things that optimize out
Diffstat (limited to 'compiler/rustc_mir_transform/src/cost_checker.rs')
-rw-r--r--compiler/rustc_mir_transform/src/cost_checker.rs83
1 files changed, 63 insertions, 20 deletions
diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs
index bca4ef5b3d1..32c0d27f635 100644
--- a/compiler/rustc_mir_transform/src/cost_checker.rs
+++ b/compiler/rustc_mir_transform/src/cost_checker.rs
@@ -1,3 +1,4 @@
+use rustc_middle::bug;
 use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
 use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
@@ -6,6 +7,8 @@ const INSTR_COST: usize = 5;
 const CALL_PENALTY: usize = 25;
 const LANDINGPAD_PENALTY: usize = 50;
 const RESUME_PENALTY: usize = 45;
+const LARGE_SWITCH_PENALTY: usize = 20;
+const CONST_SWITCH_BONUS: usize = 10;
 
 /// Verify that the callee body is compatible with the caller.
 #[derive(Clone)]
@@ -42,36 +45,49 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
 }
 
 impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
-    fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
-        // Don't count StorageLive/StorageDead in the inlining cost.
+    fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
+        // Most costs are in rvalues and terminators, not in statements.
         match statement.kind {
-            StatementKind::StorageLive(_)
-            | StatementKind::StorageDead(_)
-            | StatementKind::Deinit(_)
-            | StatementKind::Nop => {}
+            StatementKind::Intrinsic(ref ndi) => {
+                self.penalty += match **ndi {
+                    NonDivergingIntrinsic::Assume(..) => INSTR_COST,
+                    NonDivergingIntrinsic::CopyNonOverlapping(..) => CALL_PENALTY,
+                };
+            }
+            _ => self.super_statement(statement, location),
+        }
+    }
+
+    fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) {
+        match rvalue {
+            Rvalue::NullaryOp(NullOp::UbChecks, ..) if !self.tcx.sess.ub_checks() => {
+                // If this is in optimized MIR it's because it's used later,
+                // so if we don't need UB checks this session, give a bonus
+                // here to offset the cost of the call later.
+                self.bonus += CALL_PENALTY;
+            }
+            // These are essentially constants that didn't end up in an Operand,
+            // so treat them as also being free.
+            Rvalue::NullaryOp(..) => {}
             _ => self.penalty += INSTR_COST,
         }
     }
 
     fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
-        let tcx = self.tcx;
-        match terminator.kind {
-            TerminatorKind::Drop { ref place, unwind, .. } => {
+        match &terminator.kind {
+            TerminatorKind::Drop { place, unwind, .. } => {
                 // If the place doesn't actually need dropping, treat it like a regular goto.
-                let ty = self.instantiate_ty(place.ty(self.callee_body, tcx).ty);
-                if ty.needs_drop(tcx, self.param_env) {
+                let ty = self.instantiate_ty(place.ty(self.callee_body, self.tcx).ty);
+                if ty.needs_drop(self.tcx, self.param_env) {
                     self.penalty += CALL_PENALTY;
                     if let UnwindAction::Cleanup(_) = unwind {
                         self.penalty += LANDINGPAD_PENALTY;
                     }
-                } else {
-                    self.penalty += INSTR_COST;
                 }
             }
-            TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
-                let fn_ty = self.instantiate_ty(f.const_.ty());
-                self.penalty += if let ty::FnDef(def_id, _) = *fn_ty.kind()
-                    && tcx.intrinsic(def_id).is_some()
+            TerminatorKind::Call { func, unwind, .. } => {
+                self.penalty += if let Some((def_id, ..)) = func.const_fn_def()
+                    && self.tcx.intrinsic(def_id).is_some()
                 {
                     // Don't give intrinsics the extra penalty for calls
                     INSTR_COST
@@ -82,8 +98,25 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
                     self.penalty += LANDINGPAD_PENALTY;
                 }
             }
-            TerminatorKind::Assert { unwind, .. } => {
-                self.penalty += CALL_PENALTY;
+            TerminatorKind::SwitchInt { discr, targets } => {
+                if discr.constant().is_some() {
+                    // Not only will this become a `Goto`, but likely other
+                    // things will be removable as unreachable.
+                    self.bonus += CONST_SWITCH_BONUS;
+                } else if targets.all_targets().len() > 3 {
+                    // More than false/true/unreachable gets extra cost.
+                    self.penalty += LARGE_SWITCH_PENALTY;
+                } else {
+                    self.penalty += INSTR_COST;
+                }
+            }
+            TerminatorKind::Assert { unwind, msg, .. } => {
+                self.penalty +=
+                    if msg.is_optional_overflow_check() && !self.tcx.sess.overflow_checks() {
+                        INSTR_COST
+                    } else {
+                        CALL_PENALTY
+                    };
                 if let UnwindAction::Cleanup(_) = unwind {
                     self.penalty += LANDINGPAD_PENALTY;
                 }
@@ -95,7 +128,17 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
                     self.penalty += LANDINGPAD_PENALTY;
                 }
             }
-            _ => self.penalty += INSTR_COST,
+            TerminatorKind::Unreachable => {
+                self.bonus += INSTR_COST;
+            }
+            TerminatorKind::Goto { .. } | TerminatorKind::Return => {}
+            TerminatorKind::UnwindTerminate(..) => {}
+            kind @ (TerminatorKind::FalseUnwind { .. }
+            | TerminatorKind::FalseEdge { .. }
+            | TerminatorKind::Yield { .. }
+            | TerminatorKind::CoroutineDrop) => {
+                bug!("{kind:?} should not be in runtime MIR");
+            }
         }
     }
 }