about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform')
-rw-r--r--compiler/rustc_mir_transform/src/cost_checker.rs47
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs94
2 files changed, 83 insertions, 58 deletions
diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs
index b23d8b9e737..00a8293966b 100644
--- a/compiler/rustc_mir_transform/src/cost_checker.rs
+++ b/compiler/rustc_mir_transform/src/cost_checker.rs
@@ -37,29 +37,11 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
     /// and even the full `Inline` doesn't call `visit_body`, so there's nowhere
     /// to put this logic in the visitor.
     pub(super) fn add_function_level_costs(&mut self) {
-        fn is_call_like(bbd: &BasicBlockData<'_>) -> bool {
-            use TerminatorKind::*;
-            match bbd.terminator().kind {
-                Call { .. } | TailCall { .. } | Drop { .. } | Assert { .. } | InlineAsm { .. } => {
-                    true
-                }
-
-                Goto { .. }
-                | SwitchInt { .. }
-                | UnwindResume
-                | UnwindTerminate(_)
-                | Return
-                | Unreachable => false,
-
-                Yield { .. } | CoroutineDrop | FalseEdge { .. } | FalseUnwind { .. } => {
-                    unreachable!()
-                }
-            }
-        }
-
         // If the only has one Call (or similar), inlining isn't increasing the total
         // number of calls, so give extra encouragement to inlining that.
-        if self.callee_body.basic_blocks.iter().filter(|bbd| is_call_like(bbd)).count() == 1 {
+        if self.callee_body.basic_blocks.iter().filter(|bbd| is_call_like(bbd.terminator())).count()
+            == 1
+        {
             self.bonus += CALL_PENALTY;
         }
     }
@@ -193,3 +175,26 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
         }
     }
 }
+
+/// A terminator that's more call-like (might do a bunch of work, might panic, etc)
+/// than it is goto-/return-like (no side effects, etc).
+///
+/// Used to treat multi-call functions (which could inline exponentially)
+/// different from those that only do one or none of these "complex" things.
+pub(super) fn is_call_like(terminator: &Terminator<'_>) -> bool {
+    use TerminatorKind::*;
+    match terminator.kind {
+        Call { .. } | TailCall { .. } | Drop { .. } | Assert { .. } | InlineAsm { .. } => true,
+
+        Goto { .. }
+        | SwitchInt { .. }
+        | UnwindResume
+        | UnwindTerminate(_)
+        | Return
+        | Unreachable => false,
+
+        Yield { .. } | CoroutineDrop | FalseEdge { .. } | FalseUnwind { .. } => {
+            unreachable!()
+        }
+    }
+}
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 0183ba19475..0ab24e48d44 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -1,5 +1,6 @@
 //! Inlining pass for MIR functions.
 
+use std::assert_matches::debug_assert_matches;
 use std::iter;
 use std::ops::{Range, RangeFrom};
 
@@ -18,7 +19,7 @@ use rustc_session::config::{DebugInfo, OptLevel};
 use rustc_span::source_map::Spanned;
 use tracing::{debug, instrument, trace, trace_span};
 
-use crate::cost_checker::CostChecker;
+use crate::cost_checker::{CostChecker, is_call_like};
 use crate::deref_separator::deref_finder;
 use crate::simplify::simplify_cfg;
 use crate::validate::validate_types;
@@ -26,6 +27,7 @@ use crate::{check_inline, util};
 
 pub(crate) mod cycle;
 
+const HISTORY_DEPTH_LIMIT: usize = 20;
 const TOP_DOWN_DEPTH_LIMIT: usize = 5;
 
 #[derive(Clone, Debug)]
@@ -117,6 +119,11 @@ trait Inliner<'tcx> {
     /// Should inlining happen for a given callee?
     fn should_inline_for_callee(&self, def_id: DefId) -> bool;
 
+    fn check_codegen_attributes_extra(
+        &self,
+        callee_attrs: &CodegenFnAttrs,
+    ) -> Result<(), &'static str>;
+
     fn check_caller_mir_body(&self, body: &Body<'tcx>) -> bool;
 
     /// Returns inlining decision that is based on the examination of callee MIR body.
@@ -128,10 +135,6 @@ trait Inliner<'tcx> {
         callee_attrs: &CodegenFnAttrs,
     ) -> Result<(), &'static str>;
 
