about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
authorJakub Beránek <berykubik@gmail.com>2025-03-18 07:43:12 +0100
committerGitHub <noreply@github.com>2025-03-18 07:43:12 +0100
commite9d50f4c577e966921c09df521ccef2df431c0e3 (patch)
treec87367c4bb57995c3ba78b92e175a3795e76775c /compiler/rustc_mir_transform/src
parentef46ce7b7378b366c345f0f3823cfee9921b69c6 (diff)
parent69ed0232ef33f06b6fefc14f53150a00af9200e8 (diff)
downloadrust-e9d50f4c577e966921c09df521ccef2df431c0e3.tar.gz
rust-e9d50f4c577e966921c09df521ccef2df431c0e3.zip
Merge pull request #2293 from jieyouxu/rustc-pull
Rustc pull
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/coroutine/by_move_body.rs41
-rw-r--r--compiler/rustc_mir_transform/src/cost_checker.rs47
-rw-r--r--compiler/rustc_mir_transform/src/gvn.rs10
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs94
-rw-r--r--compiler/rustc_mir_transform/src/lint_tail_expr_drop_order.rs15
-rw-r--r--compiler/rustc_mir_transform/src/match_branches.rs9
-rw-r--r--compiler/rustc_mir_transform/src/validate.rs3
7 files changed, 142 insertions, 77 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
index 9cd7045a0a2..89a306c6104 100644
--- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
+++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
@@ -178,7 +178,7 @@ pub(crate) fn coroutine_by_move_body_def_id<'tcx>(
                 ),
             };
 
-            (
+            Some((
                 FieldIdx::from_usize(child_field_idx + num_args),
                 (
                     FieldIdx::from_usize(parent_field_idx + num_args),
@@ -186,9 +186,10 @@ pub(crate) fn coroutine_by_move_body_def_id<'tcx>(
                     peel_deref,
                     child_precise_captures,
                 ),
-            )
+            ))
         },
     )
+    .flatten()
     .collect();
 
     if coroutine_kind == ty::ClosureKind::FnOnce {
@@ -312,10 +313,46 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
         self.super_place(place, context, location);
     }
 
+    fn visit_statement(&mut self, statement: &mut mir::Statement<'tcx>, location: mir::Location) {
+        // Remove fake borrows of closure captures if that capture has been
+        // replaced with a by-move version of that capture.
+        //
+        // For example, imagine we capture `Foo` in the parent and `&Foo`
+        // in the child. We will emit two fake borrows like:
+        //
+        // ```
+        //    _2 = &fake shallow (*(_1.0: &Foo));
+        //    _3 = &fake shallow (_1.0: &Foo);
+        // ```
+        //
+        // However, since this transform is responsible for replacing
+        // `_1.0: &Foo` with `_1.0: Foo`, that makes the second fake borrow
+        // obsolete, and we should replace it with a nop.
+        //
+        // As a side-note, we don't actually even care about fake borrows
+        // here at all since they're fully a MIR borrowck artifact, and we
+        // don't need to borrowck by-move MIR bodies. But it's best to preserve
+        // as much as we can between these two bodies :)
+        if let mir::StatementKind::Assign(box (_, rvalue)) = &statement.kind
+            && let mir::Rvalue::Ref(_, mir::BorrowKind::Fake(mir::FakeBorrowKind::Shallow), place) =
+                rvalue
+            && let mir::PlaceRef {
+                local: ty::CAPTURE_STRUCT_LOCAL,
+                projection: [mir::ProjectionElem::Field(idx, _)],
+            } = place.as_ref()
+            && let Some(&(_, _, true, _)) = self.field_remapping.get(&idx)
+        {
+            statement.kind = mir::StatementKind::Nop;
+        }
+
+        self.super_statement(statement, location);
+    }
+
     fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
         // Replace the type of the self arg.
         if local == ty::CAPTURE_STRUCT_LOCAL {
             local_decl.ty = self.by_move_coroutine_ty;
         }
+        self.super_local_decl(local, local_decl);
     }
 }
diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs
index b23d8b9e737..00a8293966b 100644
--- a/compiler/rustc_mir_transform/src/cost_checker.rs
+++ b/compiler/rustc_mir_transform/src/cost_checker.rs
@@ -37,29 +37,11 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
     /// and even the full `Inline` doesn't call `visit_body`, so there's nowhere
     /// to put this logic in the visitor.
     pub(super) fn add_function_level_costs(&mut self) {
-        fn is_call_like(bbd: &BasicBlockData<'_>) -> bool {
-            use TerminatorKind::*;
-            match bbd.terminator().kind {
-                Call { .. } | TailCall { .. } | Drop { .. } | Assert { .. } | InlineAsm { .. } => {
-                    true
-                }
-
-                Goto { .. }
-                | SwitchInt { .. }
-                | UnwindResume
-                | UnwindTerminate(_)
-                | Return
-                | Unreachable => false,
-
-                Yield { .. } | CoroutineDrop | FalseEdge { .. } | FalseUnwind { .. } => {
-                    unreachable!()
-                }
-            }
-        }
-
         // If the only has one Call (or similar), inlining isn't increasing the total
         // number of calls, so give extra encouragement to inlining that.
-        if self.callee_body.basic_blocks.iter().filter(|bbd| is_call_like(bbd)).count() == 1 {
+        if self.callee_body.basic_blocks.iter().filter(|bbd| is_call_like(bbd.terminator())).count()
+            == 1
+        {
             self.bonus += CALL_PENALTY;
         }
     }
@@ -193,3 +175,26 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
         }
     }
 }
+
+/// A terminator that's more call-like (might do a bunch of work, might panic, etc)
+/// than it is goto-/return-like (no side effects, etc).
+///
+/// Used to treat multi-call functions (which could inline exponentially)
+/// different from those that only do one or none of these "complex" things.
+pub(super) fn is_call_like(terminator: &Terminator<'_>) -> bool {
+    use TerminatorKind::*;
+    match terminator.kind {
+        Call { .. } | TailCall { .. } | Drop { .. } | Assert { .. } | InlineAsm { .. } => true,
+
+        Goto { .. }
+        | SwitchInt { .. }
+        | UnwindResume
+        | UnwindTerminate(_)
+        | Return
+        | Unreachable => false,
+
+        Yield { .. } | CoroutineDrop | FalseEdge { .. } | FalseUnwind { .. } => {
+            unreachable!()
+        }
+    }
+}
diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs
index 981dedd5b5c..0a54c780f31 100644
--- a/compiler/rustc_mir_transform/src/gvn.rs
+++ b/compiler/rustc_mir_transform/src/gvn.rs
@@ -872,8 +872,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
                 self.simplify_place_projection(place, location);
                 return self.new_pointer(*place, AddressKind::Address(mutbl));
             }
-            Rvalue::WrapUnsafeBinder(ref mut op, _) => {
-                return self.simplify_operand(op, location);
+            Rvalue::WrapUnsafeBinder(ref mut op, ty) => {
+                let value = self.simplify_operand(op, location)?;
+                Value::Cast {
+                    kind: CastKind::Transmute,
+                    value,
+                    from: op.ty(self.local_decls, self.tcx),
+                    to: ty,
+                }
             }
 
             // Operations.
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 0183ba19475..0ab24e48d44 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -1,5 +1,6 @@
 //! Inlining pass for MIR functions.
 
+use std::assert_matches::debug_assert_matches;
 use std::iter;
 use std::ops::{Range, RangeFrom};
 
@@ -18,7 +19,7 @@ use rustc_session::config::{DebugInfo, OptLevel};
 use rustc_span::source_map::Spanned;
 use tracing::{debug, instrument, trace, trace_span};
 
-use crate::cost_checker::CostChecker;
+use crate::cost_checker::{CostChecker, is_call_like};
 use crate::deref_separator::deref_finder;
 use crate::simplify::simplify_cfg;
 use crate::validate::validate_types;
@@ -26,6 +27,7 @@ use crate::{check_inline, util};
 
 pub(crate) mod cycle;
 
+const HISTORY_DEPTH_LIMIT: usize = 20;
 const TOP_DOWN_DEPTH_LIMIT: usize = 5;
 
 #[derive(Clone, Debug)]
@@ -117,6 +119,11 @@ trait Inliner<'tcx> {
     /// Should inlining happen for a given callee?
     fn should_inline_for_callee(&self, def_id: DefId) -> bool;
 
+    fn check_codegen_attributes_extra(
+        &self,
+        callee_attrs: &CodegenFnAttrs,
+    ) -> Result<(), &'static str>;
+
     fn check_caller_mir_body(&self, body: &Body<'tcx>) -> bool;
 
     /// Returns inlining decision that is based on the examination of callee MIR body.
@@ -128,10 +135,6 @@ trait Inliner<'tcx> {
         callee_attrs: &CodegenFnAttrs,
     ) -> Result<(), &'static str>;
 
-    // How many callsites in a body are we allowed to inline? We need to limit this in order
-    // to prevent super-linear growth in MIR size.
-    fn inline_limit_for_block(&self) -> Option<usize>;
-
     /// Called when inlining succeeds.
     fn on_inline_success(
         &mut self,
@@ -142,9 +145,6 @@ trait Inliner<'tcx> {
 
     /// Called when inlining failed or was not performed.
     fn on_inline_failure(&self, callsite: &CallSite<'tcx>, reason: &'static str);
-
-    /// Called when the inline limit for a body is reached.
-    fn on_inline_limit_reached(&self) -> bool;
 }
 
 struct ForceInliner<'tcx> {
@@ -191,6 +191,14 @@ impl<'tcx> Inliner<'tcx> for ForceInliner<'tcx> {
         ForceInline::should_run_pass_for_callee(self.tcx(), def_id)
     }
 
