diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
17 files changed, 325 insertions, 187 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index 05674792426..9d44001f915 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -58,7 +58,7 @@ use crate::deref_separator::deref_finder; use crate::errors; use crate::pass_manager as pm; use crate::simplify; -use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_data_structures::fx::FxHashSet; use rustc_errors::pluralize; use rustc_hir as hir; use rustc_hir::lang_items::LangItem; @@ -208,11 +208,8 @@ const UNRESUMED: usize = CoroutineArgs::UNRESUMED; const RETURNED: usize = CoroutineArgs::RETURNED; /// Coroutine has panicked and is poisoned. const POISONED: usize = CoroutineArgs::POISONED; - -/// Number of variants to reserve in coroutine state. Corresponds to -/// `UNRESUMED` (beginning of a coroutine) and `RETURNED`/`POISONED` -/// (end of a coroutine) states. -const RESERVED_VARIANTS: usize = 3; +/// Number of reserved variants of coroutine state. +const RESERVED_VARIANTS: usize = CoroutineArgs::RESERVED_VARIANTS; /// A `yield` point in the coroutine. struct SuspensionPoint<'tcx> { @@ -236,8 +233,7 @@ struct TransformVisitor<'tcx> { discr_ty: Ty<'tcx>, // Mapping from Local to (type of local, coroutine struct index) - // FIXME(eddyb) This should use `IndexVec<Local, Option<_>>`. - remap: FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>, + remap: IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>, // A map from a suspension point in a block to the locals which have live storage at that point storage_liveness: IndexVec<BasicBlock, Option<BitSet<Local>>>, @@ -485,7 +481,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { } fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { - assert_eq!(self.remap.get(local), None); + assert!(!self.remap.contains(*local)); } fn visit_place( @@ -495,7 +491,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { _location: Location, ) { // Replace an Local in the remap with a coroutine struct access - if let Some(&(ty, variant_index, idx)) = self.remap.get(&place.local) { + if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) { replace_base(place, self.make_field(variant_index, idx, ty), self.tcx); } } @@ -504,7 +500,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { // Remove StorageLive and StorageDead statements for remapped locals data.retain_statements(|s| match s.kind { StatementKind::StorageLive(l) | StatementKind::StorageDead(l) => { - !self.remap.contains_key(&l) + !self.remap.contains(l) } _ => true, }); @@ -529,13 +525,9 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { // The resume arg target location might itself be remapped if its base local is // live across a yield. - let resume_arg = - if let Some(&(ty, variant, idx)) = self.remap.get(&resume_arg.local) { - replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx); - resume_arg - } else { - resume_arg - }; + if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) { + replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx); + } let storage_liveness: GrowableBitSet<Local> = self.storage_liveness[block].clone().unwrap().into(); @@ -543,7 +535,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { for i in 0..self.always_live_locals.domain_size() { let l = Local::new(i); let needs_storage_dead = storage_liveness.contains(l) - && !self.remap.contains_key(&l) + && !self.remap.contains(l) && !self.always_live_locals.contains(l); if needs_storage_dead { data.statements @@ -1037,7 +1029,7 @@ fn compute_layout<'tcx>( liveness: LivenessInfo, body: &Body<'tcx>, ) -> ( - FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>, + IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>, CoroutineLayout<'tcx>, IndexVec<BasicBlock, Option<BitSet<Local>>>, ) { @@ -1098,7 +1090,7 @@ fn compute_layout<'tcx>( // Create a map from local indices to coroutine struct indices. let mut variant_fields: IndexVec<VariantIdx, IndexVec<FieldIdx, CoroutineSavedLocal>> = iter::repeat(IndexVec::new()).take(RESERVED_VARIANTS).collect(); - let mut remap = FxHashMap::default(); + let mut remap = IndexVec::from_elem_n(None, saved_locals.domain_size()); for (suspension_point_idx, live_locals) in live_locals_at_suspension_points.iter().enumerate() { let variant_index = VariantIdx::from(RESERVED_VARIANTS + suspension_point_idx); let mut fields = IndexVec::new(); @@ -1109,7 +1101,7 @@ fn compute_layout<'tcx>( // around inside coroutines, so it doesn't matter which variant // index we access them by. let idx = FieldIdx::from_usize(idx); - remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx)); + remap[locals[saved_local]] = Some((tys[saved_local].ty, variant_index, idx)); } variant_fields.push(fields); variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]); @@ -1121,7 +1113,9 @@ fn compute_layout<'tcx>( for var in &body.var_debug_info { let VarDebugInfoContents::Place(place) = &var.value else { continue }; let Some(local) = place.as_local() else { continue }; - let Some(&(_, variant, field)) = remap.get(&local) else { continue }; + let Some(&Some((_, variant, field))) = remap.get(local) else { + continue; + }; let saved_local = variant_fields[variant][field]; field_names.get_or_insert_with(saved_local, || var.name); @@ -1524,7 +1518,7 @@ fn create_cases<'tcx>( for i in 0..(body.local_decls.len()) { let l = Local::new(i); let needs_storage_live = point.storage_liveness.contains(l) - && !transform.remap.contains_key(&l) + && !transform.remap.contains(l) && !transform.always_live_locals.contains(l); if needs_storage_live { statements diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs index 7e401b5482f..3333bebff3a 100644 --- a/compiler/rustc_mir_transform/src/cost_checker.rs +++ b/compiler/rustc_mir_transform/src/cost_checker.rs @@ -31,6 +31,37 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> { CostChecker { tcx, param_env, callee_body, instance, penalty: 0, bonus: 0 } } + /// Add function-level costs not well-represented by the block-level costs. + /// + /// Needed because the `CostChecker` is used sometimes for just blocks, + /// and even the full `Inline` doesn't call `visit_body`, so there's nowhere + /// to put this logic in the visitor. + pub fn add_function_level_costs(&mut self) { + fn is_call_like(bbd: &BasicBlockData<'_>) -> bool { + use TerminatorKind::*; + match bbd.terminator().kind { + Call { .. } | Drop { .. } | Assert { .. } | InlineAsm { .. } => true, + + Goto { .. } + | SwitchInt { .. } + | UnwindResume + | UnwindTerminate(_) + | Return + | Unreachable => false, + + Yield { .. } | CoroutineDrop | FalseEdge { .. } | FalseUnwind { .. } => { + unreachable!() + } + } + } + + // If the only has one Call (or similar), inlining isn't increasing the total + // number of calls, so give extra encouragement to inlining that. + if self.callee_body.basic_blocks.iter().filter(|bbd| is_call_like(bbd)).count() == 1 { + self.bonus += CALL_PENALTY; + } + } + pub fn cost(&self) -> usize { usize::saturating_sub(self.penalty, self.bonus) } diff --git a/compiler/rustc_mir_transform/src/coverage/mappings.rs b/compiler/rustc_mir_transform/src/coverage/mappings.rs index 759bb7c1f9d..235992ac547 100644 --- a/compiler/rustc_mir_transform/src/coverage/mappings.rs +++ b/compiler/rustc_mir_transform/src/coverage/mappings.rs @@ -9,9 +9,8 @@ use rustc_middle::ty::TyCtxt; use rustc_span::Span; use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; -use crate::coverage::spans::{ - extract_refined_covspans, unexpand_into_body_span_with_visible_macro, -}; +use crate::coverage::spans::extract_refined_covspans; +use crate::coverage::unexpand::unexpand_into_body_span; use crate::coverage::ExtractedHirInfo; /// Associates an ordinary executable code span with its corresponding BCB. @@ -202,8 +201,7 @@ pub(super) fn extract_branch_pairs( if !raw_span.ctxt().outer_expn_data().is_root() { return None; } - let (span, _) = - unexpand_into_body_span_with_visible_macro(raw_span, hir_info.body_span)?; + let span = unexpand_into_body_span(raw_span, hir_info.body_span)?; let bcb_from_marker = |marker: BlockMarkerId| basic_coverage_blocks.bcb_from_bb(block_markers[marker]?); @@ -238,7 +236,7 @@ pub(super) fn extract_mcdc_mappings( if !raw_span.ctxt().outer_expn_data().is_root() { return None; } - let (span, _) = unexpand_into_body_span_with_visible_macro(raw_span, body_span)?; + let span = unexpand_into_body_span(raw_span, body_span)?; let true_bcb = bcb_from_marker(true_marker)?; let false_bcb = bcb_from_marker(false_marker)?; @@ -261,7 +259,7 @@ pub(super) fn extract_mcdc_mappings( mcdc_decisions.extend(branch_info.mcdc_decision_spans.iter().filter_map( |decision: &mir::coverage::MCDCDecisionSpan| { - let (span, _) = unexpand_into_body_span_with_visible_macro(decision.span, body_span)?; + let span = unexpand_into_body_span(decision.span, body_span)?; let end_bcbs = decision .end_markers diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index 4a64d21f3d1..d55bde311c1 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -6,6 +6,7 @@ mod mappings; mod spans; #[cfg(test)] mod tests; +mod unexpand; use rustc_middle::mir::coverage::{ CodeRegion, CoverageKind, DecisionInfo, FunctionCoverageInfo, Mapping, MappingKind, diff --git a/compiler/rustc_mir_transform/src/coverage/query.rs b/compiler/rustc_mir_transform/src/coverage/query.rs index 25744009be8..1fce2abbbbf 100644 --- a/compiler/rustc_mir_transform/src/coverage/query.rs +++ b/compiler/rustc_mir_transform/src/coverage/query.rs @@ -6,11 +6,13 @@ use rustc_middle::query::TyCtxtAt; use rustc_middle::ty::{self, TyCtxt}; use rustc_middle::util::Providers; use rustc_span::def_id::LocalDefId; +use rustc_span::sym; /// Registers query/hook implementations related to coverage. pub(crate) fn provide(providers: &mut Providers) { providers.hooks.is_eligible_for_coverage = |TyCtxtAt { tcx, .. }, def_id| is_eligible_for_coverage(tcx, def_id); + providers.queries.coverage_attr_on = coverage_attr_on; providers.queries.coverage_ids_info = coverage_ids_info; } @@ -38,7 +40,12 @@ fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { return false; } - if tcx.codegen_fn_attrs(def_id).flags.contains(CodegenFnAttrFlags::NO_COVERAGE) { + if tcx.codegen_fn_attrs(def_id).flags.contains(CodegenFnAttrFlags::NAKED) { + trace!("InstrumentCoverage skipped for {def_id:?} (`#[naked]`)"); + return false; + } + + if !tcx.coverage_attr_on(def_id) { trace!("InstrumentCoverage skipped for {def_id:?} (`#[coverage(off)]`)"); return false; } @@ -46,6 +53,30 @@ fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { true } +/// Query implementation for `coverage_attr_on`. +fn coverage_attr_on(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { + // Check for annotations directly on this def. + if let Some(attr) = tcx.get_attr(def_id, sym::coverage) { + match attr.meta_item_list().as_deref() { + Some([item]) if item.has_name(sym::off) => return false, + Some([item]) if item.has_name(sym::on) => return true, + Some(_) | None => { + // Other possibilities should have been rejected by `rustc_parse::validate_attr`. + tcx.dcx().span_bug(attr.span, "unexpected value of coverage attribute"); + } + } + } + + match tcx.opt_local_parent(def_id) { + // Check the parent def (and so on recursively) until we find an + // enclosing attribute or reach the crate root. + Some(parent) => tcx.coverage_attr_on(parent), + // We reached the crate root without seeing a coverage attribute, so + // allow coverage instrumentation by default. + None => true, + } +} + /// Query implementation for `coverage_ids_info`. fn coverage_ids_info<'tcx>( tcx: TyCtxt<'tcx>, diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index 84a70d1f02d..7612c01c52e 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -14,11 +14,6 @@ use crate::coverage::ExtractedHirInfo; mod from_mir; -// FIXME(#124545) It's awkward that we have to re-export this, because it's an -// internal detail of `from_mir` that is also needed when handling branch and -// MC/DC spans. Ideally we would find a more natural home for it. -pub(super) use from_mir::unexpand_into_body_span_with_visible_macro; - pub(super) fn extract_refined_covspans( mir_body: &mir::Body<'_>, hir_info: &ExtractedHirInfo, diff --git a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs index 09deb7534bf..2ca166929ee 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs @@ -4,12 +4,13 @@ use rustc_middle::mir::{ self, AggregateKind, FakeReadCause, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, }; -use rustc_span::{ExpnKind, MacroKind, Span, Symbol}; +use rustc_span::{Span, Symbol}; use crate::coverage::graph::{ BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph, START_BCB, }; use crate::coverage::spans::Covspan; +use crate::coverage::unexpand::unexpand_into_body_span_with_visible_macro; use crate::coverage::ExtractedHirInfo; pub(crate) struct ExtractedCovspans { @@ -215,59 +216,6 @@ fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Span> { } } -/// Returns an extrapolated span (pre-expansion[^1]) corresponding to a range -/// within the function's body source. This span is guaranteed to be contained -/// within, or equal to, the `body_span`. If the extrapolated span is not -/// contained within the `body_span`, `None` is returned. -/// -/// [^1]Expansions result from Rust syntax including macros, syntactic sugar, -/// etc.). -pub(crate) fn unexpand_into_body_span_with_visible_macro( - original_span: Span, - body_span: Span, -) -> Option<(Span, Option<Symbol>)> { - let (span, prev) = unexpand_into_body_span_with_prev(original_span, body_span)?; - - let visible_macro = prev - .map(|prev| match prev.ctxt().outer_expn_data().kind { - ExpnKind::Macro(MacroKind::Bang, name) => Some(name), - _ => None, - }) - .flatten(); - - Some((span, visible_macro)) -} - -/// Walks through the expansion ancestors of `original_span` to find a span that -/// is contained in `body_span` and has the same [`SyntaxContext`] as `body_span`. -/// The ancestor that was traversed just before the matching span (if any) is -/// also returned. -/// -/// For example, a return value of `Some((ancestor, Some(prev))` means that: -/// - `ancestor == original_span.find_ancestor_inside_same_ctxt(body_span)` -/// - `ancestor == prev.parent_callsite()` -/// -/// [`SyntaxContext`]: rustc_span::SyntaxContext -fn unexpand_into_body_span_with_prev( - original_span: Span, - body_span: Span, -) -> Option<(Span, Option<Span>)> { - let mut prev = None; - let mut curr = original_span; - - while !body_span.contains(curr) || !curr.eq_ctxt(body_span) { - prev = Some(curr); - curr = curr.parent_callsite()?; - } - - debug_assert_eq!(Some(curr), original_span.find_ancestor_in_same_ctxt(body_span)); - if let Some(prev) = prev { - debug_assert_eq!(Some(curr), prev.parent_callsite()); - } - - Some((curr, prev)) -} - #[derive(Debug)] pub(crate) struct Hole { pub(crate) span: Span, diff --git a/compiler/rustc_mir_transform/src/coverage/unexpand.rs b/compiler/rustc_mir_transform/src/coverage/unexpand.rs new file mode 100644 index 00000000000..8cde291b907 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/unexpand.rs @@ -0,0 +1,60 @@ +use rustc_span::{ExpnKind, MacroKind, Span, Symbol}; + +/// Walks through the expansion ancestors of `original_span` to find a span that +/// is contained in `body_span` and has the same [syntax context] as `body_span`. +pub(crate) fn unexpand_into_body_span(original_span: Span, body_span: Span) -> Option<Span> { + // Because we don't need to return any extra ancestor information, + // we can just delegate directly to `find_ancestor_inside_same_ctxt`. + original_span.find_ancestor_inside_same_ctxt(body_span) +} + +/// Walks through the expansion ancestors of `original_span` to find a span that +/// is contained in `body_span` and has the same [syntax context] as `body_span`. +/// +/// If the returned span represents a bang-macro invocation (e.g. `foo!(..)`), +/// the returned symbol will be the name of that macro (e.g. `foo`). +pub(crate) fn unexpand_into_body_span_with_visible_macro( + original_span: Span, + body_span: Span, +) -> Option<(Span, Option<Symbol>)> { + let (span, prev) = unexpand_into_body_span_with_prev(original_span, body_span)?; + + let visible_macro = prev + .map(|prev| match prev.ctxt().outer_expn_data().kind { + ExpnKind::Macro(MacroKind::Bang, name) => Some(name), + _ => None, + }) + .flatten(); + + Some((span, visible_macro)) +} + +/// Walks through the expansion ancestors of `original_span` to find a span that +/// is contained in `body_span` and has the same [syntax context] as `body_span`. +/// The ancestor that was traversed just before the matching span (if any) is +/// also returned. +/// +/// For example, a return value of `Some((ancestor, Some(prev)))` means that: +/// - `ancestor == original_span.find_ancestor_inside_same_ctxt(body_span)` +/// - `prev.parent_callsite() == ancestor` +/// +/// [syntax context]: rustc_span::SyntaxContext +fn unexpand_into_body_span_with_prev( + original_span: Span, + body_span: Span, +) -> Option<(Span, Option<Span>)> { + let mut prev = None; + let mut curr = original_span; + + while !body_span.contains(curr) || !curr.eq_ctxt(body_span) { + prev = Some(curr); + curr = curr.parent_callsite()?; + } + + debug_assert_eq!(Some(curr), original_span.find_ancestor_inside_same_ctxt(body_span)); + if let Some(prev) = prev { + debug_assert_eq!(Some(curr), prev.parent_callsite()); + } + + Some((curr, prev)) +} diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index 0fd85eb345d..8b965f4d18e 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -181,11 +181,6 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { state.insert_value_idx(value_target, val, self.map()); } if let Some(overflow_target) = overflow_target { - let overflow = match overflow { - FlatSet::Top => FlatSet::Top, - FlatSet::Elem(overflow) => FlatSet::Elem(overflow), - FlatSet::Bottom => FlatSet::Bottom, - }; // We have flooded `target` earlier. state.insert_value_idx(overflow_target, overflow, self.map()); } diff --git a/compiler/rustc_mir_transform/src/elaborate_drops.rs b/compiler/rustc_mir_transform/src/elaborate_drops.rs index 665b2260294..fbbb8c5e472 100644 --- a/compiler/rustc_mir_transform/src/elaborate_drops.rs +++ b/compiler/rustc_mir_transform/src/elaborate_drops.rs @@ -97,7 +97,7 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops { #[instrument(level = "trace", skip(body, flow_inits), ret)] fn compute_dead_unwinds<'mir, 'tcx>( body: &'mir Body<'tcx>, - flow_inits: &mut ResultsCursor<'mir, 'tcx, MaybeInitializedPlaces<'mir, 'tcx>>, + flow_inits: &mut ResultsCursor<'mir, 'tcx, MaybeInitializedPlaces<'_, 'mir, 'tcx>>, ) -> BitSet<BasicBlock> { // We only need to do this pass once, because unwind edges can only // reach cleanup blocks, which can't have unwind edges themselves. @@ -118,12 +118,12 @@ fn compute_dead_unwinds<'mir, 'tcx>( dead_unwinds } -struct InitializationData<'mir, 'tcx> { - inits: ResultsCursor<'mir, 'tcx, MaybeInitializedPlaces<'mir, 'tcx>>, - uninits: ResultsCursor<'mir, 'tcx, MaybeUninitializedPlaces<'mir, 'tcx>>, +struct InitializationData<'a, 'mir, 'tcx> { + inits: ResultsCursor<'mir, 'tcx, MaybeInitializedPlaces<'a, 'mir, 'tcx>>, + uninits: ResultsCursor<'mir, 'tcx, MaybeUninitializedPlaces<'a, 'mir, 'tcx>>, } -impl InitializationData<'_, '_> { +impl InitializationData<'_, '_, '_> { fn seek_before(&mut self, loc: Location) { self.inits.seek_before_primary_effect(loc); self.uninits.seek_before_primary_effect(loc); @@ -134,17 +134,17 @@ impl InitializationData<'_, '_> { } } -struct Elaborator<'a, 'b, 'tcx> { - ctxt: &'a mut ElaborateDropsCtxt<'b, 'tcx>, +struct Elaborator<'a, 'b, 'mir, 'tcx> { + ctxt: &'a mut ElaborateDropsCtxt<'b, 'mir, 'tcx>, } -impl fmt::Debug for Elaborator<'_, '_, '_> { +impl fmt::Debug for Elaborator<'_, '_, '_, '_> { fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { Ok(()) } } -impl<'a, 'tcx> DropElaborator<'a, 'tcx> for Elaborator<'a, '_, 'tcx> { +impl<'a, 'tcx> DropElaborator<'a, 'tcx> for Elaborator<'a, '_, '_, 'tcx> { type Path = MovePathIndex; fn patch(&mut self) -> &mut MirPatch<'tcx> { @@ -238,16 +238,16 @@ impl<'a, 'tcx> DropElaborator<'a, 'tcx> for Elaborator<'a, '_, 'tcx> { } } -struct ElaborateDropsCtxt<'a, 'tcx> { +struct ElaborateDropsCtxt<'a, 'mir, 'tcx> { tcx: TyCtxt<'tcx>, - body: &'a Body<'tcx>, + body: &'mir Body<'tcx>, env: &'a MoveDataParamEnv<'tcx>, - init_data: InitializationData<'a, 'tcx>, + init_data: InitializationData<'a, 'mir, 'tcx>, drop_flags: IndexVec<MovePathIndex, Option<Local>>, patch: MirPatch<'tcx>, } -impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { +impl<'b, 'mir, 'tcx> ElaborateDropsCtxt<'b, 'mir, 'tcx> { fn move_data(&self) -> &'b MoveData<'tcx> { &self.env.move_data } diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs index 936a7e2d9de..2b7d9be6d35 100644 --- a/compiler/rustc_mir_transform/src/gvn.rs +++ b/compiler/rustc_mir_transform/src/gvn.rs @@ -8,7 +8,7 @@ //! `Value` is interned as a `VnIndex`, which allows us to cheaply compute identical values. //! //! From those assignments, we construct a mapping `VnIndex -> Vec<(Local, Location)>` of available -//! values, the locals in which they are stored, and a the assignment location. +//! values, the locals in which they are stored, and the assignment location. //! //! In a second pass, we traverse all (non SSA) assignments `x = rvalue` and operands. For each //! one, we compute the `VnIndex` of the rvalue. If this `VnIndex` is associated to a constant, we @@ -1074,11 +1074,11 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { { lhs = *lhs_value; rhs = *rhs_value; - if let Some(op) = self.try_as_operand(lhs, location) { - *lhs_operand = op; - } - if let Some(op) = self.try_as_operand(rhs, location) { - *rhs_operand = op; + if let Some(lhs_op) = self.try_as_operand(lhs, location) + && let Some(rhs_op) = self.try_as_operand(rhs, location) + { + *lhs_operand = lhs_op; + *rhs_operand = rhs_op; } } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 0a5fc697d03..5075e072754 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -11,7 +11,7 @@ use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs} use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::TypeVisitableExt; -use rustc_middle::ty::{self, Instance, InstanceKind, ParamEnv, Ty, TyCtxt}; +use rustc_middle::ty::{self, Instance, InstanceKind, ParamEnv, Ty, TyCtxt, TypeFlags}; use rustc_session::config::{DebugInfo, OptLevel}; use rustc_span::source_map::Spanned; use rustc_span::sym; @@ -85,13 +85,18 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { } let param_env = tcx.param_env_reveal_all_normalized(def_id); + let codegen_fn_attrs = tcx.codegen_fn_attrs(def_id); let mut this = Inliner { tcx, param_env, - codegen_fn_attrs: tcx.codegen_fn_attrs(def_id), + codegen_fn_attrs, history: Vec::new(), changed: false, + caller_is_inline_forwarder: matches!( + codegen_fn_attrs.inline, + InlineAttr::Hint | InlineAttr::Always + ) && body_is_forwarder(body), }; let blocks = START_BLOCK..body.basic_blocks.next_index(); this.process_blocks(body, blocks); @@ -111,6 +116,9 @@ struct Inliner<'tcx> { history: Vec<DefId>, /// Indicates that the caller body has been modified. changed: bool, + /// Indicates that the caller is #[inline] and just calls another function, + /// and thus we can inline less into it as it'll be inlined itself. + caller_is_inline_forwarder: bool, } impl<'tcx> Inliner<'tcx> { @@ -306,6 +314,16 @@ impl<'tcx> Inliner<'tcx> { InstanceKind::Intrinsic(_) | InstanceKind::Virtual(..) => { return Err("instance without MIR (intrinsic / virtual)"); } + + // FIXME(#127030): `ConstParamHasTy` has bad interactions with + // the drop shim builder, which does not evaluate predicates in + // the correct param-env for types being dropped. Stall resolving + // the MIR for this instance until all of its const params are + // substituted. + InstanceKind::DropGlue(_, Some(ty)) if ty.has_type_flags(TypeFlags::HAS_CT_PARAM) => { + return Err("still needs substitution"); + } + // This cannot result in an immediate cycle since the callee MIR is a shim, which does // not get any optimizations run on it. Any subsequent inlining may cause cycles, but we // do not need to catch this here, we can wait until the inliner decides to continue @@ -371,7 +389,7 @@ impl<'tcx> Inliner<'tcx> { // To resolve an instance its args have to be fully normalized. let args = self.tcx.try_normalize_erasing_regions(self.param_env, args).ok()?; let callee = - Instance::resolve(self.tcx, self.param_env, def_id, args).ok().flatten()?; + Instance::try_resolve(self.tcx, self.param_env, def_id, args).ok().flatten()?; if let InstanceKind::Virtual(..) | InstanceKind::Intrinsic(_) = callee.def { return None; @@ -475,7 +493,9 @@ impl<'tcx> Inliner<'tcx> { ) -> Result<(), &'static str> { let tcx = self.tcx; - let mut threshold = if cross_crate_inlinable { + let mut threshold = if self.caller_is_inline_forwarder { + self.tcx.sess.opts.unstable_opts.inline_mir_forwarder_threshold.unwrap_or(30) + } else if cross_crate_inlinable { self.tcx.sess.opts.unstable_opts.inline_mir_hint_threshold.unwrap_or(100) } else { self.tcx.sess.opts.unstable_opts.inline_mir_threshold.unwrap_or(50) @@ -494,6 +514,8 @@ impl<'tcx> Inliner<'tcx> { let mut checker = CostChecker::new(self.tcx, self.param_env, Some(callsite.callee), callee_body); + checker.add_function_level_costs(); + // Traverse the MIR manually so we can account for the effects of inlining on the CFG. let mut work_list = vec![START_BLOCK]; let mut visited = BitSet::new_empty(callee_body.basic_blocks.len()); @@ -1081,3 +1103,37 @@ fn try_instance_mir<'tcx>( } Ok(tcx.instance_mir(instance)) } + +fn body_is_forwarder(body: &Body<'_>) -> bool { + let TerminatorKind::Call { target, .. } = body.basic_blocks[START_BLOCK].terminator().kind + else { + return false; + }; + if let Some(target) = target { + let TerminatorKind::Return = body.basic_blocks[target].terminator().kind else { + return false; + }; + } + + let max_blocks = if !body.is_polymorphic { + 2 + } else if target.is_none() { + 3 + } else { + 4 + }; + if body.basic_blocks.len() > max_blocks { + return false; + } + + body.basic_blocks.iter_enumerated().all(|(bb, bb_data)| { + bb == START_BLOCK + || matches!( + bb_data.terminator().kind, + TerminatorKind::Return + | TerminatorKind::Drop { .. } + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + ) + }) +} diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index 35bcd24ce95..d4477563e3a 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -53,7 +53,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( trace!(?caller, ?param_env, ?args, "cannot normalize, skipping"); continue; }; - let Ok(Some(callee)) = ty::Instance::resolve(tcx, param_env, callee, args) else { + let Ok(Some(callee)) = ty::Instance::try_resolve(tcx, param_env, callee, args) else { trace!(?callee, "cannot resolve, skipping"); continue; }; diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs index 23cc0c46e73..85a61f95ca3 100644 --- a/compiler/rustc_mir_transform/src/jump_threading.rs +++ b/compiler/rustc_mir_transform/src/jump_threading.rs @@ -47,6 +47,7 @@ use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; use rustc_middle::ty::layout::LayoutOf; use rustc_middle::ty::{self, ScalarInt, TyCtxt}; +use rustc_mir_dataflow::lattice::HasBottom; use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem}; use rustc_span::DUMMY_SP; use rustc_target::abi::{TagEncoding, Variants}; @@ -91,43 +92,8 @@ impl<'tcx> MirPass<'tcx> for JumpThreading { opportunities: Vec::new(), }; - for (bb, bbdata) in body.basic_blocks.iter_enumerated() { - debug!(?bb, term = ?bbdata.terminator()); - if bbdata.is_cleanup || loop_headers.contains(bb) { - continue; - } - let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { continue }; - let Some(discr) = discr.place() else { continue }; - debug!(?discr, ?bb); - - let discr_ty = discr.ty(body, tcx).ty; - let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue }; - - let Some(discr) = finder.map.find(discr.as_ref()) else { continue }; - debug!(?discr); - - let cost = CostChecker::new(tcx, param_env, None, body); - - let mut state = State::new(ConditionSet::default(), finder.map); - - let conds = if let Some((value, then, else_)) = targets.as_static_if() { - let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { - continue; - }; - arena.alloc_from_iter([ - Condition { value, polarity: Polarity::Eq, target: then }, - Condition { value, polarity: Polarity::Ne, target: else_ }, - ]) - } else { - arena.alloc_from_iter(targets.iter().filter_map(|(value, target)| { - let value = ScalarInt::try_from_uint(value, discr_layout.size)?; - Some(Condition { value, polarity: Polarity::Eq, target }) - })) - }; - let conds = ConditionSet(conds); - state.insert_value_idx(discr, conds, finder.map); - - finder.find_opportunity(bb, state, cost, 0); + for bb in body.basic_blocks.indices() { + finder.start_from_switch(bb); } let opportunities = finder.opportunities; @@ -193,9 +159,17 @@ impl Condition { } } -#[derive(Copy, Clone, Debug, Default)] +#[derive(Copy, Clone, Debug)] struct ConditionSet<'a>(&'a [Condition]); +impl HasBottom for ConditionSet<'_> { + const BOTTOM: Self = ConditionSet(&[]); + + fn is_bottom(&self) -> bool { + self.0.is_empty() + } +} + impl<'a> ConditionSet<'a> { fn iter(self) -> impl Iterator<Item = Condition> + 'a { self.0.iter().copied() @@ -212,10 +186,50 @@ impl<'a> ConditionSet<'a> { impl<'tcx, 'a> TOFinder<'tcx, 'a> { fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool { - state.all(|cs| cs.0.is_empty()) + state.all_bottom() } /// Recursion entry point to find threading opportunities. + #[instrument(level = "trace", skip(self))] + fn start_from_switch(&mut self, bb: BasicBlock) -> Option<!> { + let bbdata = &self.body[bb]; + if bbdata.is_cleanup || self.loop_headers.contains(bb) { + return None; + } + let (discr, targets) = bbdata.terminator().kind.as_switch()?; + let discr = discr.place()?; + debug!(?discr, ?bb); + + let discr_ty = discr.ty(self.body, self.tcx).ty; + let discr_layout = self.ecx.layout_of(discr_ty).ok()?; + + let discr = self.map.find(discr.as_ref())?; + debug!(?discr); + + let cost = CostChecker::new(self.tcx, self.param_env, None, self.body); + let mut state = State::new_reachable(); + + let conds = if let Some((value, then, else_)) = targets.as_static_if() { + let value = ScalarInt::try_from_uint(value, discr_layout.size)?; + self.arena.alloc_from_iter([ + Condition { value, polarity: Polarity::Eq, target: then }, + Condition { value, polarity: Polarity::Ne, target: else_ }, + ]) + } else { + self.arena.alloc_from_iter(targets.iter().filter_map(|(value, target)| { + let value = ScalarInt::try_from_uint(value, discr_layout.size)?; + Some(Condition { value, polarity: Polarity::Eq, target }) + })) + }; + let conds = ConditionSet(conds); + state.insert_value_idx(discr, conds, self.map); + + self.find_opportunity(bb, state, cost, 0); + None + } + + /// Recursively walk statements backwards from this bb's terminator to find threading + /// opportunities. #[instrument(level = "trace", skip(self, cost), ret)] fn find_opportunity( &mut self, @@ -250,7 +264,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`. // _1 = 6 if let Some((lhs, tail)) = self.mutated_statement(stmt) { - state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::default()); + state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::BOTTOM); } } @@ -270,12 +284,13 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { self.process_switch_int(discr, targets, bb, &mut state); self.find_opportunity(pred, state, cost, depth + 1); } - _ => self.recurse_through_terminator(pred, &state, &cost, depth), + _ => self.recurse_through_terminator(pred, || state, &cost, depth), } - } else { + } else if let &[ref predecessors @ .., last_pred] = &predecessors[..] { for &pred in predecessors { - self.recurse_through_terminator(pred, &state, &cost, depth); + self.recurse_through_terminator(pred, || state.clone(), &cost, depth); } + self.recurse_through_terminator(last_pred, || state, &cost, depth); } let new_tos = &mut self.opportunities[last_non_rec..]; @@ -566,11 +581,12 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { None } - #[instrument(level = "trace", skip(self, cost))] + #[instrument(level = "trace", skip(self, state, cost))] fn recurse_through_terminator( &mut self, bb: BasicBlock, - state: &State<ConditionSet<'a>>, + // Pass a closure that may clone the state, as we don't want to do it each time. + state: impl FnOnce() -> State<ConditionSet<'a>>, cost: &CostChecker<'_, 'tcx>, depth: usize, ) { @@ -600,9 +616,9 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { }; // We can recurse through this terminator. - let mut state = state.clone(); + let mut state = state(); if let Some(place_to_flood) = place_to_flood { - state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::default()); + state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::BOTTOM); } self.find_opportunity(bb, state, cost.clone(), depth + 1); } diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index f7056702cb4..5d253d7384d 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -519,7 +519,7 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &add_subtyping_projections::Subtyper, // calling this after reveal_all ensures that we don't deal with opaque types &elaborate_drops::ElaborateDrops, // This will remove extraneous landing pads which are no longer - // necessary as well as well as forcing any call in a non-unwinding + // necessary as well as forcing any call in a non-unwinding // function calling a possibly-unwinding function to abort the process. &abort_unwinding_calls::AbortUnwindingCalls, // AddMovesForPackedDrops needs to run after drop diff --git a/compiler/rustc_mir_transform/src/promote_consts.rs b/compiler/rustc_mir_transform/src/promote_consts.rs index 3f4d2b65ff2..736647fb64b 100644 --- a/compiler/rustc_mir_transform/src/promote_consts.rs +++ b/compiler/rustc_mir_transform/src/promote_consts.rs @@ -816,7 +816,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> { mut func, mut args, call_source: desugar, fn_span, .. } => { // This promoted involves a function call, so it may fail to evaluate. - // Let's make sure it is added to `required_consts` so that that failure cannot get lost. + // Let's make sure it is added to `required_consts` so that failure cannot get lost. self.add_to_required = true; self.visit_operand(&mut func, loc); diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 25577e88e28..6835a39cf36 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -1,18 +1,17 @@ use rustc_hir as hir; use rustc_hir::def_id::DefId; use rustc_hir::lang_items::LangItem; +use rustc_index::{Idx, IndexVec}; use rustc_middle::mir::*; use rustc_middle::query::Providers; use rustc_middle::ty::GenericArgs; use rustc_middle::ty::{self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; -use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; - -use rustc_index::{Idx, IndexVec}; - use rustc_span::{source_map::Spanned, Span, DUMMY_SP}; +use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; use rustc_target::spec::abi::Abi; +use std::assert_matches::assert_matches; use std::fmt; use std::iter; @@ -1020,21 +1019,19 @@ fn build_construct_coroutine_by_move_shim<'tcx>( receiver_by_ref: bool, ) -> Body<'tcx> { let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let mut self_local: Place<'tcx> = Local::from_usize(1).into(); let ty::CoroutineClosure(_, args) = *self_ty.kind() else { bug!(); }; - // We use `&mut Self` here because we only need to emit an ABI-compatible shim body, - // rather than match the signature exactly (which might take `&self` instead). + // We use `&Self` here because we only need to emit an ABI-compatible shim body, + // rather than match the signature exactly (which might take `&mut self` instead). // - // The self type here is a coroutine-closure, not a coroutine, and we never read from - // it because it never has any captures, because this is only true in the Fn/FnMut - // implementation, not the AsyncFn/AsyncFnMut implementation, which is implemented only - // if the coroutine-closure has no captures. + // We adjust the `self_local` to be a deref since we want to copy fields out of + // a reference to the closure. if receiver_by_ref { - // Triple-check that there's no captures here. - assert_eq!(args.as_coroutine_closure().tupled_upvars_ty(), tcx.types.unit); - self_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, self_ty); + self_local = tcx.mk_place_deref(self_local); + self_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, self_ty); } let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { @@ -1067,11 +1064,27 @@ fn build_construct_coroutine_by_move_shim<'tcx>( fields.push(Operand::Move(Local::from_usize(idx + 1).into())); } for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() { - fields.push(Operand::Move(tcx.mk_place_field( - Local::from_usize(1).into(), - FieldIdx::from_usize(idx), - ty, - ))); + if receiver_by_ref { + // The only situation where it's possible is when we capture immuatable references, + // since those don't need to be reborrowed with the closure's env lifetime. Since + // references are always `Copy`, just emit a copy. + assert_matches!( + ty.kind(), + ty::Ref(_, _, hir::Mutability::Not), + "field should be captured by immutable ref if we have an `Fn` instance" + ); + fields.push(Operand::Copy(tcx.mk_place_field( + self_local, + FieldIdx::from_usize(idx), + ty, + ))); + } else { + fields.push(Operand::Move(tcx.mk_place_field( + self_local, + FieldIdx::from_usize(idx), + ty, + ))); + } } let source_info = SourceInfo::outermost(span); |