-    // How many callsites in a body are we allowed to inline? We need to limit this in order
-    // to prevent super-linear growth in MIR size.
-    fn inline_limit_for_block(&self) -> Option<usize>;
-
     /// Called when inlining succeeds.
     fn on_inline_success(
         &mut self,
@@ -142,9 +145,6 @@ trait Inliner<'tcx> {
 
     /// Called when inlining failed or was not performed.
     fn on_inline_failure(&self, callsite: &CallSite<'tcx>, reason: &'static str);
-
-    /// Called when the inline limit for a body is reached.
-    fn on_inline_limit_reached(&self) -> bool;
 }
 
 struct ForceInliner<'tcx> {
@@ -191,6 +191,14 @@ impl<'tcx> Inliner<'tcx> for ForceInliner<'tcx> {
         ForceInline::should_run_pass_for_callee(self.tcx(), def_id)
     }
 
+    fn check_codegen_attributes_extra(
+        &self,
+        callee_attrs: &CodegenFnAttrs,
+    ) -> Result<(), &'static str> {
+        debug_assert_matches!(callee_attrs.inline, InlineAttr::Force { .. });
+        Ok(())
+    }
+
     fn check_caller_mir_body(&self, _: &Body<'tcx>) -> bool {
         true
     }
@@ -224,10 +232,6 @@ impl<'tcx> Inliner<'tcx> for ForceInliner<'tcx> {
         }
     }
 
-    fn inline_limit_for_block(&self) -> Option<usize> {
-        Some(usize::MAX)
-    }
-
     fn on_inline_success(
         &mut self,
         callsite: &CallSite<'tcx>,
@@ -261,10 +265,6 @@ impl<'tcx> Inliner<'tcx> for ForceInliner<'tcx> {
             justification: justification.map(|sym| crate::errors::ForceInlineJustification { sym }),
         });
     }
-
-    fn on_inline_limit_reached(&self) -> bool {
-        false
-    }
 }
 
 struct NormalInliner<'tcx> {
@@ -278,6 +278,10 @@ struct NormalInliner<'tcx> {
     /// The number of `DefId`s is finite, so checking history is enough
     /// to ensure that we do not loop endlessly while inlining.
     history: Vec<DefId>,
+    /// How many (multi-call) callsites have we inlined for the top-level call?
+    ///
+    /// We need to limit this in order to prevent super-linear growth in MIR size.
+    top_down_counter: usize,
     /// Indicates that the caller body has been modified.
     changed: bool,
     /// Indicates that the caller is #[inline] and just calls another function,
@@ -285,6 +289,12 @@ struct NormalInliner<'tcx> {
     caller_is_inline_forwarder: bool,
 }
 
+impl<'tcx> NormalInliner<'tcx> {
+    fn past_depth_limit(&self) -> bool {
+        self.history.len() > HISTORY_DEPTH_LIMIT || self.top_down_counter > TOP_DOWN_DEPTH_LIMIT
+    }
+}
+
 impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
     fn new(tcx: TyCtxt<'tcx>, def_id: DefId, body: &Body<'tcx>) -> Self {
         let typing_env = body.typing_env(tcx);
@@ -295,6 +305,7 @@ impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
             typing_env,
             def_id,
             history: Vec::new(),
+            top_down_counter: 0,
             changed: false,
             caller_is_inline_forwarder: matches!(
                 codegen_fn_attrs.inline,
@@ -327,6 +338,17 @@ impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
         true
     }
 
+    fn check_codegen_attributes_extra(
+        &self,
+        callee_attrs: &CodegenFnAttrs,
+    ) -> Result<(), &'static str> {
+        if self.past_depth_limit() && matches!(callee_attrs.inline, InlineAttr::None) {
+            Err("Past depth limit so not inspecting unmarked callee")
+        } else {
+            Ok(())
+        }
+    }
+
     fn check_caller_mir_body(&self, body: &Body<'tcx>) -> bool {
         // Avoid inlining into coroutines, since their `optimized_mir` is used for layout computation,
         // which can create a cycle, even when no attempt is made to inline the function in the other
@@ -351,7 +373,11 @@ impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
             return Err("body has errors");
         }
 