+    fn check_codegen_attributes_extra(
+        &self,
+        callee_attrs: &CodegenFnAttrs,
+    ) -> Result<(), &'static str> {
+        debug_assert_matches!(callee_attrs.inline, InlineAttr::Force { .. });
+        Ok(())
+    }
+
     fn check_caller_mir_body(&self, _: &Body<'tcx>) -> bool {
         true
     }
@@ -224,10 +232,6 @@ impl<'tcx> Inliner<'tcx> for ForceInliner<'tcx> {
         }
     }
 
-    fn inline_limit_for_block(&self) -> Option<usize> {
-        Some(usize::MAX)
-    }
-
     fn on_inline_success(
         &mut self,
         callsite: &CallSite<'tcx>,
@@ -261,10 +265,6 @@ impl<'tcx> Inliner<'tcx> for ForceInliner<'tcx> {
             justification: justification.map(|sym| crate::errors::ForceInlineJustification { sym }),
         });
     }
-
-    fn on_inline_limit_reached(&self) -> bool {
-        false
-    }
 }
 
 struct NormalInliner<'tcx> {
@@ -278,6 +278,10 @@ struct NormalInliner<'tcx> {
     /// The number of `DefId`s is finite, so checking history is enough
     /// to ensure that we do not loop endlessly while inlining.
     history: Vec<DefId>,
+    /// How many (multi-call) callsites have we inlined for the top-level call?
+    ///
+    /// We need to limit this in order to prevent super-linear growth in MIR size.
+    top_down_counter: usize,
     /// Indicates that the caller body has been modified.
     changed: bool,
     /// Indicates that the caller is #[inline] and just calls another function,
@@ -285,6 +289,12 @@ struct NormalInliner<'tcx> {
     caller_is_inline_forwarder: bool,
 }
 
+impl<'tcx> NormalInliner<'tcx> {
+    fn past_depth_limit(&self) -> bool {
+        self.history.len() > HISTORY_DEPTH_LIMIT || self.top_down_counter > TOP_DOWN_DEPTH_LIMIT
+    }
+}
+
 impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
     fn new(tcx: TyCtxt<'tcx>, def_id: DefId, body: &Body<'tcx>) -> Self {
         let typing_env = body.typing_env(tcx);
@@ -295,6 +305,7 @@ impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
             typing_env,
             def_id,
             history: Vec::new(),
+            top_down_counter: 0,
             changed: false,
             caller_is_inline_forwarder: matches!(
                 codegen_fn_attrs.inline,
@@ -327,6 +338,17 @@ impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
         true
     }
 
+    fn check_codegen_attributes_extra(
+        &self,
+        callee_attrs: &CodegenFnAttrs,
+    ) -> Result<(), &'static str> {
+        if self.past_depth_limit() && matches!(callee_attrs.inline, InlineAttr::None) {
+            Err("Past depth limit so not inspecting unmarked callee")
+        } else {
+            Ok(())
+        }
+    }
+
     fn check_caller_mir_body(&self, body: &Body<'tcx>) -> bool {
         // Avoid inlining into coroutines, since their `optimized_mir` is used for layout computation,
         // which can create a cycle, even when no attempt is made to inline the function in the other
@@ -351,7 +373,11 @@ impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
             return Err("body has errors");
         }
 
