diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
| -rw-r--r-- | compiler/rustc_mir_transform/src/cost_checker.rs | 31 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/mappings.rs | 12 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/mod.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/spans.rs | 5 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs | 56 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/coverage/unexpand.rs | 60 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/elaborate_drops.rs | 26 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/gvn.rs | 12 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline.rs | 52 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline/cycle.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/lib.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/promote_consts.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/shim.rs | 49 |
13 files changed, 201 insertions, 109 deletions
diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs index 7e401b5482f..3333bebff3a 100644 --- a/compiler/rustc_mir_transform/src/cost_checker.rs +++ b/compiler/rustc_mir_transform/src/cost_checker.rs @@ -31,6 +31,37 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> { CostChecker { tcx, param_env, callee_body, instance, penalty: 0, bonus: 0 } } + /// Add function-level costs not well-represented by the block-level costs. + /// + /// Needed because the `CostChecker` is used sometimes for just blocks, + /// and even the full `Inline` doesn't call `visit_body`, so there's nowhere + /// to put this logic in the visitor. + pub fn add_function_level_costs(&mut self) { + fn is_call_like(bbd: &BasicBlockData<'_>) -> bool { + use TerminatorKind::*; + match bbd.terminator().kind { + Call { .. } | Drop { .. } | Assert { .. } | InlineAsm { .. } => true, + + Goto { .. } + | SwitchInt { .. } + | UnwindResume + | UnwindTerminate(_) + | Return + | Unreachable => false, + + Yield { .. } | CoroutineDrop | FalseEdge { .. } | FalseUnwind { .. } => { + unreachable!() + } + } + } + + // If the only has one Call (or similar), inlining isn't increasing the total + // number of calls, so give extra encouragement to inlining that. + if self.callee_body.basic_blocks.iter().filter(|bbd| is_call_like(bbd)).count() == 1 { + self.bonus += CALL_PENALTY; + } + } + pub fn cost(&self) -> usize { usize::saturating_sub(self.penalty, self.bonus) } diff --git a/compiler/rustc_mir_transform/src/coverage/mappings.rs b/compiler/rustc_mir_transform/src/coverage/mappings.rs index 759bb7c1f9d..235992ac547 100644 --- a/compiler/rustc_mir_transform/src/coverage/mappings.rs +++ b/compiler/rustc_mir_transform/src/coverage/mappings.rs @@ -9,9 +9,8 @@ use rustc_middle::ty::TyCtxt; use rustc_span::Span; use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; -use crate::coverage::spans::{ - extract_refined_covspans, unexpand_into_body_span_with_visible_macro, -}; +use crate::coverage::spans::extract_refined_covspans; +use crate::coverage::unexpand::unexpand_into_body_span; use crate::coverage::ExtractedHirInfo; /// Associates an ordinary executable code span with its corresponding BCB. @@ -202,8 +201,7 @@ pub(super) fn extract_branch_pairs( if !raw_span.ctxt().outer_expn_data().is_root() { return None; } - let (span, _) = - unexpand_into_body_span_with_visible_macro(raw_span, hir_info.body_span)?; + let span = unexpand_into_body_span(raw_span, hir_info.body_span)?; let bcb_from_marker = |marker: BlockMarkerId| basic_coverage_blocks.bcb_from_bb(block_markers[marker]?); @@ -238,7 +236,7 @@ pub(super) fn extract_mcdc_mappings( if !raw_span.ctxt().outer_expn_data().is_root() { return None; } - let (span, _) = unexpand_into_body_span_with_visible_macro(raw_span, body_span)?; + let span = unexpand_into_body_span(raw_span, body_span)?; let true_bcb = bcb_from_marker(true_marker)?; let false_bcb = bcb_from_marker(false_marker)?; @@ -261,7 +259,7 @@ pub(super) fn extract_mcdc_mappings( mcdc_decisions.extend(branch_info.mcdc_decision_spans.iter().filter_map( |decision: &mir::coverage::MCDCDecisionSpan| { - let (span, _) = unexpand_into_body_span_with_visible_macro(decision.span, body_span)?; + let span = unexpand_into_body_span(decision.span, body_span)?; let end_bcbs = decision .end_markers diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index 4a64d21f3d1..d55bde311c1 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -6,6 +6,7 @@ mod mappings; mod spans; #[cfg(test)] mod tests; +mod unexpand; use rustc_middle::mir::coverage::{ CodeRegion, CoverageKind, DecisionInfo, FunctionCoverageInfo, Mapping, MappingKind, diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index 84a70d1f02d..7612c01c52e 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -14,11 +14,6 @@ use crate::coverage::ExtractedHirInfo; mod from_mir; -// FIXME(#124545) It's awkward that we have to re-export this, because it's an -// internal detail of `from_mir` that is also needed when handling branch and -// MC/DC spans. Ideally we would find a more natural home for it. -pub(super) use from_mir::unexpand_into_body_span_with_visible_macro; - pub(super) fn extract_refined_covspans( mir_body: &mir::Body<'_>, hir_info: &ExtractedHirInfo, diff --git a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs index 09deb7534bf..2ca166929ee 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs @@ -4,12 +4,13 @@ use rustc_middle::mir::{ self, AggregateKind, FakeReadCause, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, }; -use rustc_span::{ExpnKind, MacroKind, Span, Symbol}; +use rustc_span::{Span, Symbol}; use crate::coverage::graph::{ BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph, START_BCB, }; use crate::coverage::spans::Covspan; +use crate::coverage::unexpand::unexpand_into_body_span_with_visible_macro; use crate::coverage::ExtractedHirInfo; pub(crate) struct ExtractedCovspans { @@ -215,59 +216,6 @@ fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Span> { } } -/// Returns an extrapolated span (pre-expansion[^1]) corresponding to a range -/// within the function's body source. This span is guaranteed to be contained -/// within, or equal to, the `body_span`. If the extrapolated span is not -/// contained within the `body_span`, `None` is returned. -/// -/// [^1]Expansions result from Rust syntax including macros, syntactic sugar, -/// etc.). -pub(crate) fn unexpand_into_body_span_with_visible_macro( - original_span: Span, - body_span: Span, -) -> Option<(Span, Option<Symbol>)> { - let (span, prev) = unexpand_into_body_span_with_prev(original_span, body_span)?; - - let visible_macro = prev - .map(|prev| match prev.ctxt().outer_expn_data().kind { - ExpnKind::Macro(MacroKind::Bang, name) => Some(name), - _ => None, - }) - .flatten(); - - Some((span, visible_macro)) -} - -/// Walks through the expansion ancestors of `original_span` to find a span that -/// is contained in `body_span` and has the same [`SyntaxContext`] as `body_span`. -/// The ancestor that was traversed just before the matching span (if any) is -/// also returned. -/// -/// For example, a return value of `Some((ancestor, Some(prev))` means that: -/// - `ancestor == original_span.find_ancestor_inside_same_ctxt(body_span)` -/// - `ancestor == prev.parent_callsite()` -/// -/// [`SyntaxContext`]: rustc_span::SyntaxContext -fn unexpand_into_body_span_with_prev( - original_span: Span, - body_span: Span, -) -> Option<(Span, Option<Span>)> { - let mut prev = None; - let mut curr = original_span; - - while !body_span.contains(curr) || !curr.eq_ctxt(body_span) { - prev = Some(curr); - curr = curr.parent_callsite()?; - } - - debug_assert_eq!(Some(curr), original_span.find_ancestor_in_same_ctxt(body_span)); - if let Some(prev) = prev { - debug_assert_eq!(Some(curr), prev.parent_callsite()); - } - - Some((curr, prev)) -} - #[derive(Debug)] pub(crate) struct Hole { pub(crate) span: Span, diff --git a/compiler/rustc_mir_transform/src/coverage/unexpand.rs b/compiler/rustc_mir_transform/src/coverage/unexpand.rs new file mode 100644 index 00000000000..8cde291b907 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/unexpand.rs @@ -0,0 +1,60 @@ +use rustc_span::{ExpnKind, MacroKind, Span, Symbol}; + +/// Walks through the expansion ancestors of `original_span` to find a span that +/// is contained in `body_span` and has the same [syntax context] as `body_span`. +pub(crate) fn unexpand_into_body_span(original_span: Span, body_span: Span) -> Option<Span> { + // Because we don't need to return any extra ancestor information, + // we can just delegate directly to `find_ancestor_inside_same_ctxt`. + original_span.find_ancestor_inside_same_ctxt(body_span) +} + +/// Walks through the expansion ancestors of `original_span` to find a span that +/// is contained in `body_span` and has the same [syntax context] as `body_span`. +/// +/// If the returned span represents a bang-macro invocation (e.g. `foo!(..)`), +/// the returned symbol will be the name of that macro (e.g. `foo`). +pub(crate) fn unexpand_into_body_span_with_visible_macro( + original_span: Span, + body_span: Span, +) -> Option<(Span, Option<Symbol>)> { + let (span, prev) = unexpand_into_body_span_with_prev(original_span, body_span)?; + + let visible_macro = prev + .map(|prev| match prev.ctxt().outer_expn_data().kind { + ExpnKind::Macro(MacroKind::Bang, name) => Some(name), + _ => None, + }) + .flatten(); + + Some((span, visible_macro)) +} + +/// Walks through the expansion ancestors of `original_span` to find a span that +/// is contained in `body_span` and has the same [syntax context] as `body_span`. +/// The ancestor that was traversed just before the matching span (if any) is +/// also returned. +/// +/// For example, a return value of `Some((ancestor, Some(prev)))` means that: +/// - `ancestor == original_span.find_ancestor_inside_same_ctxt(body_span)` +/// - `prev.parent_callsite() == ancestor` +/// +/// [syntax context]: rustc_span::SyntaxContext +fn unexpand_into_body_span_with_prev( + original_span: Span, + body_span: Span, +) -> Option<(Span, Option<Span>)> { + let mut prev = None; + let mut curr = original_span; + + while !body_span.contains(curr) || !curr.eq_ctxt(body_span) { + prev = Some(curr); + curr = curr.parent_callsite()?; + } + + debug_assert_eq!(Some(curr), original_span.find_ancestor_inside_same_ctxt(body_span)); + if let Some(prev) = prev { + debug_assert_eq!(Some(curr), prev.parent_callsite()); + } + + Some((curr, prev)) +} diff --git a/compiler/rustc_mir_transform/src/elaborate_drops.rs b/compiler/rustc_mir_transform/src/elaborate_drops.rs index 665b2260294..fbbb8c5e472 100644 --- a/compiler/rustc_mir_transform/src/elaborate_drops.rs +++ b/compiler/rustc_mir_transform/src/elaborate_drops.rs @@ -97,7 +97,7 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops { #[instrument(level = "trace", skip(body, flow_inits), ret)] fn compute_dead_unwinds<'mir, 'tcx>( body: &'mir Body<'tcx>, - flow_inits: &mut ResultsCursor<'mir, 'tcx, MaybeInitializedPlaces<'mir, 'tcx>>, + flow_inits: &mut ResultsCursor<'mir, 'tcx, MaybeInitializedPlaces<'_, 'mir, 'tcx>>, ) -> BitSet<BasicBlock> { // We only need to do this pass once, because unwind edges can only // reach cleanup blocks, which can't have unwind edges themselves. @@ -118,12 +118,12 @@ fn compute_dead_unwinds<'mir, 'tcx>( dead_unwinds } -struct InitializationData<'mir, 'tcx> { - inits: ResultsCursor<'mir, 'tcx, MaybeInitializedPlaces<'mir, 'tcx>>, - uninits: ResultsCursor<'mir, 'tcx, MaybeUninitializedPlaces<'mir, 'tcx>>, +struct InitializationData<'a, 'mir, 'tcx> { + inits: ResultsCursor<'mir, 'tcx, MaybeInitializedPlaces<'a, 'mir, 'tcx>>, + uninits: ResultsCursor<'mir, 'tcx, MaybeUninitializedPlaces<'a, 'mir, 'tcx>>, } -impl InitializationData<'_, '_> { +impl InitializationData<'_, '_, '_> { fn seek_before(&mut self, loc: Location) { self.inits.seek_before_primary_effect(loc); self.uninits.seek_before_primary_effect(loc); @@ -134,17 +134,17 @@ impl InitializationData<'_, '_> { } } -struct Elaborator<'a, 'b, 'tcx> { - ctxt: &'a mut ElaborateDropsCtxt<'b, 'tcx>, +struct Elaborator<'a, 'b, 'mir, 'tcx> { + ctxt: &'a mut ElaborateDropsCtxt<'b, 'mir, 'tcx>, } -impl fmt::Debug for Elaborator<'_, '_, '_> { +impl fmt::Debug for Elaborator<'_, '_, '_, '_> { fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { Ok(()) } } -impl<'a, 'tcx> DropElaborator<'a, 'tcx> for Elaborator<'a, '_, 'tcx> { +impl<'a, 'tcx> DropElaborator<'a, 'tcx> for Elaborator<'a, '_, '_, 'tcx> { type Path = MovePathIndex; fn patch(&mut self) -> &mut MirPatch<'tcx> { @@ -238,16 +238,16 @@ impl<'a, 'tcx> DropElaborator<'a, 'tcx> for Elaborator<'a, '_, 'tcx> { } } -struct ElaborateDropsCtxt<'a, 'tcx> { +struct ElaborateDropsCtxt<'a, 'mir, 'tcx> { tcx: TyCtxt<'tcx>, - body: &'a Body<'tcx>, + body: &'mir Body<'tcx>, env: &'a MoveDataParamEnv<'tcx>, - init_data: InitializationData<'a, 'tcx>, + init_data: InitializationData<'a, 'mir, 'tcx>, drop_flags: IndexVec<MovePathIndex, Option<Local>>, patch: MirPatch<'tcx>, } -impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { +impl<'b, 'mir, 'tcx> ElaborateDropsCtxt<'b, 'mir, 'tcx> { fn move_data(&self) -> &'b MoveData<'tcx> { &self.env.move_data } diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs index 936a7e2d9de..2b7d9be6d35 100644 --- a/compiler/rustc_mir_transform/src/gvn.rs +++ b/compiler/rustc_mir_transform/src/gvn.rs @@ -8,7 +8,7 @@ //! `Value` is interned as a `VnIndex`, which allows us to cheaply compute identical values. //! //! From those assignments, we construct a mapping `VnIndex -> Vec<(Local, Location)>` of available -//! values, the locals in which they are stored, and a the assignment location. +//! values, the locals in which they are stored, and the assignment location. //! //! In a second pass, we traverse all (non SSA) assignments `x = rvalue` and operands. For each //! one, we compute the `VnIndex` of the rvalue. If this `VnIndex` is associated to a constant, we @@ -1074,11 +1074,11 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { { lhs = *lhs_value; rhs = *rhs_value; - if let Some(op) = self.try_as_operand(lhs, location) { - *lhs_operand = op; - } - if let Some(op) = self.try_as_operand(rhs, location) { - *rhs_operand = op; + if let Some(lhs_op) = self.try_as_operand(lhs, location) + && let Some(rhs_op) = self.try_as_operand(rhs, location) + { + *lhs_operand = lhs_op; + *rhs_operand = rhs_op; } } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 07482d0571a..5075e072754 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -85,13 +85,18 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { } let param_env = tcx.param_env_reveal_all_normalized(def_id); + let codegen_fn_attrs = tcx.codegen_fn_attrs(def_id); let mut this = Inliner { tcx, param_env, - codegen_fn_attrs: tcx.codegen_fn_attrs(def_id), + codegen_fn_attrs, history: Vec::new(), changed: false, + caller_is_inline_forwarder: matches!( + codegen_fn_attrs.inline, + InlineAttr::Hint | InlineAttr::Always + ) && body_is_forwarder(body), }; let blocks = START_BLOCK..body.basic_blocks.next_index(); this.process_blocks(body, blocks); @@ -111,6 +116,9 @@ struct Inliner<'tcx> { history: Vec<DefId>, /// Indicates that the caller body has been modified. changed: bool, + /// Indicates that the caller is #[inline] and just calls another function, + /// and thus we can inline less into it as it'll be inlined itself. + caller_is_inline_forwarder: bool, } impl<'tcx> Inliner<'tcx> { @@ -381,7 +389,7 @@ impl<'tcx> Inliner<'tcx> { // To resolve an instance its args have to be fully normalized. let args = self.tcx.try_normalize_erasing_regions(self.param_env, args).ok()?; let callee = - Instance::resolve(self.tcx, self.param_env, def_id, args).ok().flatten()?; + Instance::try_resolve(self.tcx, self.param_env, def_id, args).ok().flatten()?; if let InstanceKind::Virtual(..) | InstanceKind::Intrinsic(_) = callee.def { return None; @@ -485,7 +493,9 @@ impl<'tcx> Inliner<'tcx> { ) -> Result<(), &'static str> { let tcx = self.tcx; - let mut threshold = if cross_crate_inlinable { + let mut threshold = if self.caller_is_inline_forwarder { + self.tcx.sess.opts.unstable_opts.inline_mir_forwarder_threshold.unwrap_or(30) + } else if cross_crate_inlinable { self.tcx.sess.opts.unstable_opts.inline_mir_hint_threshold.unwrap_or(100) } else { self.tcx.sess.opts.unstable_opts.inline_mir_threshold.unwrap_or(50) @@ -504,6 +514,8 @@ impl<'tcx> Inliner<'tcx> { let mut checker = CostChecker::new(self.tcx, self.param_env, Some(callsite.callee), callee_body); + checker.add_function_level_costs(); + // Traverse the MIR manually so we can account for the effects of inlining on the CFG. let mut work_list = vec![START_BLOCK]; let mut visited = BitSet::new_empty(callee_body.basic_blocks.len()); @@ -1091,3 +1103,37 @@ fn try_instance_mir<'tcx>( } Ok(tcx.instance_mir(instance)) } + +fn body_is_forwarder(body: &Body<'_>) -> bool { + let TerminatorKind::Call { target, .. } = body.basic_blocks[START_BLOCK].terminator().kind + else { + return false; + }; + if let Some(target) = target { + let TerminatorKind::Return = body.basic_blocks[target].terminator().kind else { + return false; + }; + } + + let max_blocks = if !body.is_polymorphic { + 2 + } else if target.is_none() { + 3 + } else { + 4 + }; + if body.basic_blocks.len() > max_blocks { + return false; + } + + body.basic_blocks.iter_enumerated().all(|(bb, bb_data)| { + bb == START_BLOCK + || matches!( + bb_data.terminator().kind, + TerminatorKind::Return + | TerminatorKind::Drop { .. } + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + ) + }) +} diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index 35bcd24ce95..d4477563e3a 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -53,7 +53,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( trace!(?caller, ?param_env, ?args, "cannot normalize, skipping"); continue; }; - let Ok(Some(callee)) = ty::Instance::resolve(tcx, param_env, callee, args) else { + let Ok(Some(callee)) = ty::Instance::try_resolve(tcx, param_env, callee, args) else { trace!(?callee, "cannot resolve, skipping"); continue; }; diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index f7056702cb4..5d253d7384d 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -519,7 +519,7 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &add_subtyping_projections::Subtyper, // calling this after reveal_all ensures that we don't deal with opaque types &elaborate_drops::ElaborateDrops, // This will remove extraneous landing pads which are no longer - // necessary as well as well as forcing any call in a non-unwinding + // necessary as well as forcing any call in a non-unwinding // function calling a possibly-unwinding function to abort the process. &abort_unwinding_calls::AbortUnwindingCalls, // AddMovesForPackedDrops needs to run after drop diff --git a/compiler/rustc_mir_transform/src/promote_consts.rs b/compiler/rustc_mir_transform/src/promote_consts.rs index 3f4d2b65ff2..736647fb64b 100644 --- a/compiler/rustc_mir_transform/src/promote_consts.rs +++ b/compiler/rustc_mir_transform/src/promote_consts.rs @@ -816,7 +816,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { mut func, mut args, call_source: desugar, fn_span, .. } => { // This promoted involves a function call, so it may fail to evaluate. - // Let's make sure it is added to `required_consts` so that that failure cannot get lost. + // Let's make sure it is added to `required_consts` so that failure cannot get lost. self.add_to_required = true; self.visit_operand(&mut func, loc); diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 25577e88e28..6835a39cf36 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -1,18 +1,17 @@ use rustc_hir as hir; use rustc_hir::def_id::DefId; use rustc_hir::lang_items::LangItem; +use rustc_index::{Idx, IndexVec}; use rustc_middle::mir::*; use rustc_middle::query::Providers; use rustc_middle::ty::GenericArgs; use rustc_middle::ty::{self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; -use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; - -use rustc_index::{Idx, IndexVec}; - use rustc_span::{source_map::Spanned, Span, DUMMY_SP}; +use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; use rustc_target::spec::abi::Abi; +use std::assert_matches::assert_matches; use std::fmt; use std::iter; @@ -1020,21 +1019,19 @@ fn build_construct_coroutine_by_move_shim<'tcx>( receiver_by_ref: bool, ) -> Body<'tcx> { let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let mut self_local: Place<'tcx> = Local::from_usize(1).into(); let ty::CoroutineClosure(_, args) = *self_ty.kind() else { bug!(); }; - // We use `&mut Self` here because we only need to emit an ABI-compatible shim body, - // rather than match the signature exactly (which might take `&self` instead). + // We use `&Self` here because we only need to emit an ABI-compatible shim body, + // rather than match the signature exactly (which might take `&mut self` instead). // - // The self type here is a coroutine-closure, not a coroutine, and we never read from - // it because it never has any captures, because this is only true in the Fn/FnMut - // implementation, not the AsyncFn/AsyncFnMut implementation, which is implemented only - // if the coroutine-closure has no captures. + // We adjust the `self_local` to be a deref since we want to copy fields out of + // a reference to the closure. if receiver_by_ref { - // Triple-check that there's no captures here. - assert_eq!(args.as_coroutine_closure().tupled_upvars_ty(), tcx.types.unit); - self_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, self_ty); + self_local = tcx.mk_place_deref(self_local); + self_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, self_ty); } let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { @@ -1067,11 +1064,27 @@ fn build_construct_coroutine_by_move_shim<'tcx>( fields.push(Operand::Move(Local::from_usize(idx + 1).into())); } for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() { - fields.push(Operand::Move(tcx.mk_place_field( - Local::from_usize(1).into(), - FieldIdx::from_usize(idx), - ty, - ))); + if receiver_by_ref { + // The only situation where it's possible is when we capture immuatable references, + // since those don't need to be reborrowed with the closure's env lifetime. Since + // references are always `Copy`, just emit a copy. + assert_matches!( + ty.kind(), + ty::Ref(_, _, hir::Mutability::Not), + "field should be captured by immutable ref if we have an `Fn` instance" + ); + fields.push(Operand::Copy(tcx.mk_place_field( + self_local, + FieldIdx::from_usize(idx), + ty, + ))); + } else { + fields.push(Operand::Move(tcx.mk_place_field( + self_local, + FieldIdx::from_usize(idx), + ty, + ))); + } } let source_info = SourceInfo::outermost(span); |