-        let mut threshold = if self.caller_is_inline_forwarder {
+        if self.past_depth_limit() && callee_body.basic_blocks.len() > 1 {
+            return Err("Not inlining multi-block body as we're past a depth limit");
+        }
+
+        let mut threshold = if self.caller_is_inline_forwarder || self.past_depth_limit() {
             tcx.sess.opts.unstable_opts.inline_mir_forwarder_threshold.unwrap_or(30)
         } else if tcx.cross_crate_inlinable(callsite.callee.def_id()) {
             tcx.sess.opts.unstable_opts.inline_mir_hint_threshold.unwrap_or(100)
@@ -431,14 +457,6 @@ impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
         }
     }
 
-    fn inline_limit_for_block(&self) -> Option<usize> {
-        match self.history.len() {
-            0 => Some(usize::MAX),
-            1..=TOP_DOWN_DEPTH_LIMIT => Some(1),
-            _ => None,
-        }
-    }
-
     fn on_inline_success(
         &mut self,
         callsite: &CallSite<'tcx>,
@@ -447,13 +465,21 @@ impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
     ) {
         self.changed = true;
 
+        let new_calls_count = new_blocks
+            .clone()
+            .filter(|&bb| is_call_like(caller_body.basic_blocks[bb].terminator()))
+            .count();
+        if new_calls_count > 1 {
+            self.top_down_counter += 1;
+        }
+
         self.history.push(callsite.callee.def_id());
         process_blocks(self, caller_body, new_blocks);
         self.history.pop();
-    }
 
-    fn on_inline_limit_reached(&self) -> bool {
-        true
+        if self.history.is_empty() {
+            self.top_down_counter = 0;
+        }
     }
 
     fn on_inline_failure(&self, _: &CallSite<'tcx>, _: &'static str) {}
@@ -482,8 +508,6 @@ fn process_blocks<'tcx, I: Inliner<'tcx>>(
     caller_body: &mut Body<'tcx>,
     blocks: Range<BasicBlock>,
 ) {
-    let Some(inline_limit) = inliner.inline_limit_for_block() else { return };
-    let mut inlined_count = 0;
     for bb in blocks {
         let bb_data = &caller_body[bb];
         if bb_data.is_cleanup {
@@ -505,13 +529,6 @@ fn process_blocks<'tcx, I: Inliner<'tcx>>(
             Ok(new_blocks) => {
                 debug!("inlined {}", callsite.callee);
                 inliner.on_inline_success(&callsite, caller_body, new_blocks);
-
-                inlined_count += 1;
-                if inlined_count == inline_limit {
-                    if inliner.on_inline_limit_reached() {
-                        return;
-                    }
-                }
             }
         }
     }
@@ -584,6 +601,7 @@ fn try_inlining<'tcx, I: Inliner<'tcx>>(
     let callee_attrs = tcx.codegen_fn_attrs(callsite.callee.def_id());
     check_inline::is_inline_valid_on_fn(tcx, callsite.callee.def_id())?;
     check_codegen_attributes(inliner, callsite, callee_attrs)?;
+    inliner.check_codegen_attributes_extra(callee_attrs)?;
 
     let terminator = caller_body[callsite.block].terminator.as_ref().unwrap();
     let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() };
@@ -770,6 +788,8 @@ fn check_codegen_attributes<'tcx, I: Inliner<'tcx>>(
         return Err("has DoNotOptimize attribute");
     }
 
+    inliner.check_codegen_attributes_extra(callee_attrs)?;
+
     // Reachability pass defines which functions are eligible for inlining. Generally inlining
     // other functions is incorrect because they could reference symbols that aren't exported.
     let is_generic = callsite.callee.args.non_erasable_generics().next().is_some();