-        let mut threshold = if self.caller_is_inline_forwarder {
+        if self.past_depth_limit() && callee_body.basic_blocks.len() > 1 {
+            return Err("Not inlining multi-block body as we're past a depth limit");
+        }
+
+        let mut threshold = if self.caller_is_inline_forwarder || self.past_depth_limit() {
             tcx.sess.opts.unstable_opts.inline_mir_forwarder_threshold.unwrap_or(30)
         } else if tcx.cross_crate_inlinable(callsite.callee.def_id()) {
             tcx.sess.opts.unstable_opts.inline_mir_hint_threshold.unwrap_or(100)
@@ -431,14 +457,6 @@ impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
         }
     }
 
-    fn inline_limit_for_block(&self) -> Option<usize> {
-        match self.history.len() {
-            0 => Some(usize::MAX),
-            1..=TOP_DOWN_DEPTH_LIMIT => Some(1),
-            _ => None,
-        }
-    }
-
     fn on_inline_success(
         &mut self,
         callsite: &CallSite<'tcx>,
@@ -447,13 +465,21 @@ impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> {
     ) {
         self.changed = true;
 
+        let new_calls_count = new_blocks
+            .clone()
+            .filter(|&bb| is_call_like(caller_body.basic_blocks[bb].terminator()))
+            .count();
+        if new_calls_count > 1 {
+            self.top_down_counter += 1;
+        }
+
         self.history.push(callsite.callee.def_id());
         process_blocks(self, caller_body, new_blocks);
         self.history.pop();
-    }
 
-    fn on_inline_limit_reached(&self) -> bool {
-        true
+        if self.history.is_empty() {
+            self.top_down_counter = 0;
+        }
     }
 
     fn on_inline_failure(&self, _: &CallSite<'tcx>, _: &'static str) {}
@@ -482,8 +508,6 @@ fn process_blocks<'tcx, I: Inliner<'tcx>>(
     caller_body: &mut Body<'tcx>,
     blocks: Range<BasicBlock>,
 ) {
-    let Some(inline_limit) = inliner.inline_limit_for_block() else { return };
-    let mut inlined_count = 0;
     for bb in blocks {
         let bb_data = &caller_body[bb];
         if bb_data.is_cleanup {
@@ -505,13 +529,6 @@ fn process_blocks<'tcx, I: Inliner<'tcx>>(
             Ok(new_blocks) => {
                 debug!("inlined {}", callsite.callee);
                 inliner.on_inline_success(&callsite, caller_body, new_blocks);
-
-                inlined_count += 1;
-                if inlined_count == inline_limit {
-                    if inliner.on_inline_limit_reached() {
-                        return;
-                    }
-                }
             }
         }
     }
@@ -584,6 +601,7 @@ fn try_inlining<'tcx, I: Inliner<'tcx>>(
     let callee_attrs = tcx.codegen_fn_attrs(callsite.callee.def_id());
     check_inline::is_inline_valid_on_fn(tcx, callsite.callee.def_id())?;
     check_codegen_attributes(inliner, callsite, callee_attrs)?;
+    inliner.check_codegen_attributes_extra(callee_attrs)?;
 
     let terminator = caller_body[callsite.block].terminator.as_ref().unwrap();
     let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() };
@@ -770,6 +788,8 @@ fn check_codegen_attributes<'tcx, I: Inliner<'tcx>>(
         return Err("has DoNotOptimize attribute");
     }
 
+    inliner.check_codegen_attributes_extra(callee_attrs)?;
+
     // Reachability pass defines which functions are eligible for inlining. Generally inlining
     // other functions is incorrect because they could reference symbols that aren't exported.
     let is_generic = callsite.callee.args.non_erasable_generics().next().is_some();
diff --git a/compiler/rustc_mir_transform/src/lint_tail_expr_drop_order.rs b/compiler/rustc_mir_transform/src/lint_tail_expr_drop_order.rs
index 7d77fffa83f..29a9133abe9 100644
--- a/compiler/rustc_mir_transform/src/lint_tail_expr_drop_order.rs
+++ b/compiler/rustc_mir_transform/src/lint_tail_expr_drop_order.rs
@@ -2,7 +2,7 @@ use std::cell::RefCell;
 use std::collections::hash_map;
 use std::rc::Rc;
 
-use rustc_data_structures::fx::{FxHashMap, FxHashSet};
+use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap};
 use rustc_data_structures::unord::{UnordMap, UnordSet};
 use rustc_errors::Subdiagnostic;
 use rustc_hir::CRATE_HIR_ID;
@@ -25,7 +25,6 @@ use rustc_mir_dataflow::{Analysis, MaybeReachable, ResultsCursor};
 use rustc_session::lint::builtin::TAIL_EXPR_DROP_ORDER;
 use rustc_session::lint::{self};
 use rustc_span::{DUMMY_SP, Span, Symbol};
-use rustc_type_ir::data_structures::IndexMap;
 use tracing::debug;
 
 fn place_has_common_prefix<'tcx>(left: &Place<'tcx>, right: &Place<'tcx>) -> bool {
@@ -199,7 +198,7 @@ pub(crate) fn run_lint<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId, body: &Body<
     // and, for each block, the vector of locations.
     //
     // We group them per-block because they tend to scheduled in the same drop ladder block.
-    let mut bid_per_block = IndexMap::default();
+    let mut bid_per_block = FxIndexMap::default();
     let mut bid_places = UnordSet::new();
 
     let mut ty_dropped_components = UnordMap::default();
@@ -455,8 +454,8 @@ pub(crate) fn run_lint<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId, body: &Body<
 }
 
 /// Extract binding names if available for diagnosis
-fn collect_user_names(body: &Body<'_>) -> IndexMap<Local, Symbol> {
-    let mut names = IndexMap::default();
+fn collect_user_names(body: &Body<'_>) -> FxIndexMap<Local, Symbol> {
+    let mut names = FxIndexMap::default();
     for var_debug_info in &body.var_debug_info {
         if let mir::VarDebugInfoContents::Place(place) = &var_debug_info.value
             && let Some(local) = place.local_or_deref_local()
@@ -470,9 +469,9 @@ fn collect_user_names(body: &Body<'_>) -> IndexMap<Local, Symbol> {
 /// Assign names for anonymous or temporary values for diagnosis
 fn assign_observables_names(
     locals: impl IntoIterator<Item = Local>,
-    user_names: &IndexMap<Local, Symbol>,
-) -> IndexMap<Local, (String, bool)> {
-    let mut names = IndexMap::default();
+    user_names: &FxIndexMap<Local, Symbol>,
+) -> FxIndexMap<Local, (String, bool)> {
+    let mut names = FxIndexMap::default();
     let mut assigned_names = FxHashSet::default();
     let mut idx = 0u64;
     let mut fresh_name = || {
diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs
index 9db37bf5a07..0d9d0368d37 100644
--- a/compiler/rustc_mir_transform/src/match_branches.rs
+++ b/compiler/rustc_mir_transform/src/match_branches.rs
@@ -5,7 +5,6 @@ use rustc_index::IndexSlice;
 use rustc_middle::mir::*;
 use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
 use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
-use rustc_type_ir::TyKind::*;
 use tracing::instrument;
 
 use super::simplify::simplify_cfg;
@@ -293,13 +292,13 @@ fn can_cast(
 ) -> bool {
     let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
     let v = match src_layout.ty.kind() {
-        Uint(_) => from_scalar.to_uint(src_layout.size),
-        Int(_) => from_scalar.to_int(src_layout.size) as u128,
+        ty::Uint(_) => from_scalar.to_uint(src_layout.size),
+        ty::Int(_) => from_scalar.to_int(src_layout.size) as u128,
         _ => unreachable!("invalid int"),
     };
     let size = match *cast_ty.kind() {
-        Int(t) => Integer::from_int_ty(&tcx, t).size(),
-        Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
+        ty::Int(t) => Integer::from_int_ty(&tcx, t).size(),
+        ty::Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
         _ => unreachable!("invalid int"),
     };
     let v = size.truncate(v);
diff --git a/compiler/rustc_mir_transform/src/validate.rs b/compiler/rustc_mir_transform/src/validate.rs
index 4ac3a268c9c..231d7c2ef02 100644
--- a/compiler/rustc_mir_transform/src/validate.rs
+++ b/compiler/rustc_mir_transform/src/validate.rs
@@ -13,11 +13,10 @@ use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
 use rustc_middle::mir::*;
 use rustc_middle::ty::adjustment::PointerCoercion;
 use rustc_middle::ty::{
-    self, CoroutineArgsExt, InstanceKind, ScalarInt, Ty, TyCtxt, TypeVisitableExt, Variance,
+    self, CoroutineArgsExt, InstanceKind, ScalarInt, Ty, TyCtxt, TypeVisitableExt, Upcast, Variance,
 };
 use rustc_middle::{bug, span_bug};
 use rustc_trait_selection::traits::ObligationCtxt;
-use rustc_type_ir::Upcast;
 
 use crate::util::{self, is_within_packed};