diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
39 files changed, 2265 insertions, 2356 deletions
diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs index dfc7a9891f9..ba70a4453d6 100644 --- a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs +++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs @@ -39,7 +39,9 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls { let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, + ty::CoroutineClosure(..) => Abi::RustCall, ty::Coroutine(..) => Abi::Rust, + ty::Error(_) => return, _ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty), }; let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi); diff --git a/compiler/rustc_mir_transform/src/add_retag.rs b/compiler/rustc_mir_transform/src/add_retag.rs index 94077c63057..430d9572e75 100644 --- a/compiler/rustc_mir_transform/src/add_retag.rs +++ b/compiler/rustc_mir_transform/src/add_retag.rs @@ -36,10 +36,10 @@ fn may_contain_reference<'tcx>(ty: Ty<'tcx>, depth: u32, tcx: TyCtxt<'tcx>) -> b ty::Tuple(tys) => { depth == 0 || tys.iter().any(|ty| may_contain_reference(ty, depth - 1, tcx)) } - ty::Adt(adt, subst) => { + ty::Adt(adt, args) => { depth == 0 || adt.variants().iter().any(|v| { - v.fields.iter().any(|f| may_contain_reference(f.ty(tcx, subst), depth - 1, tcx)) + v.fields.iter().any(|f| may_contain_reference(f.ty(tcx, args), depth - 1, tcx)) }) } // Conservative fallback diff --git a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs index 3195cd3622d..1f615c9d8d1 100644 --- a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs +++ b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs @@ -100,7 +100,7 @@ impl<'tcx> Visitor<'tcx> for ConstMutationChecker<'_, 'tcx> { && let Some((lint_root, span, item)) = self.should_lint_const_item_usage(lhs, def_id, loc) { - self.tcx.emit_spanned_lint( + self.tcx.emit_node_span_lint( CONST_ITEM_MUTATION, lint_root, span, @@ -145,7 +145,7 @@ impl<'tcx> Visitor<'tcx> for ConstMutationChecker<'_, 'tcx> { if let Some((lint_root, span, item)) = self.should_lint_const_item_usage(place, def_id, lint_loc) { - self.tcx.emit_spanned_lint( + self.tcx.emit_node_span_lint( CONST_ITEM_MUTATION, lint_root, span, diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs index 582c2c0c6b6..a0c3de3af58 100644 --- a/compiler/rustc_mir_transform/src/check_unsafety.rs +++ b/compiler/rustc_mir_transform/src/check_unsafety.rs @@ -128,7 +128,9 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { ), } } - &AggregateKind::Closure(def_id, _) | &AggregateKind::Coroutine(def_id, _) => { + &AggregateKind::Closure(def_id, _) + | &AggregateKind::CoroutineClosure(def_id, _) + | &AggregateKind::Coroutine(def_id, _) => { let def_id = def_id.expect_local(); let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } = self.tcx.mir_unsafety_check_result(def_id); @@ -241,10 +243,11 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { // old value is being dropped. let assigned_ty = place.ty(&self.body.local_decls, self.tcx).ty; if assigned_ty.needs_drop(self.tcx, self.param_env) { - // This would be unsafe, but should be outright impossible since we reject such unions. - self.tcx.dcx().span_delayed_bug( - self.source_info.span, - format!("union fields that need dropping should be impossible: {assigned_ty}") + // This would be unsafe, but should be outright impossible since we reject + // such unions. + assert!( + self.tcx.dcx().has_errors().is_some(), + "union fields that need dropping should be impossible: {assigned_ty}" ); } } else { @@ -527,7 +530,7 @@ fn report_unused_unsafe(tcx: TyCtxt<'_>, kind: UnusedUnsafe, id: HirId) { } else { None }; - tcx.emit_spanned_lint(UNUSED_UNSAFE, id, span, errors::UnusedUnsafe { span, nested_parent }); + tcx.emit_node_span_lint(UNUSED_UNSAFE, id, span, errors::UnusedUnsafe { span, nested_parent }); } pub fn check_unsafety(tcx: TyCtxt<'_>, def_id: LocalDefId) { @@ -577,7 +580,7 @@ pub fn check_unsafety(tcx: TyCtxt<'_>, def_id: LocalDefId) { }); } UnsafetyViolationKind::UnsafeFn => { - tcx.emit_spanned_lint( + tcx.emit_node_span_lint( UNSAFE_OP_IN_UNSAFE_FN, lint_root, source_info.span, diff --git a/compiler/rustc_mir_transform/src/const_goto.rs b/compiler/rustc_mir_transform/src/const_goto.rs deleted file mode 100644 index cb5b66b314d..00000000000 --- a/compiler/rustc_mir_transform/src/const_goto.rs +++ /dev/null @@ -1,128 +0,0 @@ -//! This pass optimizes the following sequence -//! ```rust,ignore (example) -//! bb2: { -//! _2 = const true; -//! goto -> bb3; -//! } -//! -//! bb3: { -//! switchInt(_2) -> [false: bb4, otherwise: bb5]; -//! } -//! ``` -//! into -//! ```rust,ignore (example) -//! bb2: { -//! _2 = const true; -//! goto -> bb5; -//! } -//! ``` - -use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; -use rustc_middle::{mir::visit::Visitor, ty::ParamEnv}; - -use super::simplify::{simplify_cfg, simplify_locals}; - -pub struct ConstGoto; - -impl<'tcx> MirPass<'tcx> for ConstGoto { - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - // This pass participates in some as-of-yet untested unsoundness found - // in https://github.com/rust-lang/rust/issues/112460 - sess.mir_opt_level() >= 2 && sess.opts.unstable_opts.unsound_mir_opts - } - - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - trace!("Running ConstGoto on {:?}", body.source); - let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); - let mut opt_finder = - ConstGotoOptimizationFinder { tcx, body, optimizations: vec![], param_env }; - opt_finder.visit_body(body); - let should_simplify = !opt_finder.optimizations.is_empty(); - for opt in opt_finder.optimizations { - let block = &mut body.basic_blocks_mut()[opt.bb_with_goto]; - block.statements.extend(opt.stmts_move_up); - let terminator = block.terminator_mut(); - let new_goto = TerminatorKind::Goto { target: opt.target_to_use_in_goto }; - debug!("SUCCESS: replacing `{:?}` with `{:?}`", terminator.kind, new_goto); - terminator.kind = new_goto; - } - - // if we applied optimizations, we potentially have some cfg to cleanup to - // make it easier for further passes - if should_simplify { - simplify_cfg(body); - simplify_locals(body, tcx); - } - } -} - -impl<'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'_, 'tcx> { - fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { - if data.is_cleanup { - // Because of the restrictions around control flow in cleanup blocks, we don't perform - // this optimization at all in such blocks. - return; - } - self.super_basic_block_data(block, data); - } - - fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { - let _: Option<_> = try { - let target = terminator.kind.as_goto()?; - // We only apply this optimization if the last statement is a const assignment - let last_statement = self.body.basic_blocks[location.block].statements.last()?; - - if let (place, Rvalue::Use(Operand::Constant(_const))) = - last_statement.kind.as_assign()? - { - // We found a constant being assigned to `place`. - // Now check that the target of this Goto switches on this place. - let target_bb = &self.body.basic_blocks[target]; - - // The `StorageDead(..)` statement does not affect the functionality of mir. - // We can move this part of the statement up to the predecessor. - let mut stmts_move_up = Vec::new(); - for stmt in &target_bb.statements { - if let StatementKind::StorageDead(..) = stmt.kind { - stmts_move_up.push(stmt.clone()) - } else { - None?; - } - } - - let target_bb_terminator = target_bb.terminator(); - let (discr, targets) = target_bb_terminator.kind.as_switch()?; - if discr.place() == Some(*place) { - let switch_ty = place.ty(self.body.local_decls(), self.tcx).ty; - debug_assert_eq!(switch_ty, _const.ty()); - // We now know that the Switch matches on the const place, and it is statementless - // Now find which value in the Switch matches the const value. - let const_value = _const.const_.try_eval_bits(self.tcx, self.param_env)?; - let target_to_use_in_goto = targets.target_for_value(const_value); - self.optimizations.push(OptimizationToApply { - bb_with_goto: location.block, - target_to_use_in_goto, - stmts_move_up, - }); - } - } - Some(()) - }; - - self.super_terminator(terminator, location); - } -} - -struct OptimizationToApply<'tcx> { - bb_with_goto: BasicBlock, - target_to_use_in_goto: BasicBlock, - stmts_move_up: Vec<Statement<'tcx>>, -} - -pub struct ConstGotoOptimizationFinder<'a, 'tcx> { - tcx: TyCtxt<'tcx>, - body: &'a Body<'tcx>, - param_env: ParamEnv<'tcx>, - optimizations: Vec<OptimizationToApply<'tcx>>, -} diff --git a/compiler/rustc_mir_transform/src/const_prop.rs b/compiler/rustc_mir_transform/src/const_prop.rs deleted file mode 100644 index c5824c30770..00000000000 --- a/compiler/rustc_mir_transform/src/const_prop.rs +++ /dev/null @@ -1,326 +0,0 @@ -//! Propagates constants for early reporting of statically known -//! assertion failures - -use rustc_const_eval::interpret::{ - self, compile_time_machine, AllocId, ConstAllocation, FnArg, Frame, ImmTy, InterpCx, - InterpResult, OpTy, PlaceTy, Pointer, -}; -use rustc_data_structures::fx::FxHashSet; -use rustc_index::bit_set::BitSet; -use rustc_index::IndexVec; -use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}; -use rustc_middle::mir::*; -use rustc_middle::query::TyCtxtAt; -use rustc_middle::ty::layout::TyAndLayout; -use rustc_middle::ty::{self, ParamEnv, TyCtxt}; -use rustc_span::def_id::DefId; -use rustc_target::abi::Size; -use rustc_target::spec::abi::Abi as CallAbi; - -/// The maximum number of bytes that we'll allocate space for a local or the return value. -/// Needed for #66397, because otherwise we eval into large places and that can cause OOM or just -/// Severely regress performance. -const MAX_ALLOC_LIMIT: u64 = 1024; - -/// Macro for machine-specific `InterpError` without allocation. -/// (These will never be shown to the user, but they help diagnose ICEs.) -pub(crate) macro throw_machine_stop_str($($tt:tt)*) {{ - // We make a new local type for it. The type itself does not carry any information, - // but its vtable (for the `MachineStopType` trait) does. - #[derive(Debug)] - struct Zst; - // Printing this type shows the desired string. - impl std::fmt::Display for Zst { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, $($tt)*) - } - } - - impl rustc_middle::mir::interpret::MachineStopType for Zst { - fn diagnostic_message(&self) -> rustc_errors::DiagnosticMessage { - self.to_string().into() - } - - fn add_args( - self: Box<Self>, - _: &mut dyn FnMut(std::borrow::Cow<'static, str>, rustc_errors::DiagnosticArgValue<'static>), - ) {} - } - throw_machine_stop!(Zst) -}} - -pub(crate) struct ConstPropMachine<'mir, 'tcx> { - /// The virtual call stack. - stack: Vec<Frame<'mir, 'tcx>>, - pub written_only_inside_own_block_locals: FxHashSet<Local>, - pub can_const_prop: IndexVec<Local, ConstPropMode>, -} - -impl ConstPropMachine<'_, '_> { - pub fn new(can_const_prop: IndexVec<Local, ConstPropMode>) -> Self { - Self { - stack: Vec::new(), - written_only_inside_own_block_locals: Default::default(), - can_const_prop, - } - } -} - -impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> { - compile_time_machine!(<'mir, 'tcx>); - - const PANIC_ON_ALLOC_FAIL: bool = true; // all allocations are small (see `MAX_ALLOC_LIMIT`) - - const POST_MONO_CHECKS: bool = false; // this MIR is still generic! - - type MemoryKind = !; - - #[inline(always)] - fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool { - false // no reason to enforce alignment - } - - #[inline(always)] - fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool { - false // for now, we don't enforce validity - } - - fn load_mir( - _ecx: &InterpCx<'mir, 'tcx, Self>, - _instance: ty::InstanceDef<'tcx>, - ) -> InterpResult<'tcx, &'tcx Body<'tcx>> { - throw_machine_stop_str!("calling functions isn't supported in ConstProp") - } - - fn panic_nounwind(_ecx: &mut InterpCx<'mir, 'tcx, Self>, _msg: &str) -> InterpResult<'tcx> { - throw_machine_stop_str!("panicking isn't supported in ConstProp") - } - - fn find_mir_or_eval_fn( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _instance: ty::Instance<'tcx>, - _abi: CallAbi, - _args: &[FnArg<'tcx>], - _destination: &PlaceTy<'tcx>, - _target: Option<BasicBlock>, - _unwind: UnwindAction, - ) -> InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> { - Ok(None) - } - - fn call_intrinsic( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _instance: ty::Instance<'tcx>, - _args: &[OpTy<'tcx>], - _destination: &PlaceTy<'tcx>, - _target: Option<BasicBlock>, - _unwind: UnwindAction, - ) -> InterpResult<'tcx> { - throw_machine_stop_str!("calling intrinsics isn't supported in ConstProp") - } - - fn assert_panic( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _msg: &rustc_middle::mir::AssertMessage<'tcx>, - _unwind: rustc_middle::mir::UnwindAction, - ) -> InterpResult<'tcx> { - bug!("panics terminators are not evaluated in ConstProp") - } - - fn binary_ptr_op( - _ecx: &InterpCx<'mir, 'tcx, Self>, - _bin_op: BinOp, - _left: &ImmTy<'tcx>, - _right: &ImmTy<'tcx>, - ) -> InterpResult<'tcx, (ImmTy<'tcx>, bool)> { - // We can't do this because aliasing of memory can differ between const eval and llvm - throw_machine_stop_str!("pointer arithmetic or comparisons aren't supported in ConstProp") - } - - fn before_access_local_mut<'a>( - ecx: &'a mut InterpCx<'mir, 'tcx, Self>, - frame: usize, - local: Local, - ) -> InterpResult<'tcx> { - assert_eq!(frame, 0); - match ecx.machine.can_const_prop[local] { - ConstPropMode::NoPropagation => { - throw_machine_stop_str!( - "tried to write to a local that is marked as not propagatable" - ) - } - ConstPropMode::OnlyInsideOwnBlock => { - ecx.machine.written_only_inside_own_block_locals.insert(local); - } - ConstPropMode::FullConstProp => {} - } - Ok(()) - } - - fn before_access_global( - _tcx: TyCtxtAt<'tcx>, - _machine: &Self, - _alloc_id: AllocId, - alloc: ConstAllocation<'tcx>, - _static_def_id: Option<DefId>, - is_write: bool, - ) -> InterpResult<'tcx> { - if is_write { - throw_machine_stop_str!("can't write to global"); - } - // If the static allocation is mutable, then we can't const prop it as its content - // might be different at runtime. - if alloc.inner().mutability.is_mut() { - throw_machine_stop_str!("can't access mutable globals in ConstProp"); - } - - Ok(()) - } - - #[inline(always)] - fn expose_ptr(_ecx: &mut InterpCx<'mir, 'tcx, Self>, _ptr: Pointer) -> InterpResult<'tcx> { - throw_machine_stop_str!("exposing pointers isn't supported in ConstProp") - } - - #[inline(always)] - fn init_frame_extra( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - frame: Frame<'mir, 'tcx>, - ) -> InterpResult<'tcx, Frame<'mir, 'tcx>> { - Ok(frame) - } - - #[inline(always)] - fn stack<'a>( - ecx: &'a InterpCx<'mir, 'tcx, Self>, - ) -> &'a [Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>] { - &ecx.machine.stack - } - - #[inline(always)] - fn stack_mut<'a>( - ecx: &'a mut InterpCx<'mir, 'tcx, Self>, - ) -> &'a mut Vec<Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>> { - &mut ecx.machine.stack - } -} - -/// The mode that `ConstProp` is allowed to run in for a given `Local`. -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum ConstPropMode { - /// The `Local` can be propagated into and reads of this `Local` can also be propagated. - FullConstProp, - /// The `Local` can only be propagated into and from its own block. - OnlyInsideOwnBlock, - /// The `Local` cannot be part of propagation at all. Any statement - /// referencing it either for reading or writing will not get propagated. - NoPropagation, -} - -pub struct CanConstProp { - can_const_prop: IndexVec<Local, ConstPropMode>, - // False at the beginning. Once set, no more assignments are allowed to that local. - found_assignment: BitSet<Local>, -} - -impl CanConstProp { - /// Returns true if `local` can be propagated - pub fn check<'tcx>( - tcx: TyCtxt<'tcx>, - param_env: ParamEnv<'tcx>, - body: &Body<'tcx>, - ) -> IndexVec<Local, ConstPropMode> { - let mut cpv = CanConstProp { - can_const_prop: IndexVec::from_elem(ConstPropMode::FullConstProp, &body.local_decls), - found_assignment: BitSet::new_empty(body.local_decls.len()), - }; - for (local, val) in cpv.can_const_prop.iter_enumerated_mut() { - let ty = body.local_decls[local].ty; - match tcx.layout_of(param_env.and(ty)) { - Ok(layout) if layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) => {} - // Either the layout fails to compute, then we can't use this local anyway - // or the local is too large, then we don't want to. - _ => { - *val = ConstPropMode::NoPropagation; - continue; - } - } - } - // Consider that arguments are assigned on entry. - for arg in body.args_iter() { - cpv.found_assignment.insert(arg); - } - cpv.visit_body(body); - cpv.can_const_prop - } -} - -impl<'tcx> Visitor<'tcx> for CanConstProp { - fn visit_place(&mut self, place: &Place<'tcx>, mut context: PlaceContext, loc: Location) { - use rustc_middle::mir::visit::PlaceContext::*; - - // Dereferencing just read the addess of `place.local`. - if place.projection.first() == Some(&PlaceElem::Deref) { - context = NonMutatingUse(NonMutatingUseContext::Copy); - } - - self.visit_local(place.local, context, loc); - self.visit_projection(place.as_ref(), context, loc); - } - - fn visit_local(&mut self, local: Local, context: PlaceContext, _: Location) { - use rustc_middle::mir::visit::PlaceContext::*; - match context { - // These are just stores, where the storing is not propagatable, but there may be later - // mutations of the same local via `Store` - | MutatingUse(MutatingUseContext::Call) - | MutatingUse(MutatingUseContext::AsmOutput) - | MutatingUse(MutatingUseContext::Deinit) - // Actual store that can possibly even propagate a value - | MutatingUse(MutatingUseContext::Store) - | MutatingUse(MutatingUseContext::SetDiscriminant) => { - if !self.found_assignment.insert(local) { - match &mut self.can_const_prop[local] { - // If the local can only get propagated in its own block, then we don't have - // to worry about multiple assignments, as we'll nuke the const state at the - // end of the block anyway, and inside the block we overwrite previous - // states as applicable. - ConstPropMode::OnlyInsideOwnBlock => {} - ConstPropMode::NoPropagation => {} - other @ ConstPropMode::FullConstProp => { - trace!( - "local {:?} can't be propagated because of multiple assignments. Previous state: {:?}", - local, other, - ); - *other = ConstPropMode::OnlyInsideOwnBlock; - } - } - } - } - // Reading constants is allowed an arbitrary number of times - NonMutatingUse(NonMutatingUseContext::Copy) - | NonMutatingUse(NonMutatingUseContext::Move) - | NonMutatingUse(NonMutatingUseContext::Inspect) - | NonMutatingUse(NonMutatingUseContext::PlaceMention) - | NonUse(_) => {} - - // These could be propagated with a smarter analysis or just some careful thinking about - // whether they'd be fine right now. - MutatingUse(MutatingUseContext::Yield) - | MutatingUse(MutatingUseContext::Drop) - | MutatingUse(MutatingUseContext::Retag) - // These can't ever be propagated under any scheme, as we can't reason about indirect - // mutation. - | NonMutatingUse(NonMutatingUseContext::SharedBorrow) - | NonMutatingUse(NonMutatingUseContext::FakeBorrow) - | NonMutatingUse(NonMutatingUseContext::AddressOf) - | MutatingUse(MutatingUseContext::Borrow) - | MutatingUse(MutatingUseContext::AddressOf) => { - trace!("local {:?} can't be propagated because it's used: {:?}", local, context); - self.can_const_prop[local] = ConstPropMode::NoPropagation; - } - MutatingUse(MutatingUseContext::Projection) - | NonMutatingUse(NonMutatingUseContext::Projection) => bug!("visit_place should not pass {context:?} for {local:?}"), - } - } -} diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs deleted file mode 100644 index d0bbca08a40..00000000000 --- a/compiler/rustc_mir_transform/src/const_prop_lint.rs +++ /dev/null @@ -1,696 +0,0 @@ -//! Propagates constants for early reporting of statically known -//! assertion failures - -use std::fmt::Debug; - -use either::Left; - -use rustc_const_eval::interpret::Immediate; -use rustc_const_eval::interpret::{ - InterpCx, InterpResult, MemoryKind, OpTy, Scalar, StackPopCleanup, -}; -use rustc_const_eval::ReportErrorExt; -use rustc_hir::def::DefKind; -use rustc_hir::HirId; -use rustc_index::bit_set::BitSet; -use rustc_middle::mir::visit::Visitor; -use rustc_middle::mir::*; -use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; -use rustc_middle::ty::GenericArgs; -use rustc_middle::ty::{ - self, ConstInt, Instance, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisitableExt, -}; -use rustc_span::Span; -use rustc_target::abi::{HasDataLayout, Size, TargetDataLayout}; - -use crate::const_prop::CanConstProp; -use crate::const_prop::ConstPropMachine; -use crate::const_prop::ConstPropMode; -use crate::errors::AssertLint; -use crate::MirLint; - -/// The maximum number of bytes that we'll allocate space for a local or the return value. -/// Needed for #66397, because otherwise we eval into large places and that can cause OOM or just -/// Severely regress performance. -const MAX_ALLOC_LIMIT: u64 = 1024; - -pub struct ConstPropLint; - -impl<'tcx> MirLint<'tcx> for ConstPropLint { - fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { - if body.tainted_by_errors.is_some() { - return; - } - - // will be evaluated by miri and produce its errors there - if body.source.promoted.is_some() { - return; - } - - let def_id = body.source.def_id().expect_local(); - let def_kind = tcx.def_kind(def_id); - let is_fn_like = def_kind.is_fn_like(); - let is_assoc_const = def_kind == DefKind::AssocConst; - - // Only run const prop on functions, methods, closures and associated constants - if !is_fn_like && !is_assoc_const { - // skip anon_const/statics/consts because they'll be evaluated by miri anyway - trace!("ConstPropLint skipped for {:?}", def_id); - return; - } - - // FIXME(welseywiser) const prop doesn't work on coroutines because of query cycles - // computing their layout. - if tcx.is_coroutine(def_id.to_def_id()) { - trace!("ConstPropLint skipped for coroutine {:?}", def_id); - return; - } - - trace!("ConstPropLint starting for {:?}", def_id); - - // FIXME(oli-obk, eddyb) Optimize locals (or even local paths) to hold - // constants, instead of just checking for const-folding succeeding. - // That would require a uniform one-def no-mutation analysis - // and RPO (or recursing when needing the value of a local). - let mut linter = ConstPropagator::new(body, tcx); - linter.visit_body(body); - - trace!("ConstPropLint done for {:?}", def_id); - } -} - -/// Finds optimization opportunities on the MIR. -struct ConstPropagator<'mir, 'tcx> { - ecx: InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, - tcx: TyCtxt<'tcx>, - param_env: ParamEnv<'tcx>, - worklist: Vec<BasicBlock>, - visited_blocks: BitSet<BasicBlock>, -} - -impl<'tcx> LayoutOfHelpers<'tcx> for ConstPropagator<'_, 'tcx> { - type LayoutOfResult = Result<TyAndLayout<'tcx>, LayoutError<'tcx>>; - - #[inline] - fn handle_layout_err(&self, err: LayoutError<'tcx>, _: Span, _: Ty<'tcx>) -> LayoutError<'tcx> { - err - } -} - -impl HasDataLayout for ConstPropagator<'_, '_> { - #[inline] - fn data_layout(&self) -> &TargetDataLayout { - &self.tcx.data_layout - } -} - -impl<'tcx> ty::layout::HasTyCtxt<'tcx> for ConstPropagator<'_, 'tcx> { - #[inline] - fn tcx(&self) -> TyCtxt<'tcx> { - self.tcx - } -} - -impl<'tcx> ty::layout::HasParamEnv<'tcx> for ConstPropagator<'_, 'tcx> { - #[inline] - fn param_env(&self) -> ty::ParamEnv<'tcx> { - self.param_env - } -} - -impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { - fn new(body: &'mir Body<'tcx>, tcx: TyCtxt<'tcx>) -> ConstPropagator<'mir, 'tcx> { - let def_id = body.source.def_id(); - let args = &GenericArgs::identity_for_item(tcx, def_id); - let param_env = tcx.param_env_reveal_all_normalized(def_id); - - let can_const_prop = CanConstProp::check(tcx, param_env, body); - let mut ecx = InterpCx::new( - tcx, - tcx.def_span(def_id), - param_env, - ConstPropMachine::new(can_const_prop), - ); - - let ret_layout = ecx - .layout_of(body.bound_return_ty().instantiate(tcx, args)) - .ok() - // Don't bother allocating memory for large values. - // I don't know how return types can seem to be unsized but this happens in the - // `type/type-unsatisfiable.rs` test. - .filter(|ret_layout| { - ret_layout.is_sized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) - }) - .unwrap_or_else(|| ecx.layout_of(tcx.types.unit).unwrap()); - - let ret = ecx - .allocate(ret_layout, MemoryKind::Stack) - .expect("couldn't perform small allocation") - .into(); - - ecx.push_stack_frame( - Instance::new(def_id, args), - body, - &ret, - StackPopCleanup::Root { cleanup: false }, - ) - .expect("failed to push initial stack frame"); - - for local in body.local_decls.indices() { - // Mark everything initially live. - // This is somewhat dicey since some of them might be unsized and it is incoherent to - // mark those as live... We rely on `local_to_place`/`local_to_op` in the interpreter - // stopping us before those unsized immediates can cause issues deeper in the - // interpreter. - ecx.frame_mut().locals[local].make_live_uninit(); - } - - ConstPropagator { - ecx, - tcx, - param_env, - worklist: vec![START_BLOCK], - visited_blocks: BitSet::new_empty(body.basic_blocks.len()), - } - } - - fn body(&self) -> &'mir Body<'tcx> { - self.ecx.frame().body - } - - fn local_decls(&self) -> &'mir LocalDecls<'tcx> { - &self.body().local_decls - } - - fn get_const(&self, place: Place<'tcx>) -> Option<OpTy<'tcx>> { - let op = match self.ecx.eval_place_to_op(place, None) { - Ok(op) => { - if op - .as_mplace_or_imm() - .right() - .is_some_and(|imm| matches!(*imm, Immediate::Uninit)) - { - // Make sure nobody accidentally uses this value. - return None; - } - op - } - Err(e) => { - trace!("get_const failed: {:?}", e.into_kind().debug()); - return None; - } - }; - - // Try to read the local as an immediate so that if it is representable as a scalar, we can - // handle it as such, but otherwise, just return the value as is. - Some(match self.ecx.read_immediate_raw(&op) { - Ok(Left(imm)) => imm.into(), - _ => op, - }) - } - - /// Remove `local` from the pool of `Locals`. Allows writing to them, - /// but not reading from them anymore. - fn remove_const(ecx: &mut InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, local: Local) { - ecx.frame_mut().locals[local].make_live_uninit(); - ecx.machine.written_only_inside_own_block_locals.remove(&local); - } - - fn lint_root(&self, source_info: SourceInfo) -> Option<HirId> { - source_info.scope.lint_root(&self.body().source_scopes) - } - - fn use_ecx<F, T>(&mut self, location: Location, f: F) -> Option<T> - where - F: FnOnce(&mut Self) -> InterpResult<'tcx, T>, - { - // Overwrite the PC -- whatever the interpreter does to it does not make any sense anyway. - self.ecx.frame_mut().loc = Left(location); - match f(self) { - Ok(val) => Some(val), - Err(error) => { - trace!("InterpCx operation failed: {:?}", error); - // Some errors shouldn't come up because creating them causes - // an allocation, which we should avoid. When that happens, - // dedicated error variants should be introduced instead. - assert!( - !error.kind().formatted_string(), - "const-prop encountered formatting error: {}", - self.ecx.format_error(error), - ); - None - } - } - } - - /// Returns the value, if any, of evaluating `c`. - fn eval_constant(&mut self, c: &ConstOperand<'tcx>, location: Location) -> Option<OpTy<'tcx>> { - // FIXME we need to revisit this for #67176 - if c.has_param() { - return None; - } - - // Normalization needed b/c const prop lint runs in - // `mir_drops_elaborated_and_const_checked`, which happens before - // optimized MIR. Only after optimizing the MIR can we guarantee - // that the `RevealAll` pass has happened and that the body's consts - // are normalized, so any call to resolve before that needs to be - // manually normalized. - let val = self.tcx.try_normalize_erasing_regions(self.param_env, c.const_).ok()?; - - self.use_ecx(location, |this| this.ecx.eval_mir_constant(&val, Some(c.span), None)) - } - - /// Returns the value, if any, of evaluating `place`. - fn eval_place(&mut self, place: Place<'tcx>, location: Location) -> Option<OpTy<'tcx>> { - trace!("eval_place(place={:?})", place); - self.use_ecx(location, |this| this.ecx.eval_place_to_op(place, None)) - } - - /// Returns the value, if any, of evaluating `op`. Calls upon `eval_constant` - /// or `eval_place`, depending on the variant of `Operand` used. - fn eval_operand(&mut self, op: &Operand<'tcx>, location: Location) -> Option<OpTy<'tcx>> { - match *op { - Operand::Constant(ref c) => self.eval_constant(c, location), - Operand::Move(place) | Operand::Copy(place) => self.eval_place(place, location), - } - } - - fn report_assert_as_lint(&self, source_info: &SourceInfo, lint: AssertLint<impl Debug>) { - if let Some(lint_root) = self.lint_root(*source_info) { - self.tcx.emit_spanned_lint(lint.lint(), lint_root, source_info.span, lint); - } - } - - fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>, location: Location) -> Option<()> { - if let (val, true) = self.use_ecx(location, |this| { - let val = this.ecx.read_immediate(&this.ecx.eval_operand(arg, None)?)?; - let (_res, overflow) = this.ecx.overflowing_unary_op(op, &val)?; - Ok((val, overflow)) - })? { - // `AssertKind` only has an `OverflowNeg` variant, so make sure that is - // appropriate to use. - assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow"); - let source_info = self.body().source_info(location); - self.report_assert_as_lint( - source_info, - AssertLint::ArithmeticOverflow( - source_info.span, - AssertKind::OverflowNeg(val.to_const_int()), - ), - ); - return None; - } - - Some(()) - } - - fn check_binary_op( - &mut self, - op: BinOp, - left: &Operand<'tcx>, - right: &Operand<'tcx>, - location: Location, - ) -> Option<()> { - let r = self.use_ecx(location, |this| { - this.ecx.read_immediate(&this.ecx.eval_operand(right, None)?) - }); - let l = self - .use_ecx(location, |this| this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?)); - // Check for exceeding shifts *even if* we cannot evaluate the LHS. - if matches!(op, BinOp::Shr | BinOp::Shl) { - let r = r.clone()?; - // We need the type of the LHS. We cannot use `place_layout` as that is the type - // of the result, which for checked binops is not the same! - let left_ty = left.ty(self.local_decls(), self.tcx); - let left_size = self.ecx.layout_of(left_ty).ok()?.size; - let right_size = r.layout.size; - let r_bits = r.to_scalar().to_bits(right_size).ok(); - if r_bits.is_some_and(|b| b >= left_size.bits() as u128) { - debug!("check_binary_op: reporting assert for {:?}", location); - let source_info = self.body().source_info(location); - let panic = AssertKind::Overflow( - op, - match l { - Some(l) => l.to_const_int(), - // Invent a dummy value, the diagnostic ignores it anyway - None => ConstInt::new( - ScalarInt::try_from_uint(1_u8, left_size).unwrap(), - left_ty.is_signed(), - left_ty.is_ptr_sized_integral(), - ), - }, - r.to_const_int(), - ); - self.report_assert_as_lint( - source_info, - AssertLint::ArithmeticOverflow(source_info.span, panic), - ); - return None; - } - } - - if let (Some(l), Some(r)) = (l, r) { - // The remaining operators are handled through `overflowing_binary_op`. - if self.use_ecx(location, |this| { - let (_res, overflow) = this.ecx.overflowing_binary_op(op, &l, &r)?; - Ok(overflow) - })? { - let source_info = self.body().source_info(location); - self.report_assert_as_lint( - source_info, - AssertLint::ArithmeticOverflow( - source_info.span, - AssertKind::Overflow(op, l.to_const_int(), r.to_const_int()), - ), - ); - return None; - } - } - Some(()) - } - - fn check_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) -> Option<()> { - // Perform any special handling for specific Rvalue types. - // Generally, checks here fall into one of two categories: - // 1. Additional checking to provide useful lints to the user - // - In this case, we will do some validation and then fall through to the - // end of the function which evals the assignment. - // 2. Working around bugs in other parts of the compiler - // - In this case, we'll return `None` from this function to stop evaluation. - match rvalue { - // Additional checking: give lints to the user if an overflow would occur. - // We do this here and not in the `Assert` terminator as that terminator is - // only sometimes emitted (overflow checks can be disabled), but we want to always - // lint. - Rvalue::UnaryOp(op, arg) => { - trace!("checking UnaryOp(op = {:?}, arg = {:?})", op, arg); - self.check_unary_op(*op, arg, location)?; - } - Rvalue::BinaryOp(op, box (left, right)) => { - trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right); - self.check_binary_op(*op, left, right, location)?; - } - Rvalue::CheckedBinaryOp(op, box (left, right)) => { - trace!( - "checking CheckedBinaryOp(op = {:?}, left = {:?}, right = {:?})", - op, - left, - right - ); - self.check_binary_op(*op, left, right, location)?; - } - - // Do not try creating references (#67862) - Rvalue::AddressOf(_, place) | Rvalue::Ref(_, _, place) => { - trace!("skipping AddressOf | Ref for {:?}", place); - - // This may be creating mutable references or immutable references to cells. - // If that happens, the pointed to value could be mutated via that reference. - // Since we aren't tracking references, the const propagator loses track of what - // value the local has right now. - // Thus, all locals that have their reference taken - // must not take part in propagation. - Self::remove_const(&mut self.ecx, place.local); - - return None; - } - Rvalue::ThreadLocalRef(def_id) => { - trace!("skipping ThreadLocalRef({:?})", def_id); - - return None; - } - - // There's no other checking to do at this time. - Rvalue::Aggregate(..) - | Rvalue::Use(..) - | Rvalue::CopyForDeref(..) - | Rvalue::Repeat(..) - | Rvalue::Len(..) - | Rvalue::Cast(..) - | Rvalue::ShallowInitBox(..) - | Rvalue::Discriminant(..) - | Rvalue::NullaryOp(..) => {} - } - - // FIXME we need to revisit this for #67176 - if rvalue.has_param() { - return None; - } - if !rvalue.ty(self.local_decls(), self.tcx).is_sized(self.tcx, self.param_env) { - // the interpreter doesn't support unsized locals (only unsized arguments), - // but rustc does (in a kinda broken way), so we have to skip them here - return None; - } - - Some(()) - } - - fn check_assertion( - &mut self, - expected: bool, - msg: &AssertKind<Operand<'tcx>>, - cond: &Operand<'tcx>, - location: Location, - ) -> Option<!> { - let value = &self.eval_operand(cond, location)?; - trace!("assertion on {:?} should be {:?}", value, expected); - - let expected = Scalar::from_bool(expected); - let value_const = self.use_ecx(location, |this| this.ecx.read_scalar(value))?; - - if expected != value_const { - // Poison all places this operand references so that further code - // doesn't use the invalid value - if let Some(place) = cond.place() { - Self::remove_const(&mut self.ecx, place.local); - } - - enum DbgVal<T> { - Val(T), - Underscore, - } - impl<T: std::fmt::Debug> std::fmt::Debug for DbgVal<T> { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Val(val) => val.fmt(fmt), - Self::Underscore => fmt.write_str("_"), - } - } - } - let mut eval_to_int = |op| { - // This can be `None` if the lhs wasn't const propagated and we just - // triggered the assert on the value of the rhs. - self.eval_operand(op, location) - .and_then(|op| self.ecx.read_immediate(&op).ok()) - .map_or(DbgVal::Underscore, |op| DbgVal::Val(op.to_const_int())) - }; - let msg = match msg { - AssertKind::DivisionByZero(op) => AssertKind::DivisionByZero(eval_to_int(op)), - AssertKind::RemainderByZero(op) => AssertKind::RemainderByZero(eval_to_int(op)), - AssertKind::Overflow(bin_op @ (BinOp::Div | BinOp::Rem), op1, op2) => { - // Division overflow is *UB* in the MIR, and different than the - // other overflow checks. - AssertKind::Overflow(*bin_op, eval_to_int(op1), eval_to_int(op2)) - } - AssertKind::BoundsCheck { ref len, ref index } => { - let len = eval_to_int(len); - let index = eval_to_int(index); - AssertKind::BoundsCheck { len, index } - } - // Remaining overflow errors are already covered by checks on the binary operators. - AssertKind::Overflow(..) | AssertKind::OverflowNeg(_) => return None, - // Need proper const propagator for these. - _ => return None, - }; - let source_info = self.body().source_info(location); - self.report_assert_as_lint( - source_info, - AssertLint::UnconditionalPanic(source_info.span, msg), - ); - } - - None - } - - fn ensure_not_propagated(&self, local: Local) { - if cfg!(debug_assertions) { - assert!( - self.get_const(local.into()).is_none() - || self - .layout_of(self.local_decls()[local].ty) - .map_or(true, |layout| layout.is_zst()), - "failed to remove values for `{local:?}`, value={:?}", - self.get_const(local.into()), - ) - } - } -} - -impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { - fn visit_body(&mut self, body: &Body<'tcx>) { - while let Some(bb) = self.worklist.pop() { - if !self.visited_blocks.insert(bb) { - continue; - } - - let data = &body.basic_blocks[bb]; - self.visit_basic_block_data(bb, data); - } - } - - fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { - self.super_operand(operand, location); - } - - fn visit_constant(&mut self, constant: &ConstOperand<'tcx>, location: Location) { - trace!("visit_constant: {:?}", constant); - self.super_constant(constant, location); - self.eval_constant(constant, location); - } - - fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { - self.super_assign(place, rvalue, location); - - let Some(()) = self.check_rvalue(rvalue, location) else { return }; - - match self.ecx.machine.can_const_prop[place.local] { - // Do nothing if the place is indirect. - _ if place.is_indirect() => {} - ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), - ConstPropMode::OnlyInsideOwnBlock | ConstPropMode::FullConstProp => { - if self - .use_ecx(location, |this| this.ecx.eval_rvalue_into_place(rvalue, *place)) - .is_none() - { - // Const prop failed, so erase the destination, ensuring that whatever happens - // from here on, does not know about the previous value. - // This is important in case we have - // ```rust - // let mut x = 42; - // x = SOME_MUTABLE_STATIC; - // // x must now be uninit - // ``` - // FIXME: we overzealously erase the entire local, because that's easier to - // implement. - trace!( - "propagation into {:?} failed. - Nuking the entire site from orbit, it's the only way to be sure", - place, - ); - Self::remove_const(&mut self.ecx, place.local); - } - } - } - } - - fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { - trace!("visit_statement: {:?}", statement); - - // We want to evaluate operands before any change to the assigned-to value, - // so we recurse first. - self.super_statement(statement, location); - - match statement.kind { - StatementKind::SetDiscriminant { ref place, .. } => { - match self.ecx.machine.can_const_prop[place.local] { - // Do nothing if the place is indirect. - _ if place.is_indirect() => {} - ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), - ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { - if self.use_ecx(location, |this| this.ecx.statement(statement)).is_some() { - trace!("propped discriminant into {:?}", place); - } else { - Self::remove_const(&mut self.ecx, place.local); - } - } - } - } - StatementKind::StorageLive(local) => { - let frame = self.ecx.frame_mut(); - frame.locals[local].make_live_uninit(); - } - StatementKind::StorageDead(local) => { - let frame = self.ecx.frame_mut(); - // We don't actually track liveness, so the local remains live. But forget its value. - frame.locals[local].make_live_uninit(); - } - _ => {} - } - } - - fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { - self.super_terminator(terminator, location); - match &terminator.kind { - TerminatorKind::Assert { expected, ref msg, ref cond, .. } => { - self.check_assertion(*expected, msg, cond, location); - } - TerminatorKind::SwitchInt { ref discr, ref targets } => { - if let Some(ref value) = self.eval_operand(discr, location) - && let Some(value_const) = - self.use_ecx(location, |this| this.ecx.read_scalar(value)) - && let Ok(constant) = value_const.try_to_int() - && let Ok(constant) = constant.to_bits(constant.size()) - { - // We managed to evaluate the discriminant, so we know we only need to visit - // one target. - let target = targets.target_for_value(constant); - self.worklist.push(target); - return; - } - // We failed to evaluate the discriminant, fallback to visiting all successors. - } - // None of these have Operands to const-propagate. - TerminatorKind::Goto { .. } - | TerminatorKind::UnwindResume - | TerminatorKind::UnwindTerminate(_) - | TerminatorKind::Return - | TerminatorKind::Unreachable - | TerminatorKind::Drop { .. } - | TerminatorKind::Yield { .. } - | TerminatorKind::CoroutineDrop - | TerminatorKind::FalseEdge { .. } - | TerminatorKind::FalseUnwind { .. } - | TerminatorKind::Call { .. } - | TerminatorKind::InlineAsm { .. } => {} - } - - self.worklist.extend(terminator.successors()); - } - - fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { - self.super_basic_block_data(block, data); - - // We remove all Locals which are restricted in propagation to their containing blocks and - // which were modified in the current block. - // Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`. - let mut written_only_inside_own_block_locals = - std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals); - - // This loop can get very hot for some bodies: it check each local in each bb. - // To avoid this quadratic behaviour, we only clear the locals that were modified inside - // the current block. - // The order in which we remove consts does not matter. - #[allow(rustc::potential_query_instability)] - for local in written_only_inside_own_block_locals.drain() { - debug_assert_eq!( - self.ecx.machine.can_const_prop[local], - ConstPropMode::OnlyInsideOwnBlock - ); - Self::remove_const(&mut self.ecx, local); - } - self.ecx.machine.written_only_inside_own_block_locals = - written_only_inside_own_block_locals; - - if cfg!(debug_assertions) { - for (local, &mode) in self.ecx.machine.can_const_prop.iter_enumerated() { - match mode { - ConstPropMode::FullConstProp => {} - ConstPropMode::NoPropagation | ConstPropMode::OnlyInsideOwnBlock => { - self.ensure_not_propagated(local); - } - } - } - } - } -} diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index eaa36e0cc91..54b13a40e92 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -50,6 +50,9 @@ //! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing. //! Otherwise it drops all the values in scope at the last suspension point. +mod by_move_body; +pub use by_move_body::ByMoveBody; + use crate::abort_unwinding_calls; use crate::deref_separator::deref_finder; use crate::errors; @@ -723,7 +726,7 @@ fn replace_resume_ty_local<'tcx>( /// The async lowering step and the type / lifetime inference / checking are /// still using the `resume` argument for the time being. After this transform, /// the coroutine body doesn't have the `resume` argument. -fn transform_gen_context<'tcx>(_tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { +fn transform_gen_context<'tcx>(body: &mut Body<'tcx>) { // This leaves the local representing the `resume` argument in place, // but turns it into a regular local variable. This is cheaper than // adjusting all local references in the body after removing it. @@ -1231,7 +1234,12 @@ fn create_coroutine_drop_shim<'tcx>( drop_clean: BasicBlock, ) -> Body<'tcx> { let mut body = body.clone(); - body.arg_count = 1; // make sure the resume argument is not included here + // Take the coroutine info out of the body, since the drop shim is + // not a coroutine body itself; it just has its drop built out of it. + let _ = body.coroutine.take(); + // Make sure the resume argument is not included here, since we're + // building a body for `drop_in_place`. + body.arg_count = 1; let source_info = SourceInfo::outermost(body.span); @@ -1607,11 +1615,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { (args.discr_ty(tcx), coroutine_kind.movability() == hir::Movability::Movable) } _ => { - tcx.dcx().span_delayed_bug( - body.span, - format!("unexpected coroutine type {coroutine_ty}"), - ); - return; + tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}")); } }; @@ -1725,7 +1729,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // Remove the context argument within generator bodies. if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) { - transform_gen_context(tcx, body); + transform_gen_context(body); } // The original arguments to the function are no longer arguments, mark them as such. @@ -2071,7 +2075,7 @@ fn check_must_not_suspend_def( span: data.source_span, reason: s.as_str().to_string(), }); - tcx.emit_spanned_lint( + tcx.emit_node_span_lint( rustc_session::lint::builtin::MUST_NOT_SUSPEND, hir_id, data.source_span, diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs new file mode 100644 index 00000000000..e40f4520671 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -0,0 +1,159 @@ +//! A MIR pass which duplicates a coroutine's body and removes any derefs which +//! would be present for upvars that are taken by-ref. The result of which will +//! be a coroutine body that takes all of its upvars by-move, and which we stash +//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures. + +use rustc_data_structures::fx::FxIndexSet; +use rustc_hir as hir; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::{self, dump_mir, MirPass}; +use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt}; +use rustc_target::abi::FieldIdx; + +pub struct ByMoveBody; + +impl<'tcx> MirPass<'tcx> for ByMoveBody { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { + let Some(coroutine_def_id) = body.source.def_id().as_local() else { + return; + }; + let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) = + tcx.coroutine_kind(coroutine_def_id) + else { + return; + }; + let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + if coroutine_ty.references_error() { + return; + } + let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") }; + + let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(); + if coroutine_kind == ty::ClosureKind::FnOnce { + return; + } + + let mut by_ref_fields = FxIndexSet::default(); + let by_move_upvars = Ty::new_tup_from_iter( + tcx, + tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| { + if capture.is_by_ref() { + by_ref_fields.insert(FieldIdx::from_usize(idx)); + } + capture.place.ty() + }), + ); + let by_move_coroutine_ty = Ty::new_coroutine( + tcx, + coroutine_def_id.to_def_id(), + ty::CoroutineArgs::new( + tcx, + ty::CoroutineArgsParts { + parent_args: args.as_coroutine().parent_args(), + kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce), + resume_ty: args.as_coroutine().resume_ty(), + yield_ty: args.as_coroutine().yield_ty(), + return_ty: args.as_coroutine().return_ty(), + witness: args.as_coroutine().witness(), + tupled_upvars_ty: by_move_upvars, + }, + ) + .args, + ); + + let mut by_move_body = body.clone(); + MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body); + dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(())); + by_move_body.source = mir::MirSource { + instance: InstanceDef::CoroutineKindShim { + coroutine_def_id: coroutine_def_id.to_def_id(), + target_kind: ty::ClosureKind::FnOnce, + }, + promoted: None, + }; + body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body); + + // If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body. + // This is actually just a copy of the by-ref body, but with a different self type. + // FIXME(async_closures): We could probably unify this with the by-ref body somehow. + if coroutine_kind == ty::ClosureKind::Fn { + let by_mut_coroutine_ty = Ty::new_coroutine( + tcx, + coroutine_def_id.to_def_id(), + ty::CoroutineArgs::new( + tcx, + ty::CoroutineArgsParts { + parent_args: args.as_coroutine().parent_args(), + kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut), + resume_ty: args.as_coroutine().resume_ty(), + yield_ty: args.as_coroutine().yield_ty(), + return_ty: args.as_coroutine().return_ty(), + witness: args.as_coroutine().witness(), + tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(), + }, + ) + .args, + ); + let mut by_mut_body = body.clone(); + by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty; + dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(())); + by_mut_body.source = mir::MirSource { + instance: InstanceDef::CoroutineKindShim { + coroutine_def_id: coroutine_def_id.to_def_id(), + target_kind: ty::ClosureKind::FnMut, + }, + promoted: None, + }; + body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body); + } + } +} + +struct MakeByMoveBody<'tcx> { + tcx: TyCtxt<'tcx>, + by_ref_fields: FxIndexSet<FieldIdx>, + by_move_coroutine_ty: Ty<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place( + &mut self, + place: &mut mir::Place<'tcx>, + context: mir::visit::PlaceContext, + location: mir::Location, + ) { + if place.local == ty::CAPTURE_STRUCT_LOCAL + && !place.projection.is_empty() + && let mir::ProjectionElem::Field(idx, ty) = place.projection[0] + && self.by_ref_fields.contains(&idx) + { + let (begin, end) = place.projection[1..].split_first().unwrap(); + // FIXME(async_closures): I'm actually a bit surprised to see that we always + // initially deref the by-ref upvars. If this is not actually true, then we + // will at least get an ICE that explains why this isn't true :^) + assert_eq!(*begin, mir::ProjectionElem::Deref); + // Peel one ref off of the ty. + let peeled_ty = ty.builtin_deref(true).unwrap().ty; + *place = mir::Place { + local: place.local, + projection: self.tcx.mk_place_elems_from_iter( + [mir::ProjectionElem::Field(idx, peeled_ty)] + .into_iter() + .chain(end.iter().copied()), + ), + }; + } + self.super_place(place, context, location); + } + + fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) { + // Replace the type of the self arg. + if local == ty::CAPTURE_STRUCT_LOCAL { + local_decl.ty = self.by_move_coroutine_ty; + } + } +} diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs index 79bed960b95..2c692c95003 100644 --- a/compiler/rustc_mir_transform/src/cost_checker.rs +++ b/compiler/rustc_mir_transform/src/cost_checker.rs @@ -70,7 +70,7 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => { let fn_ty = self.instantiate_ty(f.const_.ty()); self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() - && tcx.is_intrinsic(def_id) + && tcx.intrinsic(def_id).is_some() { // Don't give intrinsics the extra penalty for calls INSTR_COST diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs index 8c11dea5d4e..9a1d8bae6b4 100644 --- a/compiler/rustc_mir_transform/src/coverage/counters.rs +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -1,4 +1,5 @@ -use rustc_data_structures::fx::FxIndexMap; +use rustc_data_structures::captures::Captures; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::graph::WithNumNodes; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; @@ -38,19 +39,27 @@ impl Debug for BcbCounter { } } +#[derive(Debug)] +pub(super) enum CounterIncrementSite { + Node { bcb: BasicCoverageBlock }, + Edge { from_bcb: BasicCoverageBlock, to_bcb: BasicCoverageBlock }, +} + /// Generates and stores coverage counter and coverage expression information /// associated with nodes/edges in the BCB graph. pub(super) struct CoverageCounters { - next_counter_id: CounterId, + /// List of places where a counter-increment statement should be injected + /// into MIR, each with its corresponding counter ID. + counter_increment_sites: IndexVec<CounterId, CounterIncrementSite>, /// Coverage counters/expressions that are associated with individual BCBs. bcb_counters: IndexVec<BasicCoverageBlock, Option<BcbCounter>>, /// Coverage counters/expressions that are associated with the control-flow /// edge between two BCBs. /// - /// The iteration order of this map can affect the precise contents of MIR, - /// so we use `FxIndexMap` to avoid query stability hazards. - bcb_edge_counters: FxIndexMap<(BasicCoverageBlock, BasicCoverageBlock), BcbCounter>, + /// We currently don't iterate over this map, but if we do in the future, + /// switch it back to `FxIndexMap` to avoid query stability hazards. + bcb_edge_counters: FxHashMap<(BasicCoverageBlock, BasicCoverageBlock), BcbCounter>, /// Tracks which BCBs have a counter associated with some incoming edge. /// Only used by assertions, to verify that BCBs with incoming edge /// counters do not have their own physical counters (expressions are allowed). @@ -71,9 +80,9 @@ impl CoverageCounters { let num_bcbs = basic_coverage_blocks.num_nodes(); let mut this = Self { - next_counter_id: CounterId::START, + counter_increment_sites: IndexVec::new(), bcb_counters: IndexVec::from_elem_n(None, num_bcbs), - bcb_edge_counters: FxIndexMap::default(), + bcb_edge_counters: FxHashMap::default(), bcb_has_incoming_edge_counters: BitSet::new_empty(num_bcbs), expressions: IndexVec::new(), }; @@ -84,8 +93,8 @@ impl CoverageCounters { this } - fn make_counter(&mut self) -> BcbCounter { - let id = self.next_counter(); + fn make_counter(&mut self, site: CounterIncrementSite) -> BcbCounter { + let id = self.counter_increment_sites.push(site); BcbCounter::Counter { id } } @@ -103,15 +112,8 @@ impl CoverageCounters { self.make_expression(lhs, Op::Add, rhs) } - /// Counter IDs start from one and go up. - fn next_counter(&mut self) -> CounterId { - let next = self.next_counter_id; - self.next_counter_id = self.next_counter_id + 1; - next - } - pub(super) fn num_counters(&self) -> usize { - self.next_counter_id.as_usize() + self.counter_increment_sites.len() } #[cfg(test)] @@ -171,22 +173,26 @@ impl CoverageCounters { self.bcb_counters[bcb] } - pub(super) fn bcb_node_counters( + /// Returns an iterator over all the nodes/edges in the coverage graph that + /// should have a counter-increment statement injected into MIR, along with + /// each site's corresponding counter ID. + pub(super) fn counter_increment_sites( &self, - ) -> impl Iterator<Item = (BasicCoverageBlock, &BcbCounter)> { - self.bcb_counters - .iter_enumerated() - .filter_map(|(bcb, counter_kind)| Some((bcb, counter_kind.as_ref()?))) + ) -> impl Iterator<Item = (CounterId, &CounterIncrementSite)> { + self.counter_increment_sites.iter_enumerated() } - /// For each edge in the BCB graph that has an associated counter, yields - /// that edge's *from* and *to* nodes, and its counter. - pub(super) fn bcb_edge_counters( + /// Returns an iterator over the subset of BCB nodes that have been associated + /// with a counter *expression*, along with the ID of that expression. + pub(super) fn bcb_nodes_with_coverage_expressions( &self, - ) -> impl Iterator<Item = (BasicCoverageBlock, BasicCoverageBlock, &BcbCounter)> { - self.bcb_edge_counters - .iter() - .map(|(&(from_bcb, to_bcb), counter_kind)| (from_bcb, to_bcb, counter_kind)) + ) -> impl Iterator<Item = (BasicCoverageBlock, ExpressionId)> + Captures<'_> { + self.bcb_counters.iter_enumerated().filter_map(|(bcb, &counter_kind)| match counter_kind { + // Yield the BCB along with its associated expression ID. + Some(BcbCounter::Expression { id }) => Some((bcb, id)), + // This BCB is associated with a counter or nothing, so skip it. + Some(BcbCounter::Counter { .. }) | None => None, + }) } pub(super) fn into_expressions(self) -> IndexVec<ExpressionId, Expression> { @@ -339,7 +345,8 @@ impl<'a> MakeBcbCounters<'a> { // program results in a tight infinite loop, but it should still compile. let one_path_to_target = !self.basic_coverage_blocks.bcb_has_multiple_in_edges(bcb); if one_path_to_target || self.bcb_predecessors(bcb).contains(&bcb) { - let counter_kind = self.coverage_counters.make_counter(); + let counter_kind = + self.coverage_counters.make_counter(CounterIncrementSite::Node { bcb }); if one_path_to_target { debug!("{bcb:?} gets a new counter: {counter_kind:?}"); } else { @@ -401,7 +408,8 @@ impl<'a> MakeBcbCounters<'a> { } // Make a new counter to count this edge. - let counter_kind = self.coverage_counters.make_counter(); + let counter_kind = + self.coverage_counters.make_counter(CounterIncrementSite::Edge { from_bcb, to_bcb }); debug!("Edge {from_bcb:?}->{to_bcb:?} gets a new counter: {counter_kind:?}"); self.coverage_counters.set_bcb_edge_counter(from_bcb, to_bcb, counter_kind) } diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index a11d224e8f1..4c5be0a3f4b 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -7,7 +7,7 @@ mod spans; #[cfg(test)] mod tests; -use self::counters::{BcbCounter, CoverageCounters}; +use self::counters::{CounterIncrementSite, CoverageCounters}; use self::graph::{BasicCoverageBlock, CoverageGraph}; use self::spans::{BcbMapping, BcbMappingKind, CoverageSpans}; @@ -59,170 +59,148 @@ impl<'tcx> MirPass<'tcx> for InstrumentCoverage { _ => {} } - trace!("InstrumentCoverage starting for {def_id:?}"); - Instrumentor::new(tcx, mir_body).inject_counters(); - trace!("InstrumentCoverage done for {def_id:?}"); + instrument_function_for_coverage(tcx, mir_body); } } -struct Instrumentor<'a, 'tcx> { - tcx: TyCtxt<'tcx>, - mir_body: &'a mut mir::Body<'tcx>, - hir_info: ExtractedHirInfo, - basic_coverage_blocks: CoverageGraph, -} +fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir::Body<'tcx>) { + let def_id = mir_body.source.def_id(); + let _span = debug_span!("instrument_function_for_coverage", ?def_id).entered(); -impl<'a, 'tcx> Instrumentor<'a, 'tcx> { - fn new(tcx: TyCtxt<'tcx>, mir_body: &'a mut mir::Body<'tcx>) -> Self { - let hir_info = extract_hir_info(tcx, mir_body.source.def_id().expect_local()); + let hir_info = extract_hir_info(tcx, def_id.expect_local()); + let basic_coverage_blocks = CoverageGraph::from_mir(mir_body); - debug!(?hir_info, "instrumenting {:?}", mir_body.source.def_id()); - - let basic_coverage_blocks = CoverageGraph::from_mir(mir_body); + //////////////////////////////////////////////////// + // Compute coverage spans from the `CoverageGraph`. + let Some(coverage_spans) = + spans::generate_coverage_spans(mir_body, &hir_info, &basic_coverage_blocks) + else { + // No relevant spans were found in MIR, so skip instrumenting this function. + return; + }; - Self { tcx, mir_body, hir_info, basic_coverage_blocks } + //////////////////////////////////////////////////// + // Create an optimized mix of `Counter`s and `Expression`s for the `CoverageGraph`. Ensure + // every coverage span has a `Counter` or `Expression` assigned to its `BasicCoverageBlock` + // and all `Expression` dependencies (operands) are also generated, for any other + // `BasicCoverageBlock`s not already associated with a coverage span. + let bcb_has_coverage_spans = |bcb| coverage_spans.bcb_has_coverage_spans(bcb); + let coverage_counters = + CoverageCounters::make_bcb_counters(&basic_coverage_blocks, bcb_has_coverage_spans); + + let mappings = create_mappings(tcx, &hir_info, &coverage_spans, &coverage_counters); + if mappings.is_empty() { + // No spans could be converted into valid mappings, so skip this function. + debug!("no spans could be converted into valid mappings; skipping"); + return; } - fn inject_counters(&'a mut self) { - //////////////////////////////////////////////////// - // Compute coverage spans from the `CoverageGraph`. - let Some(coverage_spans) = CoverageSpans::generate_coverage_spans( - self.mir_body, - &self.hir_info, - &self.basic_coverage_blocks, - ) else { - // No relevant spans were found in MIR, so skip instrumenting this function. - return; - }; - - //////////////////////////////////////////////////// - // Create an optimized mix of `Counter`s and `Expression`s for the `CoverageGraph`. Ensure - // every coverage span has a `Counter` or `Expression` assigned to its `BasicCoverageBlock` - // and all `Expression` dependencies (operands) are also generated, for any other - // `BasicCoverageBlock`s not already associated with a coverage span. - let bcb_has_coverage_spans = |bcb| coverage_spans.bcb_has_coverage_spans(bcb); - let coverage_counters = CoverageCounters::make_bcb_counters( - &self.basic_coverage_blocks, - bcb_has_coverage_spans, - ); - - let mappings = self.create_mappings(&coverage_spans, &coverage_counters); - if mappings.is_empty() { - // No spans could be converted into valid mappings, so skip this function. - debug!("no spans could be converted into valid mappings; skipping"); - return; - } - - self.inject_coverage_statements(bcb_has_coverage_spans, &coverage_counters); - - self.mir_body.function_coverage_info = Some(Box::new(FunctionCoverageInfo { - function_source_hash: self.hir_info.function_source_hash, - num_counters: coverage_counters.num_counters(), - expressions: coverage_counters.into_expressions(), - mappings, - })); - } + inject_coverage_statements( + mir_body, + &basic_coverage_blocks, + bcb_has_coverage_spans, + &coverage_counters, + ); - /// For each coverage span extracted from MIR, create a corresponding - /// mapping. - /// - /// Precondition: All BCBs corresponding to those spans have been given - /// coverage counters. - fn create_mappings( - &self, - coverage_spans: &CoverageSpans, - coverage_counters: &CoverageCounters, - ) -> Vec<Mapping> { - let source_map = self.tcx.sess.source_map(); - let body_span = self.hir_info.body_span; - - let source_file = source_map.lookup_source_file(body_span.lo()); - use rustc_session::RemapFileNameExt; - let file_name = - Symbol::intern(&source_file.name.for_codegen(self.tcx.sess).to_string_lossy()); - - let term_for_bcb = |bcb| { - coverage_counters - .bcb_counter(bcb) - .expect("all BCBs with spans were given counters") - .as_term() - }; + mir_body.function_coverage_info = Some(Box::new(FunctionCoverageInfo { + function_source_hash: hir_info.function_source_hash, + num_counters: coverage_counters.num_counters(), + expressions: coverage_counters.into_expressions(), + mappings, + })); +} - coverage_spans - .all_bcb_mappings() - .filter_map(|&BcbMapping { kind: bcb_mapping_kind, span }| { - let kind = match bcb_mapping_kind { - BcbMappingKind::Code(bcb) => MappingKind::Code(term_for_bcb(bcb)), - }; - let code_region = make_code_region(source_map, file_name, span, body_span)?; - Some(Mapping { kind, code_region }) - }) - .collect::<Vec<_>>() - } +/// For each coverage span extracted from MIR, create a corresponding +/// mapping. +/// +/// Precondition: All BCBs corresponding to those spans have been given +/// coverage counters. +fn create_mappings<'tcx>( + tcx: TyCtxt<'tcx>, + hir_info: &ExtractedHirInfo, + coverage_spans: &CoverageSpans, + coverage_counters: &CoverageCounters, +) -> Vec<Mapping> { + let source_map = tcx.sess.source_map(); + let body_span = hir_info.body_span; + + let source_file = source_map.lookup_source_file(body_span.lo()); + use rustc_session::RemapFileNameExt; + let file_name = Symbol::intern(&source_file.name.for_codegen(tcx.sess).to_string_lossy()); + + let term_for_bcb = |bcb| { + coverage_counters + .bcb_counter(bcb) + .expect("all BCBs with spans were given counters") + .as_term() + }; - /// For each BCB node or BCB edge that has an associated coverage counter, - /// inject any necessary coverage statements into MIR. - fn inject_coverage_statements( - &mut self, - bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool, - coverage_counters: &CoverageCounters, - ) { - // Process the counters associated with BCB nodes. - for (bcb, counter_kind) in coverage_counters.bcb_node_counters() { - let do_inject = match counter_kind { - // Counter-increment statements always need to be injected. - BcbCounter::Counter { .. } => true, - // The only purpose of expression-used statements is to detect - // when a mapping is unreachable, so we only inject them for - // expressions with one or more mappings. - BcbCounter::Expression { .. } => bcb_has_coverage_spans(bcb), + coverage_spans + .all_bcb_mappings() + .filter_map(|&BcbMapping { kind: bcb_mapping_kind, span }| { + let kind = match bcb_mapping_kind { + BcbMappingKind::Code(bcb) => MappingKind::Code(term_for_bcb(bcb)), }; - if do_inject { - inject_statement( - self.mir_body, - self.make_mir_coverage_kind(counter_kind), - self.basic_coverage_blocks[bcb].leader_bb(), - ); - } - } + let code_region = make_code_region(source_map, file_name, span, body_span)?; + Some(Mapping { kind, code_region }) + }) + .collect::<Vec<_>>() +} - // Process the counters associated with BCB edges. - for (from_bcb, to_bcb, counter_kind) in coverage_counters.bcb_edge_counters() { - let do_inject = match counter_kind { - // Counter-increment statements always need to be injected. - BcbCounter::Counter { .. } => true, - // BCB-edge expressions never have mappings, so they never need - // a corresponding statement. - BcbCounter::Expression { .. } => false, - }; - if !do_inject { - continue; +/// For each BCB node or BCB edge that has an associated coverage counter, +/// inject any necessary coverage statements into MIR. +fn inject_coverage_statements<'tcx>( + mir_body: &mut mir::Body<'tcx>, + basic_coverage_blocks: &CoverageGraph, + bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool, + coverage_counters: &CoverageCounters, +) { + // Inject counter-increment statements into MIR. + for (id, counter_increment_site) in coverage_counters.counter_increment_sites() { + // Determine the block to inject a counter-increment statement into. + // For BCB nodes this is just their first block, but for edges we need + // to create a new block between the two BCBs, and inject into that. + let target_bb = match *counter_increment_site { + CounterIncrementSite::Node { bcb } => basic_coverage_blocks[bcb].leader_bb(), + CounterIncrementSite::Edge { from_bcb, to_bcb } => { + // Create a new block between the last block of `from_bcb` and + // the first block of `to_bcb`. + let from_bb = basic_coverage_blocks[from_bcb].last_bb(); + let to_bb = basic_coverage_blocks[to_bcb].leader_bb(); + + let new_bb = inject_edge_counter_basic_block(mir_body, from_bb, to_bb); + debug!( + "Edge {from_bcb:?} (last {from_bb:?}) -> {to_bcb:?} (leader {to_bb:?}) \ + requires a new MIR BasicBlock {new_bb:?} for counter increment {id:?}", + ); + new_bb } + }; - // We need to inject a coverage statement into a new BB between the - // last BB of `from_bcb` and the first BB of `to_bcb`. - let from_bb = self.basic_coverage_blocks[from_bcb].last_bb(); - let to_bb = self.basic_coverage_blocks[to_bcb].leader_bb(); - - let new_bb = inject_edge_counter_basic_block(self.mir_body, from_bb, to_bb); - debug!( - "Edge {from_bcb:?} (last {from_bb:?}) -> {to_bcb:?} (leader {to_bb:?}) \ - requires a new MIR BasicBlock {new_bb:?} for edge counter {counter_kind:?}", - ); - - // Inject a counter into the newly-created BB. - inject_statement(self.mir_body, self.make_mir_coverage_kind(counter_kind), new_bb); - } + inject_statement(mir_body, CoverageKind::CounterIncrement { id }, target_bb); } - fn make_mir_coverage_kind(&self, counter_kind: &BcbCounter) -> CoverageKind { - match *counter_kind { - BcbCounter::Counter { id } => CoverageKind::CounterIncrement { id }, - BcbCounter::Expression { id } => CoverageKind::ExpressionUsed { id }, - } + // For each counter expression that is directly associated with at least one + // span, we inject an "expression-used" statement, so that coverage codegen + // can check whether the injected statement survived MIR optimization. + // (BCB edges can't have spans, so we only need to process BCB nodes here.) + // + // See the code in `rustc_codegen_llvm::coverageinfo::map_data` that deals + // with "expressions seen" and "zero terms". + for (bcb, expression_id) in coverage_counters + .bcb_nodes_with_coverage_expressions() + .filter(|&(bcb, _)| bcb_has_coverage_spans(bcb)) + { + inject_statement( + mir_body, + CoverageKind::ExpressionUsed { id: expression_id }, + basic_coverage_blocks[bcb].leader_bb(), + ); } } +/// Given two basic blocks that have a control-flow edge between them, creates +/// and returns a new block that sits between those blocks. fn inject_edge_counter_basic_block( mir_body: &mut mir::Body<'_>, from_bb: BasicBlock, @@ -329,7 +307,7 @@ fn make_code_region( start_line = source_map.doctest_offset_line(&file.name, start_line); end_line = source_map.doctest_offset_line(&file.name, end_line); - Some(CodeRegion { + check_code_region(CodeRegion { file_name, start_line: start_line as u32, start_col: start_col as u32, @@ -338,6 +316,39 @@ fn make_code_region( }) } +/// If `llvm-cov` sees a code region that is improperly ordered (end < start), +/// it will immediately exit with a fatal error. To prevent that from happening, +/// discard regions that are improperly ordered, or might be interpreted in a +/// way that makes them improperly ordered. +fn check_code_region(code_region: CodeRegion) -> Option<CodeRegion> { + let CodeRegion { file_name: _, start_line, start_col, end_line, end_col } = code_region; + + // Line/column coordinates are supposed to be 1-based. If we ever emit + // coordinates of 0, `llvm-cov` might misinterpret them. + let all_nonzero = [start_line, start_col, end_line, end_col].into_iter().all(|x| x != 0); + // Coverage mappings use the high bit of `end_col` to indicate that a + // region is actually a "gap" region, so make sure it's unset. + let end_col_has_high_bit_unset = (end_col & (1 << 31)) == 0; + // If a region is improperly ordered (end < start), `llvm-cov` will exit + // with a fatal error, which is inconvenient for users and hard to debug. + let is_ordered = (start_line, start_col) <= (end_line, end_col); + + if all_nonzero && end_col_has_high_bit_unset && is_ordered { + Some(code_region) + } else { + debug!( + ?code_region, + ?all_nonzero, + ?end_col_has_high_bit_unset, + ?is_ordered, + "Skipping code region that would be misinterpreted or rejected by LLVM" + ); + // If this happens in a debug build, ICE to make it easier to notice. + debug_assert!(false, "Improper code region: {code_region:?}"); + None + } +} + fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { // Only instrument functions, methods, and closures (not constants since they are evaluated // at compile time by Miri). @@ -351,7 +362,18 @@ fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { return false; } + // Don't instrument functions with `#[automatically_derived]` on their + // enclosing impl block, on the assumption that most users won't care about + // coverage for derived impls. + if let Some(impl_of) = tcx.impl_of_method(def_id.to_def_id()) + && tcx.is_automatically_derived(impl_of) + { + trace!("InstrumentCoverage skipped for {def_id:?} (automatically derived)"); + return false; + } + if tcx.codegen_fn_attrs(def_id).flags.contains(CodegenFnAttrFlags::NO_COVERAGE) { + trace!("InstrumentCoverage skipped for {def_id:?} (`#[coverage(off)]`)"); return false; } @@ -363,7 +385,9 @@ fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { struct ExtractedHirInfo { function_source_hash: u64, is_async_fn: bool, - fn_sig_span: Span, + /// The span of the function's signature, extended to the start of `body_span`. + /// Must have the same context and filename as the body span. + fn_sig_span_extended: Option<Span>, body_span: Span, } @@ -376,13 +400,25 @@ fn extract_hir_info<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> ExtractedHir hir::map::associated_body(hir_node).expect("HIR node is a function with body"); let hir_body = tcx.hir().body(fn_body_id); - let is_async_fn = hir_node.fn_sig().is_some_and(|fn_sig| fn_sig.header.is_async()); - let body_span = get_body_span(tcx, hir_body, def_id); + let maybe_fn_sig = hir_node.fn_sig(); + let is_async_fn = maybe_fn_sig.is_some_and(|fn_sig| fn_sig.header.is_async()); + + let mut body_span = hir_body.value.span; + + use rustc_hir::{Closure, Expr, ExprKind, Node}; + // Unexpand a closure's body span back to the context of its declaration. + // This helps with closure bodies that consist of just a single bang-macro, + // and also with closure bodies produced by async desugaring. + if let Node::Expr(&Expr { kind: ExprKind::Closure(&Closure { fn_decl_span, .. }), .. }) = + hir_node + { + body_span = body_span.find_ancestor_in_same_ctxt(fn_decl_span).unwrap_or(body_span); + } // The actual signature span is only used if it has the same context and // filename as the body, and precedes the body. - let maybe_fn_sig_span = hir_node.fn_sig().map(|fn_sig| fn_sig.span); - let fn_sig_span = maybe_fn_sig_span + let fn_sig_span_extended = maybe_fn_sig + .map(|fn_sig| fn_sig.span) .filter(|&fn_sig_span| { let source_map = tcx.sess.source_map(); let file_idx = |span: Span| source_map.lookup_source_file_idx(span.lo()); @@ -392,39 +428,15 @@ fn extract_hir_info<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> ExtractedHir && file_idx(fn_sig_span) == file_idx(body_span) }) // If so, extend it to the start of the body span. - .map(|fn_sig_span| fn_sig_span.with_hi(body_span.lo())) - // Otherwise, create a dummy signature span at the start of the body. - .unwrap_or_else(|| body_span.shrink_to_lo()); + .map(|fn_sig_span| fn_sig_span.with_hi(body_span.lo())); let function_source_hash = hash_mir_source(tcx, hir_body); - ExtractedHirInfo { function_source_hash, is_async_fn, fn_sig_span, body_span } -} - -fn get_body_span<'tcx>( - tcx: TyCtxt<'tcx>, - hir_body: &rustc_hir::Body<'tcx>, - def_id: LocalDefId, -) -> Span { - let mut body_span = hir_body.value.span; - - if tcx.is_closure_or_coroutine(def_id.to_def_id()) { - // If the current function is a closure, and its "body" span was created - // by macro expansion or compiler desugaring, try to walk backwards to - // the pre-expansion call site or body. - body_span = body_span.source_callsite(); - } - - body_span + ExtractedHirInfo { function_source_hash, is_async_fn, fn_sig_span_extended, body_span } } fn hash_mir_source<'tcx>(tcx: TyCtxt<'tcx>, hir_body: &'tcx rustc_hir::Body<'tcx>) -> u64 { // FIXME(cjgillot) Stop hashing HIR manually here. let owner = hir_body.id().hir_id.owner; - tcx.hir_owner_nodes(owner) - .unwrap() - .opt_hash_including_bodies - .unwrap() - .to_smaller_hash() - .as_u64() + tcx.hir_owner_nodes(owner).opt_hash_including_bodies.unwrap().to_smaller_hash().as_u64() } diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index 81f6c831206..98fb1d8e1c9 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -1,9 +1,10 @@ use rustc_data_structures::graph::WithNumNodes; use rustc_index::bit_set::BitSet; use rustc_middle::mir; -use rustc_span::{BytePos, Span, DUMMY_SP}; +use rustc_span::{BytePos, Span}; -use super::graph::{BasicCoverageBlock, CoverageGraph}; +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; +use crate::coverage::spans::from_mir::SpanFromMir; use crate::coverage::ExtractedHirInfo; mod from_mir; @@ -26,66 +27,92 @@ pub(super) struct CoverageSpans { } impl CoverageSpans { - /// Extracts coverage-relevant spans from MIR, and associates them with - /// their corresponding BCBs. - /// - /// Returns `None` if no coverage-relevant spans could be extracted. - pub(super) fn generate_coverage_spans( - mir_body: &mir::Body<'_>, - hir_info: &ExtractedHirInfo, - basic_coverage_blocks: &CoverageGraph, - ) -> Option<Self> { - let mut mappings = vec![]; - - let coverage_spans = CoverageSpansGenerator::generate_coverage_spans( + pub(super) fn bcb_has_coverage_spans(&self, bcb: BasicCoverageBlock) -> bool { + self.bcb_has_mappings.contains(bcb) + } + + pub(super) fn all_bcb_mappings(&self) -> impl Iterator<Item = &BcbMapping> { + self.mappings.iter() + } +} + +/// Extracts coverage-relevant spans from MIR, and associates them with +/// their corresponding BCBs. +/// +/// Returns `None` if no coverage-relevant spans could be extracted. +pub(super) fn generate_coverage_spans( + mir_body: &mir::Body<'_>, + hir_info: &ExtractedHirInfo, + basic_coverage_blocks: &CoverageGraph, +) -> Option<CoverageSpans> { + let mut mappings = vec![]; + + if hir_info.is_async_fn { + // An async function desugars into a function that returns a future, + // with the user code wrapped in a closure. Any spans in the desugared + // outer function will be unhelpful, so just keep the signature span + // and ignore all of the spans in the MIR body. + if let Some(span) = hir_info.fn_sig_span_extended { + mappings.push(BcbMapping { kind: BcbMappingKind::Code(START_BCB), span }); + } + } else { + let sorted_spans = from_mir::mir_to_initial_sorted_coverage_spans( mir_body, hir_info, basic_coverage_blocks, ); - mappings.extend(coverage_spans.into_iter().map(|CoverageSpan { bcb, span, .. }| { + let coverage_spans = SpansRefiner::refine_sorted_spans(sorted_spans); + mappings.extend(coverage_spans.into_iter().map(|RefinedCovspan { bcb, span, .. }| { // Each span produced by the generator represents an ordinary code region. BcbMapping { kind: BcbMappingKind::Code(bcb), span } })); + } - if mappings.is_empty() { - return None; - } + if mappings.is_empty() { + return None; + } - // Identify which BCBs have one or more mappings. - let mut bcb_has_mappings = BitSet::new_empty(basic_coverage_blocks.num_nodes()); - let mut insert = |bcb| { - bcb_has_mappings.insert(bcb); - }; - for &BcbMapping { kind, span: _ } in &mappings { - match kind { - BcbMappingKind::Code(bcb) => insert(bcb), - } + // Identify which BCBs have one or more mappings. + let mut bcb_has_mappings = BitSet::new_empty(basic_coverage_blocks.num_nodes()); + let mut insert = |bcb| { + bcb_has_mappings.insert(bcb); + }; + for &BcbMapping { kind, span: _ } in &mappings { + match kind { + BcbMappingKind::Code(bcb) => insert(bcb), } + } - Some(Self { bcb_has_mappings, mappings }) + Some(CoverageSpans { bcb_has_mappings, mappings }) +} + +#[derive(Debug)] +struct CurrCovspan { + span: Span, + bcb: BasicCoverageBlock, + is_closure: bool, +} + +impl CurrCovspan { + fn new(span: Span, bcb: BasicCoverageBlock, is_closure: bool) -> Self { + Self { span, bcb, is_closure } } - pub(super) fn bcb_has_coverage_spans(&self, bcb: BasicCoverageBlock) -> bool { - self.bcb_has_mappings.contains(bcb) + fn into_prev(self) -> PrevCovspan { + let Self { span, bcb, is_closure } = self; + PrevCovspan { span, bcb, merged_spans: vec![span], is_closure } } - pub(super) fn all_bcb_mappings(&self) -> impl Iterator<Item = &BcbMapping> { - self.mappings.iter() + fn into_refined(self) -> RefinedCovspan { + // This is only called in cases where `curr` is a closure span that has + // been carved out of `prev`. + debug_assert!(self.is_closure); + self.into_prev().into_refined() } } -/// A BCB is deconstructed into one or more `Span`s. Each `Span` maps to a `CoverageSpan` that -/// references the originating BCB and one or more MIR `Statement`s and/or `Terminator`s. -/// Initially, the `Span`s come from the `Statement`s and `Terminator`s, but subsequent -/// transforms can combine adjacent `Span`s and `CoverageSpan` from the same BCB, merging the -/// `merged_spans` vectors, and the `Span`s to cover the extent of the combined `Span`s. -/// -/// Note: A span merged into another CoverageSpan may come from a `BasicBlock` that -/// is not part of the `CoverageSpan` bcb if the statement was included because it's `Span` matches -/// or is subsumed by the `Span` associated with this `CoverageSpan`, and it's `BasicBlock` -/// `dominates()` the `BasicBlock`s in this `CoverageSpan`. -#[derive(Debug, Clone)] -struct CoverageSpan { +#[derive(Debug)] +struct PrevCovspan { span: Span, bcb: BasicCoverageBlock, /// List of all the original spans from MIR that have been merged into this @@ -94,135 +121,101 @@ struct CoverageSpan { is_closure: bool, } -impl CoverageSpan { - fn new(span: Span, bcb: BasicCoverageBlock, is_closure: bool) -> Self { - Self { span, bcb, merged_spans: vec![span], is_closure } +impl PrevCovspan { + fn is_mergeable(&self, other: &CurrCovspan) -> bool { + self.bcb == other.bcb && !self.is_closure && !other.is_closure } - pub fn merge_from(&mut self, other: &Self) { + fn merge_from(&mut self, other: &CurrCovspan) { debug_assert!(self.is_mergeable(other)); self.span = self.span.to(other.span); - self.merged_spans.extend_from_slice(&other.merged_spans); + self.merged_spans.push(other.span); } - pub fn cutoff_statements_at(&mut self, cutoff_pos: BytePos) { + fn cutoff_statements_at(mut self, cutoff_pos: BytePos) -> Option<RefinedCovspan> { self.merged_spans.retain(|span| span.hi() <= cutoff_pos); if let Some(max_hi) = self.merged_spans.iter().map(|span| span.hi()).max() { self.span = self.span.with_hi(max_hi); } + + if self.merged_spans.is_empty() { None } else { Some(self.into_refined()) } } - #[inline] - pub fn is_mergeable(&self, other: &Self) -> bool { - self.is_in_same_bcb(other) && !(self.is_closure || other.is_closure) + fn refined_copy(&self) -> RefinedCovspan { + let &Self { span, bcb, merged_spans: _, is_closure } = self; + RefinedCovspan { span, bcb, is_closure } } - #[inline] - pub fn is_in_same_bcb(&self, other: &Self) -> bool { - self.bcb == other.bcb + fn into_refined(self) -> RefinedCovspan { + // Even though we consume self, we can just reuse the copying impl. + self.refined_copy() } } -/// Converts the initial set of `CoverageSpan`s (one per MIR `Statement` or `Terminator`) into a -/// minimal set of `CoverageSpan`s, using the BCB CFG to determine where it is safe and useful to: +#[derive(Debug)] +struct RefinedCovspan { + span: Span, + bcb: BasicCoverageBlock, + is_closure: bool, +} + +impl RefinedCovspan { + fn is_mergeable(&self, other: &Self) -> bool { + self.bcb == other.bcb && !self.is_closure && !other.is_closure + } + + fn merge_from(&mut self, other: &Self) { + debug_assert!(self.is_mergeable(other)); + self.span = self.span.to(other.span); + } +} + +/// Converts the initial set of coverage spans (one per MIR `Statement` or `Terminator`) into a +/// minimal set of coverage spans, using the BCB CFG to determine where it is safe and useful to: /// /// * Remove duplicate source code coverage regions /// * Merge spans that represent continuous (both in source code and control flow), non-branching /// execution /// * Carve out (leave uncovered) any span that will be counted by another MIR (notably, closures) -struct CoverageSpansGenerator<'a> { - /// The BasicCoverageBlock Control Flow Graph (BCB CFG). - basic_coverage_blocks: &'a CoverageGraph, - - /// The initial set of `CoverageSpan`s, sorted by `Span` (`lo` and `hi`) and by relative +struct SpansRefiner { + /// The initial set of coverage spans, sorted by `Span` (`lo` and `hi`) and by relative /// dominance between the `BasicCoverageBlock`s of equal `Span`s. - sorted_spans_iter: std::vec::IntoIter<CoverageSpan>, + sorted_spans_iter: std::vec::IntoIter<SpanFromMir>, - /// The current `CoverageSpan` to compare to its `prev`, to possibly merge, discard, force the + /// The current coverage span to compare to its `prev`, to possibly merge, discard, force the /// discard of the `prev` (and or `pending_dups`), or keep both (with `prev` moved to /// `pending_dups`). If `curr` is not discarded or merged, it becomes `prev` for the next /// iteration. - some_curr: Option<CoverageSpan>, - - /// The original `span` for `curr`, in case `curr.span()` is modified. The `curr_original_span` - /// **must not be mutated** (except when advancing to the next `curr`), even if `curr.span()` - /// is mutated. - curr_original_span: Span, + some_curr: Option<CurrCovspan>, - /// The CoverageSpan from a prior iteration; typically assigned from that iteration's `curr`. + /// The coverage span from a prior iteration; typically assigned from that iteration's `curr`. /// If that `curr` was discarded, `prev` retains its value from the previous iteration. - some_prev: Option<CoverageSpan>, - - /// Assigned from `curr_original_span` from the previous iteration. The `prev_original_span` - /// **must not be mutated** (except when advancing to the next `prev`), even if `prev.span()` - /// is mutated. - prev_original_span: Span, - - /// One or more `CoverageSpan`s with the same `Span` but different `BasicCoverageBlock`s, and - /// no `BasicCoverageBlock` in this list dominates another `BasicCoverageBlock` in the list. - /// If a new `curr` span also fits this criteria (compared to an existing list of - /// `pending_dups`), that `curr` `CoverageSpan` moves to `prev` before possibly being added to - /// the `pending_dups` list, on the next iteration. As a result, if `prev` and `pending_dups` - /// have the same `Span`, the criteria for `pending_dups` holds for `prev` as well: a `prev` - /// with a matching `Span` does not dominate any `pending_dup` and no `pending_dup` dominates a - /// `prev` with a matching `Span`) - pending_dups: Vec<CoverageSpan>, - - /// The final `CoverageSpan`s to add to the coverage map. A `Counter` or `Expression` - /// will also be injected into the MIR for each `CoverageSpan`. - refined_spans: Vec<CoverageSpan>, -} + some_prev: Option<PrevCovspan>, -impl<'a> CoverageSpansGenerator<'a> { - /// Generate a minimal set of `CoverageSpan`s, each representing a contiguous code region to be - /// counted. - /// - /// The basic steps are: - /// - /// 1. Extract an initial set of spans from the `Statement`s and `Terminator`s of each - /// `BasicCoverageBlockData`. - /// 2. Sort the spans by span.lo() (starting position). Spans that start at the same position - /// are sorted with longer spans before shorter spans; and equal spans are sorted - /// (deterministically) based on "dominator" relationship (if any). - /// 3. Traverse the spans in sorted order to identify spans that can be dropped (for instance, - /// if another span or spans are already counting the same code region), or should be merged - /// into a broader combined span (because it represents a contiguous, non-branching, and - /// uninterrupted region of source code). - /// - /// Closures are exposed in their enclosing functions as `Assign` `Rvalue`s, and since - /// closures have their own MIR, their `Span` in their enclosing function should be left - /// "uncovered". - /// - /// Note the resulting vector of `CoverageSpan`s may not be fully sorted (and does not need - /// to be). - pub(super) fn generate_coverage_spans( - mir_body: &mir::Body<'_>, - hir_info: &ExtractedHirInfo, - basic_coverage_blocks: &'a CoverageGraph, - ) -> Vec<CoverageSpan> { - let sorted_spans = from_mir::mir_to_initial_sorted_coverage_spans( - mir_body, - hir_info, - basic_coverage_blocks, - ); + /// The final coverage spans to add to the coverage map. A `Counter` or `Expression` + /// will also be injected into the MIR for each BCB that has associated spans. + refined_spans: Vec<RefinedCovspan>, +} - let coverage_spans = Self { - basic_coverage_blocks, +impl SpansRefiner { + /// Takes the initial list of (sorted) spans extracted from MIR, and "refines" + /// them by merging compatible adjacent spans, removing redundant spans, + /// and carving holes in spans when they overlap in unwanted ways. + fn refine_sorted_spans(sorted_spans: Vec<SpanFromMir>) -> Vec<RefinedCovspan> { + let sorted_spans_len = sorted_spans.len(); + let this = Self { sorted_spans_iter: sorted_spans.into_iter(), some_curr: None, - curr_original_span: DUMMY_SP, some_prev: None, - prev_original_span: DUMMY_SP, - pending_dups: Vec::new(), - refined_spans: Vec::with_capacity(basic_coverage_blocks.num_nodes() * 2), + refined_spans: Vec::with_capacity(sorted_spans_len), }; - coverage_spans.to_refined_spans() + this.to_refined_spans() } - /// Iterate through the sorted `CoverageSpan`s, and return the refined list of merged and - /// de-duplicated `CoverageSpan`s. - fn to_refined_spans(mut self) -> Vec<CoverageSpan> { + /// Iterate through the sorted coverage spans, and return the refined list of merged and + /// de-duplicated spans. + fn to_refined_spans(mut self) -> Vec<RefinedCovspan> { while self.next_coverage_span() { // For the first span we don't have `prev` set, so most of the // span-processing steps don't make sense yet. @@ -235,16 +228,15 @@ impl<'a> CoverageSpansGenerator<'a> { let prev = self.prev(); let curr = self.curr(); - if curr.is_mergeable(prev) { + if prev.is_mergeable(curr) { debug!(" same bcb (and neither is a closure), merge with prev={prev:?}"); - let prev = self.take_prev(); - self.curr_mut().merge_from(&prev); - // Note that curr.span may now differ from curr_original_span + let curr = self.take_curr(); + self.prev_mut().merge_from(&curr); } else if prev.span.hi() <= curr.span.lo() { debug!( " different bcbs and disjoint spans, so keep curr for next iter, and add prev={prev:?}", ); - let prev = self.take_prev(); + let prev = self.take_prev().into_refined(); self.refined_spans.push(prev); } else if prev.is_closure { // drop any equal or overlapping span (`curr`) and keep `prev` to test again in the @@ -255,26 +247,16 @@ impl<'a> CoverageSpansGenerator<'a> { self.take_curr(); // Discards curr. } else if curr.is_closure { self.carve_out_span_for_closure(); - } else if self.prev_original_span == curr.span { - // `prev` and `curr` have the same span, or would have had the - // same span before `prev` was modified by other spans. - self.update_pending_dups(); } else { self.cutoff_prev_at_overlapping_curr(); } } - // Drain any remaining dups into the output. - for dup in self.pending_dups.drain(..) { - debug!(" ...adding at least one pending dup={:?}", dup); - self.refined_spans.push(dup); - } - // There is usually a final span remaining in `prev` after the loop ends, // so add it to the output as well. if let Some(prev) = self.some_prev.take() { debug!(" AT END, adding last prev={prev:?}"); - self.refined_spans.push(prev); + self.refined_spans.push(prev.into_refined()); } // Do one last merge pass, to simplify the output. @@ -288,7 +270,7 @@ impl<'a> CoverageSpansGenerator<'a> { } }); - // Remove `CoverageSpan`s derived from closures, originally added to ensure the coverage + // Remove spans derived from closures, originally added to ensure the coverage // regions for the current function leave room for the closure's own coverage regions // (injected separately, from the closure's own MIR). self.refined_spans.retain(|covspan| !covspan.is_closure); @@ -296,72 +278,36 @@ impl<'a> CoverageSpansGenerator<'a> { } #[track_caller] - fn curr(&self) -> &CoverageSpan { + fn curr(&self) -> &CurrCovspan { self.some_curr.as_ref().unwrap_or_else(|| bug!("some_curr is None (curr)")) } - #[track_caller] - fn curr_mut(&mut self) -> &mut CoverageSpan { - self.some_curr.as_mut().unwrap_or_else(|| bug!("some_curr is None (curr_mut)")) - } - /// If called, then the next call to `next_coverage_span()` will *not* update `prev` with the /// `curr` coverage span. #[track_caller] - fn take_curr(&mut self) -> CoverageSpan { + fn take_curr(&mut self) -> CurrCovspan { self.some_curr.take().unwrap_or_else(|| bug!("some_curr is None (take_curr)")) } #[track_caller] - fn prev(&self) -> &CoverageSpan { + fn prev(&self) -> &PrevCovspan { self.some_prev.as_ref().unwrap_or_else(|| bug!("some_prev is None (prev)")) } #[track_caller] - fn prev_mut(&mut self) -> &mut CoverageSpan { + fn prev_mut(&mut self) -> &mut PrevCovspan { self.some_prev.as_mut().unwrap_or_else(|| bug!("some_prev is None (prev_mut)")) } #[track_caller] - fn take_prev(&mut self) -> CoverageSpan { + fn take_prev(&mut self) -> PrevCovspan { self.some_prev.take().unwrap_or_else(|| bug!("some_prev is None (take_prev)")) } - /// If there are `pending_dups` but `prev` is not a matching dup (`prev.span` doesn't match the - /// `pending_dups` spans), then one of the following two things happened during the previous - /// iteration: - /// * the previous `curr` span (which is now `prev`) was not a duplicate of the pending_dups - /// (in which case there should be at least two spans in `pending_dups`); or - /// * the `span` of `prev` was modified by `curr_mut().merge_from(prev)` (in which case - /// `pending_dups` could have as few as one span) - /// In either case, no more spans will match the span of `pending_dups`, so - /// add the `pending_dups` if they don't overlap `curr`, and clear the list. - fn maybe_flush_pending_dups(&mut self) { - let Some(last_dup) = self.pending_dups.last() else { return }; - if last_dup.span == self.prev().span { - return; - } - - debug!( - " SAME spans, but pending_dups are NOT THE SAME, so BCBs matched on \ - previous iteration, or prev started a new disjoint span" - ); - if last_dup.span.hi() <= self.curr().span.lo() { - for dup in self.pending_dups.drain(..) { - debug!(" ...adding at least one pending={:?}", dup); - self.refined_spans.push(dup); - } - } else { - self.pending_dups.clear(); - } - assert!(self.pending_dups.is_empty()); - } - - /// Advance `prev` to `curr` (if any), and `curr` to the next `CoverageSpan` in sorted order. + /// Advance `prev` to `curr` (if any), and `curr` to the next coverage span in sorted order. fn next_coverage_span(&mut self) -> bool { if let Some(curr) = self.some_curr.take() { - self.some_prev = Some(curr); - self.prev_original_span = self.curr_original_span; + self.some_prev = Some(curr.into_prev()); } while let Some(curr) = self.sorted_spans_iter.next() { debug!("FOR curr={:?}", curr); @@ -376,11 +322,7 @@ impl<'a> CoverageSpansGenerator<'a> { closure?); prev={prev:?}", ); } else { - // Save a copy of the original span for `curr` in case the `CoverageSpan` is changed - // by `self.curr_mut().merge_from(prev)`. - self.curr_original_span = curr.span; - self.some_curr.replace(curr); - self.maybe_flush_pending_dups(); + self.some_curr = Some(CurrCovspan::new(curr.span, curr.bcb, curr.is_closure)); return true; } } @@ -402,107 +344,20 @@ impl<'a> CoverageSpansGenerator<'a> { let has_post_closure_span = prev.span.hi() > right_cutoff; if has_pre_closure_span { - let mut pre_closure = self.prev().clone(); + let mut pre_closure = self.prev().refined_copy(); pre_closure.span = pre_closure.span.with_hi(left_cutoff); debug!(" prev overlaps a closure. Adding span for pre_closure={:?}", pre_closure); - - for mut dup in self.pending_dups.iter().cloned() { - dup.span = dup.span.with_hi(left_cutoff); - debug!(" ...and at least one pre_closure dup={:?}", dup); - self.refined_spans.push(dup); - } - self.refined_spans.push(pre_closure); } if has_post_closure_span { - // Mutate `prev.span()` to start after the closure (and discard curr). - // (**NEVER** update `prev_original_span` because it affects the assumptions - // about how the `CoverageSpan`s are ordered.) + // Mutate `prev.span` to start after the closure (and discard curr). self.prev_mut().span = self.prev().span.with_lo(right_cutoff); debug!(" Mutated prev.span to start after the closure. prev={:?}", self.prev()); - for dup in &mut self.pending_dups { - debug!(" ...and at least one overlapping dup={:?}", dup); - dup.span = dup.span.with_lo(right_cutoff); - } - - let closure_covspan = self.take_curr(); // Prevent this curr from becoming prev. + // Prevent this curr from becoming prev. + let closure_covspan = self.take_curr().into_refined(); self.refined_spans.push(closure_covspan); // since self.prev() was already updated - } else { - self.pending_dups.clear(); - } - } - - /// Called if `curr.span` equals `prev_original_span` (and potentially equal to all - /// `pending_dups` spans, if any). Keep in mind, `prev.span()` may have been changed. - /// If prev.span() was merged into other spans (with matching BCB, for instance), - /// `prev.span.hi()` will be greater than (further right of) `prev_original_span.hi()`. - /// If prev.span() was split off to the right of a closure, prev.span().lo() will be - /// greater than prev_original_span.lo(). The actual span of `prev_original_span` is - /// not as important as knowing that `prev()` **used to have the same span** as `curr()`, - /// which means their sort order is still meaningful for determining the dominator - /// relationship. - /// - /// When two `CoverageSpan`s have the same `Span`, dominated spans can be discarded; but if - /// neither `CoverageSpan` dominates the other, both (or possibly more than two) are held, - /// until their disposition is determined. In this latter case, the `prev` dup is moved into - /// `pending_dups` so the new `curr` dup can be moved to `prev` for the next iteration. - fn update_pending_dups(&mut self) { - let prev_bcb = self.prev().bcb; - let curr_bcb = self.curr().bcb; - - // Equal coverage spans are ordered by dominators before dominated (if any), so it should be - // impossible for `curr` to dominate any previous `CoverageSpan`. - debug_assert!(!self.basic_coverage_blocks.dominates(curr_bcb, prev_bcb)); - - let initial_pending_count = self.pending_dups.len(); - if initial_pending_count > 0 { - self.pending_dups - .retain(|dup| !self.basic_coverage_blocks.dominates(dup.bcb, curr_bcb)); - - let n_discarded = initial_pending_count - self.pending_dups.len(); - if n_discarded > 0 { - debug!( - " discarded {n_discarded} of {initial_pending_count} pending_dups that dominated curr", - ); - } - } - - if self.basic_coverage_blocks.dominates(prev_bcb, curr_bcb) { - debug!( - " different bcbs but SAME spans, and prev dominates curr. Discard prev={:?}", - self.prev() - ); - self.cutoff_prev_at_overlapping_curr(); - // If one span dominates the other, associate the span with the code from the dominated - // block only (`curr`), and discard the overlapping portion of the `prev` span. (Note - // that if `prev.span` is wider than `prev_original_span`, a `CoverageSpan` will still - // be created for `prev`s block, for the non-overlapping portion, left of `curr.span`.) - // - // For example: - // match somenum { - // x if x < 1 => { ... } - // }... - // - // The span for the first `x` is referenced by both the pattern block (every time it is - // evaluated) and the arm code (only when matched). The counter will be applied only to - // the dominated block. This allows coverage to track and highlight things like the - // assignment of `x` above, if the branch is matched, making `x` available to the arm - // code; and to track and highlight the question mark `?` "try" operator at the end of - // a function call returning a `Result`, so the `?` is covered when the function returns - // an `Err`, and not counted as covered if the function always returns `Ok`. - } else { - // Save `prev` in `pending_dups`. (`curr` will become `prev` in the next iteration.) - // If the `curr` CoverageSpan is later discarded, `pending_dups` can be discarded as - // well; but if `curr` is added to refined_spans, the `pending_dups` will also be added. - debug!( - " different bcbs but SAME spans, and neither dominates, so keep curr for \ - next iter, and, pending upcoming spans (unless overlapping) add prev={:?}", - self.prev() - ); - let prev = self.take_prev(); - self.pending_dups.push(prev); } } @@ -519,19 +374,13 @@ impl<'a> CoverageSpansGenerator<'a> { if it has statements that end before curr; prev={:?}", self.prev() ); - if self.pending_dups.is_empty() { - let curr_span = self.curr().span; - self.prev_mut().cutoff_statements_at(curr_span.lo()); - if self.prev().merged_spans.is_empty() { - debug!(" ... no non-overlapping statements to add"); - } else { - debug!(" ... adding modified prev={:?}", self.prev()); - let prev = self.take_prev(); - self.refined_spans.push(prev); - } + + let curr_span = self.curr().span; + if let Some(prev) = self.take_prev().cutoff_statements_at(curr_span.lo()) { + debug!("after cutoff, adding {prev:?}"); + self.refined_spans.push(prev); } else { - // with `pending_dups`, `prev` cannot have any statements that don't overlap - self.pending_dups.clear(); + debug!("prev was eliminated by cutoff"); } } } diff --git a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs index 1b6dfccd574..b91ab811918 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs @@ -9,33 +9,34 @@ use rustc_span::{ExpnKind, MacroKind, Span, Symbol}; use crate::coverage::graph::{ BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph, START_BCB, }; -use crate::coverage::spans::CoverageSpan; use crate::coverage::ExtractedHirInfo; +/// Traverses the MIR body to produce an initial collection of coverage-relevant +/// spans, each associated with a node in the coverage graph (BCB) and possibly +/// other metadata. +/// +/// The returned spans are sorted in a specific order that is expected by the +/// subsequent span-refinement step. pub(super) fn mir_to_initial_sorted_coverage_spans( mir_body: &mir::Body<'_>, hir_info: &ExtractedHirInfo, basic_coverage_blocks: &CoverageGraph, -) -> Vec<CoverageSpan> { - let &ExtractedHirInfo { is_async_fn, fn_sig_span, body_span, .. } = hir_info; - - let mut initial_spans = vec![SpanFromMir::for_fn_sig(fn_sig_span)]; - - if is_async_fn { - // An async function desugars into a function that returns a future, - // with the user code wrapped in a closure. Any spans in the desugared - // outer function will be unhelpful, so just keep the signature span - // and ignore all of the spans in the MIR body. - } else { - for (bcb, bcb_data) in basic_coverage_blocks.iter_enumerated() { - initial_spans.extend(bcb_to_initial_coverage_spans(mir_body, body_span, bcb, bcb_data)); - } +) -> Vec<SpanFromMir> { + let &ExtractedHirInfo { body_span, .. } = hir_info; - // If no spans were extracted from the body, discard the signature span. - // FIXME: This preserves existing behavior; consider getting rid of it. - if initial_spans.len() == 1 { - initial_spans.clear(); - } + let mut initial_spans = vec![]; + + for (bcb, bcb_data) in basic_coverage_blocks.iter_enumerated() { + initial_spans.extend(bcb_to_initial_coverage_spans(mir_body, body_span, bcb, bcb_data)); + } + + // Only add the signature span if we found at least one span in the body. + if !initial_spans.is_empty() { + // If there is no usable signature span, add a fake one (before refinement) + // to avoid an ugly gap between the body start and the first real span. + // FIXME: Find a more principled way to solve this problem. + let fn_sig_span = hir_info.fn_sig_span_extended.unwrap_or_else(|| body_span.shrink_to_lo()); + initial_spans.push(SpanFromMir::for_fn_sig(fn_sig_span)); } initial_spans.sort_by(|a, b| basic_coverage_blocks.cmp_in_dominator_order(a.bcb, b.bcb)); @@ -51,15 +52,20 @@ pub(super) fn mir_to_initial_sorted_coverage_spans( // - Span A extends further left, or // - Both have the same start and span A extends further right .then_with(|| Ord::cmp(&a.span.hi(), &b.span.hi()).reverse()) - // If both spans are equal, sort the BCBs in dominator order, - // so that dominating BCBs come before other BCBs they dominate. - .then_with(|| basic_coverage_blocks.cmp_in_dominator_order(a.bcb, b.bcb)) - // If two spans are otherwise identical, put closure spans first, - // as this seems to be what the refinement step expects. + // If two spans have the same lo & hi, put closure spans first, + // as they take precedence over non-closure spans. .then_with(|| Ord::cmp(&a.is_closure, &b.is_closure).reverse()) + // After deduplication, we want to keep only the most-dominated BCB. + .then_with(|| basic_coverage_blocks.cmp_in_dominator_order(a.bcb, b.bcb).reverse()) }); - initial_spans.into_iter().map(SpanFromMir::into_coverage_span).collect::<Vec<_>>() + // Among covspans with the same span, keep only one. Closure spans take + // precedence, otherwise keep the one with the most-dominated BCB. + // (Ideally we should try to preserve _all_ non-dominating BCBs, but that + // requires a lot more complexity in the span refiner, for little benefit.) + initial_spans.dedup_by(|b, a| a.span.source_equal(b.span)); + + initial_spans } /// Macros that expand into branches (e.g. `assert!`, `trace!`) tend to generate @@ -117,10 +123,10 @@ fn split_visible_macro_spans(initial_spans: &mut Vec<SpanFromMir>) { initial_spans.extend(extra_spans); } -// Generate a set of `CoverageSpan`s from the filtered set of `Statement`s and `Terminator`s of -// the `BasicBlock`(s) in the given `BasicCoverageBlockData`. One `CoverageSpan` is generated +// Generate a set of coverage spans from the filtered set of `Statement`s and `Terminator`s of +// the `BasicBlock`(s) in the given `BasicCoverageBlockData`. One coverage span is generated // for each `Statement` and `Terminator`. (Note that subsequent stages of coverage analysis will -// merge some `CoverageSpan`s, at which point a `CoverageSpan` may represent multiple +// merge some coverage spans, at which point a coverage span may represent multiple // `Statement`s and/or `Terminator`s.) fn bcb_to_initial_coverage_spans<'a, 'tcx>( mir_body: &'a mir::Body<'tcx>, @@ -131,18 +137,23 @@ fn bcb_to_initial_coverage_spans<'a, 'tcx>( bcb_data.basic_blocks.iter().flat_map(move |&bb| { let data = &mir_body[bb]; + let unexpand = move |expn_span| { + unexpand_into_body_span_with_visible_macro(expn_span, body_span) + // Discard any spans that fill the entire body, because they tend + // to represent compiler-inserted code, e.g. implicitly returning `()`. + .filter(|(span, _)| !span.source_equal(body_span)) + }; + let statement_spans = data.statements.iter().filter_map(move |statement| { let expn_span = filtered_statement_span(statement)?; - let (span, visible_macro) = - unexpand_into_body_span_with_visible_macro(expn_span, body_span)?; + let (span, visible_macro) = unexpand(expn_span)?; - Some(SpanFromMir::new(span, visible_macro, bcb, is_closure_or_coroutine(statement))) + Some(SpanFromMir::new(span, visible_macro, bcb, is_closure_like(statement))) }); let terminator_span = Some(data.terminator()).into_iter().filter_map(move |terminator| { let expn_span = filtered_terminator_span(terminator)?; - let (span, visible_macro) = - unexpand_into_body_span_with_visible_macro(expn_span, body_span)?; + let (span, visible_macro) = unexpand(expn_span)?; Some(SpanFromMir::new(span, visible_macro, bcb, false)) }); @@ -151,10 +162,12 @@ fn bcb_to_initial_coverage_spans<'a, 'tcx>( }) } -fn is_closure_or_coroutine(statement: &Statement<'_>) -> bool { +fn is_closure_like(statement: &Statement<'_>) -> bool { match statement.kind { StatementKind::Assign(box (_, Rvalue::Aggregate(box ref agg_kind, _))) => match agg_kind { - AggregateKind::Closure(_, _) | AggregateKind::Coroutine(_, _) => true, + AggregateKind::Closure(_, _) + | AggregateKind::Coroutine(_, _) + | AggregateKind::CoroutineClosure(..) => true, _ => false, }, _ => false, @@ -312,7 +325,7 @@ fn unexpand_into_body_span_with_prev( } #[derive(Debug)] -struct SpanFromMir { +pub(super) struct SpanFromMir { /// A span that has been extracted from MIR and then "un-expanded" back to /// within the current function's `body_span`. After various intermediate /// processing steps, this span is emitted as part of the final coverage @@ -320,10 +333,10 @@ struct SpanFromMir { /// /// With the exception of `fn_sig_span`, this should always be contained /// within `body_span`. - span: Span, + pub(super) span: Span, visible_macro: Option<Symbol>, - bcb: BasicCoverageBlock, - is_closure: bool, + pub(super) bcb: BasicCoverageBlock, + pub(super) is_closure: bool, } impl SpanFromMir { @@ -339,9 +352,4 @@ impl SpanFromMir { ) -> Self { Self { span, visible_macro, bcb, is_closure } } - - fn into_coverage_span(self) -> CoverageSpan { - let Self { span, visible_macro: _, bcb, is_closure } = self; - CoverageSpan::new(span, bcb, is_closure) - } } diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index ad12bce9b02..9dc7a50eca9 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -2,7 +2,9 @@ //! //! Currently, this pass only propagates scalar values. -use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable}; +use rustc_const_eval::interpret::{ + ImmTy, Immediate, InterpCx, OpTy, PlaceTy, PointerArithmetic, Projectable, +}; use rustc_data_structures::fx::FxHashMap; use rustc_hir::def::DefKind; use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar}; @@ -19,7 +21,32 @@ use rustc_span::def_id::DefId; use rustc_span::DUMMY_SP; use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT}; -use crate::const_prop::throw_machine_stop_str; +/// Macro for machine-specific `InterpError` without allocation. +/// (These will never be shown to the user, but they help diagnose ICEs.) +pub(crate) macro throw_machine_stop_str($($tt:tt)*) {{ + // We make a new local type for it. The type itself does not carry any information, + // but its vtable (for the `MachineStopType` trait) does. + #[derive(Debug)] + struct Zst; + // Printing this type shows the desired string. + impl std::fmt::Display for Zst { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, $($tt)*) + } + } + + impl rustc_middle::mir::interpret::MachineStopType for Zst { + fn diagnostic_message(&self) -> rustc_errors::DiagnosticMessage { + self.to_string().into() + } + + fn add_args( + self: Box<Self>, + _: &mut dyn FnMut(rustc_errors::DiagnosticArgName, rustc_errors::DiagnosticArgValue), + ) {} + } + throw_machine_stop!(Zst) +}} // These constants are somewhat random guesses and have not been optimized. // If `tcx.sess.mir_opt_level() >= 4`, we ignore the limits (this can become very expensive). @@ -696,6 +723,7 @@ fn try_write_constant<'tcx>( | ty::Bound(..) | ty::Placeholder(..) | ty::Closure(..) + | ty::CoroutineClosure(..) | ty::Coroutine(..) | ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"), @@ -935,12 +963,50 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm } fn binary_ptr_op( - _ecx: &InterpCx<'mir, 'tcx, Self>, - _bin_op: BinOp, - _left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, - _right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, + ecx: &InterpCx<'mir, 'tcx, Self>, + bin_op: BinOp, + left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, + right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, ) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> { - throw_machine_stop_str!("can't do pointer arithmetic"); + use rustc_middle::mir::BinOp::*; + Ok(match bin_op { + Eq | Ne | Lt | Le | Gt | Ge => { + // Types can differ, e.g. fn ptrs with different `for`. + assert_eq!(left.layout.abi, right.layout.abi); + let size = ecx.pointer_size(); + // Just compare the bits. ScalarPairs are compared lexicographically. + // We thus always compare pairs and simply fill scalars up with 0. + // If the pointer has provenance, `to_bits` will return `Err` and we bail out. + let left = match **left { + Immediate::Scalar(l) => (l.to_bits(size)?, 0), + Immediate::ScalarPair(l1, l2) => (l1.to_bits(size)?, l2.to_bits(size)?), + Immediate::Uninit => panic!("we should never see uninit data here"), + }; + let right = match **right { + Immediate::Scalar(r) => (r.to_bits(size)?, 0), + Immediate::ScalarPair(r1, r2) => (r1.to_bits(size)?, r2.to_bits(size)?), + Immediate::Uninit => panic!("we should never see uninit data here"), + }; + let res = match bin_op { + Eq => left == right, + Ne => left != right, + Lt => left < right, + Le => left <= right, + Gt => left > right, + Ge => left >= right, + _ => bug!(), + }; + (ImmTy::from_bool(res, *ecx.tcx), false) + } + + // Some more operations are possible with atomics. + // The return value always has the provenance of the *left* operand. + Add | Sub | BitOr | BitAnd | BitXor => { + throw_machine_stop_str!("pointer arithmetic is not handled") + } + + _ => span_bug!(ecx.cur_span(), "Invalid operator on pointers: {:?}", bin_op), + }) } fn expose_ptr( diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs index a6750911394..ca63f5550ae 100644 --- a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs +++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs @@ -205,8 +205,8 @@ pub fn deduced_param_attrs<'tcx>( |(arg_index, local_decl)| DeducedParamAttrs { read_only: !deduce_read_only.mutable_args.contains(arg_index) // We must normalize here to reveal opaques and normalize - // their substs, otherwise we'll see exponential blow-up in - // compile times: #113372 + // their generic parameters, otherwise we'll see exponential + // blow-up in compile times: #113372 && tcx .normalize_erasing_regions(param_env, local_decl.ty) .is_freeze(tcx, param_env), diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs index 0ac4ab61d40..2c8201b1903 100644 --- a/compiler/rustc_mir_transform/src/dest_prop.rs +++ b/compiler/rustc_mir_transform/src/dest_prop.rs @@ -398,7 +398,8 @@ impl<'alloc> Candidates<'alloc> { let candidates = entry.get_mut(); Self::vec_filter_candidates(p, candidates, f, at); if candidates.len() == 0 { - entry.remove(); + // FIXME(#120456) - is `swap_remove` correct? + entry.swap_remove(); } } diff --git a/compiler/rustc_mir_transform/src/elaborate_drops.rs b/compiler/rustc_mir_transform/src/elaborate_drops.rs index 8575f552f0a..03d952abad1 100644 --- a/compiler/rustc_mir_transform/src/elaborate_drops.rs +++ b/compiler/rustc_mir_transform/src/elaborate_drops.rs @@ -380,7 +380,7 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { LookupResult::Parent(None) => {} LookupResult::Parent(Some(_)) => { if !replace { - self.tcx.dcx().span_delayed_bug( + self.tcx.dcx().span_bug( terminator.source_info.span, format!("drop of untracked value {bb:?}"), ); diff --git a/compiler/rustc_mir_transform/src/errors.rs b/compiler/rustc_mir_transform/src/errors.rs index 2ee660ddc9b..af2d7d4946f 100644 --- a/compiler/rustc_mir_transform/src/errors.rs +++ b/compiler/rustc_mir_transform/src/errors.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use rustc_errors::{ - Applicability, DecorateLint, DiagCtxt, DiagnosticArgValue, DiagnosticBuilder, + codes::*, Applicability, DecorateLint, DiagCtxt, DiagnosticArgValue, DiagnosticBuilder, DiagnosticMessage, EmissionGuarantee, IntoDiagnostic, Level, }; use rustc_macros::{Diagnostic, LintDiagnostic, Subdiagnostic}; @@ -33,7 +33,7 @@ pub(crate) enum ConstMutate { } #[derive(Diagnostic)] -#[diag(mir_transform_unaligned_packed_ref, code = "E0793")] +#[diag(mir_transform_unaligned_packed_ref, code = E0793)] #[note] #[note(mir_transform_note_ub)] #[help] @@ -66,7 +66,7 @@ impl<'a, G: EmissionGuarantee> IntoDiagnostic<'a, G> for RequiresUnsafe { #[track_caller] fn into_diagnostic(self, dcx: &'a DiagCtxt, level: Level) -> DiagnosticBuilder<'a, G> { let mut diag = DiagnosticBuilder::new(dcx, level, fluent::mir_transform_requires_unsafe); - diag.code("E0133".to_string()); + diag.code(E0133); diag.span(self.span); diag.span_label(self.span, self.details.label()); let desc = dcx.eagerly_translate_to_string(self.details.label(), [].into_iter()); @@ -87,6 +87,9 @@ pub(crate) struct RequiresUnsafeDetail { } impl RequiresUnsafeDetail { + // FIXME: make this translatable + #[allow(rustc::diagnostic_outside_of_impl)] + #[allow(rustc::untranslatable_diagnostic)] fn add_subdiagnostics<G: EmissionGuarantee>(&self, diag: &mut DiagnosticBuilder<'_, G>) { use UnsafetyViolationDetails::*; match self.violation { @@ -125,7 +128,7 @@ impl RequiresUnsafeDetail { diag.arg( "missing_target_features", DiagnosticArgValue::StrListSepByAnd( - missing.iter().map(|feature| Cow::from(feature.as_str())).collect(), + missing.iter().map(|feature| Cow::from(feature.to_string())).collect(), ), ); diag.arg("missing_target_features_count", missing.len()); @@ -136,7 +139,7 @@ impl RequiresUnsafeDetail { DiagnosticArgValue::StrListSepByAnd( build_enabled .iter() - .map(|feature| Cow::from(feature.as_str())) + .map(|feature| Cow::from(feature.to_string())) .collect(), ), ); @@ -201,45 +204,39 @@ impl<'a> DecorateLint<'a, ()> for UnsafeOpInUnsafeFn { } } -pub(crate) enum AssertLint<P> { - ArithmeticOverflow(Span, AssertKind<P>), - UnconditionalPanic(Span, AssertKind<P>), +pub(crate) struct AssertLint<P> { + pub span: Span, + pub assert_kind: AssertKind<P>, + pub lint_kind: AssertLintKind, +} + +pub(crate) enum AssertLintKind { + ArithmeticOverflow, + UnconditionalPanic, } impl<'a, P: std::fmt::Debug> DecorateLint<'a, ()> for AssertLint<P> { fn decorate_lint<'b>(self, diag: &'b mut DiagnosticBuilder<'a, ()>) { - let span = self.span(); - let assert_kind = self.panic(); - let message = assert_kind.diagnostic_message(); - assert_kind.add_args(&mut |name, value| { + let message = self.assert_kind.diagnostic_message(); + self.assert_kind.add_args(&mut |name, value| { diag.arg(name, value); }); - diag.span_label(span, message); + diag.span_label(self.span, message); } fn msg(&self) -> DiagnosticMessage { - match self { - AssertLint::ArithmeticOverflow(..) => fluent::mir_transform_arithmetic_overflow, - AssertLint::UnconditionalPanic(..) => fluent::mir_transform_operation_will_panic, + match self.lint_kind { + AssertLintKind::ArithmeticOverflow => fluent::mir_transform_arithmetic_overflow, + AssertLintKind::UnconditionalPanic => fluent::mir_transform_operation_will_panic, } } } -impl<P> AssertLint<P> { +impl AssertLintKind { pub fn lint(&self) -> &'static Lint { match self { - AssertLint::ArithmeticOverflow(..) => lint::builtin::ARITHMETIC_OVERFLOW, - AssertLint::UnconditionalPanic(..) => lint::builtin::UNCONDITIONAL_PANIC, - } - } - pub fn span(&self) -> Span { - match self { - AssertLint::ArithmeticOverflow(sp, _) | AssertLint::UnconditionalPanic(sp, _) => *sp, - } - } - pub fn panic(self) -> AssertKind<P> { - match self { - AssertLint::ArithmeticOverflow(_, p) | AssertLint::UnconditionalPanic(_, p) => p, + AssertLintKind::ArithmeticOverflow => lint::builtin::ARITHMETIC_OVERFLOW, + AssertLintKind::UnconditionalPanic => lint::builtin::UNCONDITIONAL_PANIC, } } } @@ -276,7 +273,7 @@ impl<'a> DecorateLint<'a, ()> for MustNotSupend<'_, '_> { fn decorate_lint<'b>(self, diag: &'b mut rustc_errors::DiagnosticBuilder<'a, ()>) { diag.span_label(self.yield_sp, fluent::_subdiag::label); if let Some(reason) = self.reason { - diag.subdiagnostic(reason); + diag.subdiagnostic(diag.dcx, reason); } diag.span_help(self.src_sp, fluent::_subdiag::help); diag.arg("pre", self.pre); diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs index 26fcfad8287..663abbece85 100644 --- a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs +++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs @@ -26,7 +26,6 @@ fn abi_can_unwind(abi: Abi) -> bool { PtxKernel | Msp430Interrupt | X86Interrupt - | AmdGpuKernel | EfiApi | AvrInterrupt | AvrNonBlockingInterrupt @@ -58,7 +57,9 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, + ty::CoroutineClosure(..) => Abi::RustCall, ty::Coroutine(..) => Abi::Rust, + ty::Error(_) => return false, _ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty), }; let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi); @@ -112,7 +113,7 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { let span = terminator.source_info.span; let foreign = fn_def_id.is_some(); - tcx.emit_spanned_lint( + tcx.emit_node_span_lint( FFI_UNWIND_CALLS, lint_root, span, diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs index 61d99f1f018..e935dc7f5eb 100644 --- a/compiler/rustc_mir_transform/src/function_item_references.rs +++ b/compiler/rustc_mir_transform/src/function_item_references.rs @@ -63,7 +63,7 @@ impl<'tcx> Visitor<'tcx> for FunctionItemRefChecker<'_, 'tcx> { impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { /// Emits a lint for function reference arguments bound by `fmt::Pointer` in calls to the - /// function defined by `def_id` with the substitutions `args_ref`. + /// function defined by `def_id` with the generic parameters `args_ref`. fn check_bound_args( &self, def_id: DefId, @@ -83,11 +83,11 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { for inner_ty in arg_def.walk().filter_map(|arg| arg.as_type()) { // If the inner type matches the type bound by `Pointer` if inner_ty == bound_ty { - // Do a substitution using the parameters from the callsite - let subst_ty = + // Do an instantiation using the parameters from the callsite + let instantiated_ty = EarlyBinder::bind(inner_ty).instantiate(self.tcx, args_ref); if let Some((fn_id, fn_args)) = - FunctionItemRefChecker::is_fn_ref(subst_ty) + FunctionItemRefChecker::is_fn_ref(instantiated_ty) { let mut span = self.nth_arg_span(args, arg_num); if span.from_expansion() { @@ -185,7 +185,7 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { ret, ); - self.tcx.emit_spanned_lint( + self.tcx.emit_node_span_lint( FUNCTION_ITEM_REFERENCES, lint_root, span, diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs index 390ec3e1a36..a080e2423d4 100644 --- a/compiler/rustc_mir_transform/src/gvn.rs +++ b/compiler/rustc_mir_transform/src/gvn.rs @@ -93,7 +93,6 @@ use rustc_index::IndexVec; use rustc_middle::mir::interpret::GlobalAlloc; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; -use rustc_middle::ty::adjustment::PointerCoercion; use rustc_middle::ty::layout::LayoutOf; use rustc_middle::ty::{self, Ty, TyCtxt, TypeAndMut}; use rustc_span::def_id::DefId; @@ -400,7 +399,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { }; for (field_index, op) in fields.into_iter().enumerate() { let field_dest = self.ecx.project_field(&variant_dest, field_index).ok()?; - self.ecx.copy_op(op, &field_dest, /*allow_transmute*/ false).ok()?; + self.ecx.copy_op(op, &field_dest).ok()?; } self.ecx.write_discriminant(variant.unwrap_or(FIRST_VARIANT), &dest).ok()?; self.ecx @@ -489,6 +488,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { NullOp::OffsetOf(fields) => { layout.offset_of_subfield(&self.ecx, fields.iter()).bytes() } + NullOp::DebugAssertions => return None, }; let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); let imm = ImmTy::try_from_uint(val, usize_layout)?; @@ -551,6 +551,33 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } value.offset(Size::ZERO, to, &self.ecx).ok()? } + CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize) => { + let src = self.evaluated[value].as_ref()?; + let to = self.ecx.layout_of(to).ok()?; + let dest = self.ecx.allocate(to, MemoryKind::Stack).ok()?; + self.ecx.unsize_into(src, to, &dest.clone().into()).ok()?; + self.ecx + .alloc_mark_immutable(dest.ptr().provenance.unwrap().alloc_id()) + .ok()?; + dest.into() + } + CastKind::FnPtrToPtr | CastKind::PtrToPtr => { + let src = self.evaluated[value].as_ref()?; + let src = self.ecx.read_immediate(src).ok()?; + let to = self.ecx.layout_of(to).ok()?; + let ret = self.ecx.ptr_to_ptr(&src, to).ok()?; + ret.into() + } + CastKind::PointerCoercion( + ty::adjustment::PointerCoercion::MutToConstPointer + | ty::adjustment::PointerCoercion::ArrayToPointer + | ty::adjustment::PointerCoercion::UnsafeFnPointer, + ) => { + let src = self.evaluated[value].as_ref()?; + let src = self.ecx.read_immediate(src).ok()?; + let to = self.ecx.layout_of(to).ok()?; + ImmTy::from_immediate(*src, to).into() + } _ => return None, }, }; @@ -777,18 +804,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { // Operations. Rvalue::Len(ref mut place) => return self.simplify_len(place, location), - Rvalue::Cast(kind, ref mut value, to) => { - let from = value.ty(self.local_decls, self.tcx); - let value = self.simplify_operand(value, location)?; - if let CastKind::PointerCoercion( - PointerCoercion::ReifyFnPointer | PointerCoercion::ClosureFnPointer(_), - ) = kind - { - // Each reification of a generic fn may get a different pointer. - // Do not try to merge them. - return self.new_opaque(); - } - Value::Cast { kind, value, from, to } + Rvalue::Cast(ref mut kind, ref mut value, to) => { + return self.simplify_cast(kind, value, to, location); } Rvalue::BinaryOp(op, box (ref mut lhs, ref mut rhs)) => { let ty = lhs.ty(self.local_decls, self.tcx); @@ -840,10 +857,10 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { fn simplify_discriminant(&mut self, place: VnIndex) -> Option<VnIndex> { if let Value::Aggregate(enum_ty, variant, _) = *self.get(place) - && let AggregateTy::Def(enum_did, enum_substs) = enum_ty + && let AggregateTy::Def(enum_did, enum_args) = enum_ty && let DefKind::Enum = self.tcx.def_kind(enum_did) { - let enum_ty = self.tcx.type_of(enum_did).instantiate(self.tcx, enum_substs); + let enum_ty = self.tcx.type_of(enum_did).instantiate(self.tcx, enum_args); let discr = self.ecx.discriminant_for_variant(enum_ty, variant).ok()?; return Some(self.insert_scalar(discr.to_scalar(), discr.layout.ty)); } @@ -861,9 +878,10 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { let tcx = self.tcx; if fields.is_empty() { let is_zst = match *kind { - AggregateKind::Array(..) | AggregateKind::Tuple | AggregateKind::Closure(..) => { - true - } + AggregateKind::Array(..) + | AggregateKind::Tuple + | AggregateKind::Closure(..) + | AggregateKind::CoroutineClosure(..) => true, // Only enums can be non-ZST. AggregateKind::Adt(did, ..) => tcx.def_kind(did) != DefKind::Enum, // Coroutines are never ZST, as they at least contain the implicit states. @@ -885,11 +903,11 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { assert!(!fields.is_empty()); (AggregateTy::Tuple, FIRST_VARIANT) } - AggregateKind::Closure(did, substs) | AggregateKind::Coroutine(did, substs) => { - (AggregateTy::Def(did, substs), FIRST_VARIANT) - } - AggregateKind::Adt(did, variant_index, substs, _, None) => { - (AggregateTy::Def(did, substs), variant_index) + AggregateKind::Closure(did, args) + | AggregateKind::CoroutineClosure(did, args) + | AggregateKind::Coroutine(did, args) => (AggregateTy::Def(did, args), FIRST_VARIANT), + AggregateKind::Adt(did, variant_index, args, _, None) => { + (AggregateTy::Def(did, args), variant_index) } // Do not track unions. AggregateKind::Adt(_, _, _, _, Some(_)) => return None, @@ -1031,6 +1049,50 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } } + fn simplify_cast( + &mut self, + kind: &mut CastKind, + operand: &mut Operand<'tcx>, + to: Ty<'tcx>, + location: Location, + ) -> Option<VnIndex> { + use rustc_middle::ty::adjustment::PointerCoercion::*; + use CastKind::*; + + let mut from = operand.ty(self.local_decls, self.tcx); + let mut value = self.simplify_operand(operand, location)?; + if from == to { + return Some(value); + } + + if let CastKind::PointerCoercion(ReifyFnPointer | ClosureFnPointer(_)) = kind { + // Each reification of a generic fn may get a different pointer. + // Do not try to merge them. + return self.new_opaque(); + } + + if let PtrToPtr | PointerCoercion(MutToConstPointer) = kind + && let Value::Cast { kind: inner_kind, value: inner_value, from: inner_from, to: _ } = + *self.get(value) + && let PtrToPtr | PointerCoercion(MutToConstPointer) = inner_kind + { + from = inner_from; + value = inner_value; + *kind = PtrToPtr; + if inner_from == to { + return Some(inner_value); + } + if let Some(const_) = self.try_as_constant(value) { + *operand = Operand::Constant(Box::new(const_)); + } else if let Some(local) = self.try_as_local(value, location) { + *operand = Operand::Copy(local.into()); + self.reused_locals.insert(local); + } + } + + Some(self.insert(Value::Cast { kind: *kind, value, from, to })) + } + fn simplify_len(&mut self, place: &mut Place<'tcx>, location: Location) -> Option<VnIndex> { // Trivial case: we are fetching a statically known length. let place_ty = place.ty(self.local_decls, self.tcx).ty; @@ -1119,8 +1181,7 @@ fn op_to_prop_const<'tcx>( } // Everything failed: create a new allocation to hold the data. - let alloc_id = - ecx.intern_with_temp_alloc(op.layout, |ecx, dest| ecx.copy_op(op, dest, false)).ok()?; + let alloc_id = ecx.intern_with_temp_alloc(op.layout, |ecx, dest| ecx.copy_op(op, dest)).ok()?; let value = ConstValue::Indirect { alloc_id, offset: Size::ZERO }; // Check that we do not leak a pointer. @@ -1228,8 +1289,8 @@ impl<'tcx> MutVisitor<'tcx> for StorageRemover<'tcx> { fn visit_operand(&mut self, operand: &mut Operand<'tcx>, _: Location) { if let Operand::Move(place) = *operand - && let Some(local) = place.as_local() - && self.reused_locals.contains(local) + && !place.is_indirect_first_projection() + && self.reused_locals.contains(place.local) { *operand = Operand::Copy(place); } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 67668a216de..956d855ab81 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -2,6 +2,7 @@ use crate::deref_separator::deref_finder; use rustc_attr::InlineAttr; use rustc_const_eval::transform::validate::validate_types; +use rustc_hir::def::DefKind; use rustc_hir::def_id::DefId; use rustc_index::bit_set::BitSet; use rustc_index::Idx; @@ -317,6 +318,8 @@ impl<'tcx> Inliner<'tcx> { | InstanceDef::ReifyShim(_) | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) @@ -382,6 +385,17 @@ impl<'tcx> Inliner<'tcx> { } let fn_sig = self.tcx.fn_sig(def_id).instantiate(self.tcx, args); + + // Additionally, check that the body that we're inlining actually agrees + // with the ABI of the trait that the item comes from. + if let InstanceDef::Item(instance_def_id) = callee.def + && self.tcx.def_kind(instance_def_id) == DefKind::AssocFn + && let instance_fn_sig = self.tcx.fn_sig(instance_def_id).skip_binder() + && instance_fn_sig.abi() != fn_sig.abi() + { + return None; + } + let source_info = SourceInfo { span: fn_span, ..terminator.source_info }; return Some(CallSite { callee, fn_sig, block: bb, source_info }); @@ -1025,21 +1039,16 @@ fn try_instance_mir<'tcx>( tcx: TyCtxt<'tcx>, instance: InstanceDef<'tcx>, ) -> Result<&'tcx Body<'tcx>, &'static str> { - match instance { - ty::InstanceDef::DropGlue(_, Some(ty)) => match ty.kind() { - ty::Adt(def, args) => { - let fields = def.all_fields(); - for field in fields { - let field_ty = field.ty(tcx, args); - if field_ty.has_param() && field_ty.has_projections() { - return Err("cannot build drop shim for polymorphic type"); - } - } - - Ok(tcx.instance_mir(instance)) + if let ty::InstanceDef::DropGlue(_, Some(ty)) = instance + && let ty::Adt(def, args) = ty.kind() + { + let fields = def.all_fields(); + for field in fields { + let field_ty = field.ty(tcx, args); + if field_ty.has_param() && field_ty.has_projections() { + return Err("cannot build drop shim for polymorphic type"); } - _ => Ok(tcx.instance_mir(instance)), - }, - _ => Ok(tcx.instance_mir(instance)), + } } + Ok(tcx.instance_mir(instance)) } diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index d30e0bad813..f2b6dcac586 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -80,20 +80,22 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( } // These have no own callable MIR. InstanceDef::Intrinsic(_) | InstanceDef::Virtual(..) => continue, - // These have MIR and if that MIR is inlined, substituted and then inlining is run + // These have MIR and if that MIR is inlined, instantiated and then inlining is run // again, a function item can end up getting inlined. Thus we'll be able to cause // a cycle that way InstanceDef::VTableShim(_) | InstanceDef::ReifyShim(_) | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::ThreadLocalShim { .. } | InstanceDef::CloneShim(..) => {} // This shim does not call any other functions, thus there can be no recursion. InstanceDef::FnPtrAddrShim(..) => continue, InstanceDef::DropGlue(..) => { - // FIXME: A not fully substituted drop shim can cause ICEs if one attempts to + // FIXME: A not fully instantiated drop shim can cause ICEs if one attempts to // have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this // needs some more analysis. if callee.has_param() { diff --git a/compiler/rustc_mir_transform/src/instsimplify.rs b/compiler/rustc_mir_transform/src/instsimplify.rs index a28db0defc9..73102a5f026 100644 --- a/compiler/rustc_mir_transform/src/instsimplify.rs +++ b/compiler/rustc_mir_transform/src/instsimplify.rs @@ -2,10 +2,12 @@ use crate::simplify::simplify_duplicate_switch_targets; use rustc_middle::mir::*; +use rustc_middle::ty::layout; use rustc_middle::ty::layout::ValidityRequirement; use rustc_middle::ty::{self, GenericArgsRef, ParamEnv, Ty, TyCtxt}; use rustc_span::symbol::Symbol; use rustc_target::abi::FieldIdx; +use rustc_target::spec::abi::Abi; pub struct InstSimplify; @@ -27,17 +29,15 @@ impl<'tcx> MirPass<'tcx> for InstSimplify { ctx.simplify_bool_cmp(&statement.source_info, rvalue); ctx.simplify_ref_deref(&statement.source_info, rvalue); ctx.simplify_len(&statement.source_info, rvalue); - ctx.simplify_cast(&statement.source_info, rvalue); + ctx.simplify_cast(rvalue); } _ => {} } } ctx.simplify_primitive_clone(block.terminator.as_mut().unwrap(), &mut block.statements); - ctx.simplify_intrinsic_assert( - block.terminator.as_mut().unwrap(), - &mut block.statements, - ); + ctx.simplify_intrinsic_assert(block.terminator.as_mut().unwrap()); + ctx.simplify_nounwind_call(block.terminator.as_mut().unwrap()); simplify_duplicate_switch_targets(block.terminator.as_mut().unwrap()); } } @@ -140,7 +140,7 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { } } - fn simplify_cast(&self, _source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + fn simplify_cast(&self, rvalue: &mut Rvalue<'tcx>) { if let Rvalue::Cast(kind, operand, cast_ty) = rvalue { let operand_ty = operand.ty(self.local_decls, self.tcx); if operand_ty == *cast_ty { @@ -208,7 +208,7 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { // Only bother looking more if it's easy to know what we're calling let Some((fn_def_id, fn_args)) = func.const_fn_def() else { return }; - // Clone needs one subst, so we can cheaply rule out other stuff + // Clone needs one arg, so we can cheaply rule out other stuff if fn_args.len() != 1 { return; } @@ -252,11 +252,29 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { terminator.kind = TerminatorKind::Goto { target: destination_block }; } - fn simplify_intrinsic_assert( - &self, - terminator: &mut Terminator<'tcx>, - _statements: &mut Vec<Statement<'tcx>>, - ) { + fn simplify_nounwind_call(&self, terminator: &mut Terminator<'tcx>) { + let TerminatorKind::Call { func, unwind, .. } = &mut terminator.kind else { + return; + }; + + let Some((def_id, _)) = func.const_fn_def() else { + return; + }; + + let body_ty = self.tcx.type_of(def_id).skip_binder(); + let body_abi = match body_ty.kind() { + ty::FnDef(..) => body_ty.fn_sig(self.tcx).abi(), + ty::Closure(..) => Abi::RustCall, + ty::Coroutine(..) => Abi::Rust, + _ => bug!("unexpected body ty: {:?}", body_ty), + }; + + if !layout::fn_can_unwind(self.tcx, Some(def_id), body_abi) { + *unwind = UnwindAction::Unreachable; + } + } + + fn simplify_intrinsic_assert(&self, terminator: &mut Terminator<'tcx>) { let TerminatorKind::Call { func, target, .. } = &mut terminator.kind else { return; }; @@ -271,9 +289,9 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { if args.is_empty() { return; } - let ty = args.type_at(0); - let known_is_valid = intrinsic_assert_panics(self.tcx, self.param_env, ty, intrinsic_name); + let known_is_valid = + intrinsic_assert_panics(self.tcx, self.param_env, args[0], intrinsic_name); match known_is_valid { // We don't know the layout or it's not validity assertion at all, don't touch it None => {} @@ -292,10 +310,11 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { fn intrinsic_assert_panics<'tcx>( tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>, - ty: Ty<'tcx>, + arg: ty::GenericArg<'tcx>, intrinsic_name: Symbol, ) -> Option<bool> { let requirement = ValidityRequirement::from_intrinsic(intrinsic_name)?; + let ty = arg.expect_ty(); Some(!tcx.check_validity_requirement((requirement, param_env.and(ty))).ok()?) } @@ -304,9 +323,8 @@ fn resolve_rust_intrinsic<'tcx>( func_ty: Ty<'tcx>, ) -> Option<(Symbol, GenericArgsRef<'tcx>)> { if let ty::FnDef(def_id, args) = *func_ty.kind() { - if tcx.is_intrinsic(def_id) { - return Some((tcx.item_name(def_id), args)); - } + let name = tcx.intrinsic(def_id)?; + return Some((name, args)); } None } diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs index dcab124505e..ad8f21ffbda 100644 --- a/compiler/rustc_mir_transform/src/jump_threading.rs +++ b/compiler/rustc_mir_transform/src/jump_threading.rs @@ -36,16 +36,21 @@ //! cost by `MAX_COST`. use rustc_arena::DroplessArena; +use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable}; use rustc_data_structures::fx::FxHashSet; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; +use rustc_middle::mir::interpret::Scalar; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; -use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt}; +use rustc_middle::ty::layout::LayoutOf; +use rustc_middle::ty::{self, ScalarInt, TyCtxt}; use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem}; +use rustc_span::DUMMY_SP; use rustc_target::abi::{TagEncoding, Variants}; use crate::cost_checker::CostChecker; +use crate::dataflow_const_prop::DummyMachine; pub struct JumpThreading; @@ -55,7 +60,7 @@ const MAX_PLACES: usize = 100; impl<'tcx> MirPass<'tcx> for JumpThreading { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 4 + sess.mir_opt_level() >= 2 } #[instrument(skip_all level = "debug")] @@ -63,6 +68,12 @@ impl<'tcx> MirPass<'tcx> for JumpThreading { let def_id = body.source.def_id(); debug!(?def_id); + // Optimizing coroutines creates query cycles. + if tcx.is_coroutine(def_id) { + trace!("Skipped for coroutine {:?}", def_id); + return; + } + let param_env = tcx.param_env_reveal_all_normalized(def_id); let map = Map::new(tcx, body, Some(MAX_PLACES)); let loop_headers = loop_headers(body); @@ -71,6 +82,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading { let mut finder = TOFinder { tcx, param_env, + ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), body, arena: &arena, map: &map, @@ -88,7 +100,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading { debug!(?discr, ?bb); let discr_ty = discr.ty(body, tcx).ty; - let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue }; + let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue }; let Some(discr) = finder.map.find(discr.as_ref()) else { continue }; debug!(?discr); @@ -142,6 +154,7 @@ struct ThreadingOpportunity { struct TOFinder<'tcx, 'a> { tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>, + ecx: InterpCx<'tcx, 'tcx, DummyMachine>, body: &'a Body<'tcx>, map: &'a Map, loop_headers: &'a BitSet<BasicBlock>, @@ -329,11 +342,11 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { } #[instrument(level = "trace", skip(self))] - fn process_operand( + fn process_immediate( &mut self, bb: BasicBlock, lhs: PlaceIndex, - rhs: &Operand<'tcx>, + rhs: ImmTy<'tcx>, state: &mut State<ConditionSet<'a>>, ) -> Option<!> { let register_opportunity = |c: Condition| { @@ -341,13 +354,70 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target }) }; + let conditions = state.try_get_idx(lhs, self.map)?; + if let Immediate::Scalar(Scalar::Int(int)) = *rhs { + conditions.iter_matches(int).for_each(register_opportunity); + } + + None + } + + /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. + #[instrument(level = "trace", skip(self))] + fn process_constant( + &mut self, + bb: BasicBlock, + lhs: PlaceIndex, + constant: OpTy<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) { + self.map.for_each_projection_value( + lhs, + constant, + &mut |elem, op| match elem { + TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(), + TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(), + TrackElem::Discriminant => { + let variant = self.ecx.read_discriminant(op).ok()?; + let discr_value = + self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?; + Some(discr_value.into()) + } + TrackElem::DerefLen => { + let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into(); + let len_usize = op.len(&self.ecx).ok()?; + let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); + Some(ImmTy::from_uint(len_usize, layout).into()) + } + }, + &mut |place, op| { + if let Some(conditions) = state.try_get_idx(place, self.map) + && let Ok(imm) = self.ecx.read_immediate_raw(op) + && let Some(imm) = imm.right() + && let Immediate::Scalar(Scalar::Int(int)) = *imm + { + conditions.iter_matches(int).for_each(|c: Condition| { + self.opportunities + .push(ThreadingOpportunity { chain: vec![bb], target: c.target }) + }) + } + }, + ); + } + + #[instrument(level = "trace", skip(self))] + fn process_operand( + &mut self, + bb: BasicBlock, + lhs: PlaceIndex, + rhs: &Operand<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) -> Option<!> { match rhs { // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. Operand::Constant(constant) => { - let conditions = state.try_get_idx(lhs, self.map)?; - let constant = - constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?; - conditions.iter_matches(constant).for_each(register_opportunity); + let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?; + self.process_constant(bb, lhs, constant, state); } // Transfer the conditions on the copied rhs. Operand::Move(rhs) | Operand::Copy(rhs) => { @@ -360,6 +430,84 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { } #[instrument(level = "trace", skip(self))] + fn process_assign( + &mut self, + bb: BasicBlock, + lhs_place: &Place<'tcx>, + rhs: &Rvalue<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) -> Option<!> { + let lhs = self.map.find(lhs_place.as_ref())?; + match rhs { + Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?, + // Transfer the conditions on the copy rhs. + Rvalue::CopyForDeref(rhs) => { + self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)? + } + Rvalue::Discriminant(rhs) => { + let rhs = self.map.find_discr(rhs.as_ref())?; + state.insert_place_idx(rhs, lhs, self.map); + } + // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. + Rvalue::Aggregate(box ref kind, ref operands) => { + let agg_ty = lhs_place.ty(self.body, self.tcx).ty; + let lhs = match kind { + // Do not support unions. + AggregateKind::Adt(.., Some(_)) => return None, + AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => { + if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant) + && let Ok(discr_value) = + self.ecx.discriminant_for_variant(agg_ty, *variant_index) + { + self.process_immediate(bb, discr_target, discr_value, state); + } + self.map.apply(lhs, TrackElem::Variant(*variant_index))? + } + _ => lhs, + }; + for (field_index, operand) in operands.iter_enumerated() { + if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) { + self.process_operand(bb, field, operand, state); + } + } + } + // Transfer the conditions on the copy rhs, after inversing polarity. + Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => { + let conditions = state.try_get_idx(lhs, self.map)?; + let place = self.map.find(place.as_ref())?; + let conds = conditions.map(self.arena, Condition::inv); + state.insert_value_idx(place, conds, self.map); + } + // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`. + // Create a condition on `rhs ?= B`. + Rvalue::BinaryOp( + op, + box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value)) + | box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)), + ) => { + let conditions = state.try_get_idx(lhs, self.map)?; + let place = self.map.find(place.as_ref())?; + let equals = match op { + BinOp::Eq => ScalarInt::TRUE, + BinOp::Ne => ScalarInt::FALSE, + _ => return None, + }; + let value = value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?; + let conds = conditions.map(self.arena, |c| Condition { + value, + polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne }, + ..c + }); + state.insert_value_idx(place, conds, self.map); + } + + _ => {} + } + + None + } + + #[instrument(level = "trace", skip(self))] fn process_statement( &mut self, bb: BasicBlock, @@ -374,18 +522,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // Below, `lhs` is the return value of `mutated_statement`, // the place to which `conditions` apply. - let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| { - let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?; - let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?; - let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?; - Some(Operand::const_from_scalar( - self.tcx, - discr.ty, - scalar.into(), - rustc_span::DUMMY_SP, - )) - }; - match &stmt.kind { // If we expect `discriminant(place) ?= A`, // we have an opportunity if `variant_index ?= A`. @@ -395,7 +531,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant // of a niche encoding. If we cannot ensure that we write to the discriminant, do // nothing. - let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?; + let enum_layout = self.ecx.layout_of(enum_ty).ok()?; let writes_discriminant = match enum_layout.variants { Variants::Single { index } => { assert_eq!(index, *variant_index); @@ -408,8 +544,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { } => *variant_index != untagged_variant, }; if writes_discriminant { - let discr = discriminant_for_variant(enum_ty, *variant_index)?; - self.process_operand(bb, discr_target, &discr, state)?; + let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?; + self.process_immediate(bb, discr_target, discr, state)?; } } // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`. @@ -420,89 +556,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity); } StatementKind::Assign(box (lhs_place, rhs)) => { - if let Some(lhs) = self.map.find(lhs_place.as_ref()) { - match rhs { - Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?, - // Transfer the conditions on the copy rhs. - Rvalue::CopyForDeref(rhs) => { - self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)? - } - Rvalue::Discriminant(rhs) => { - let rhs = self.map.find_discr(rhs.as_ref())?; - state.insert_place_idx(rhs, lhs, self.map); - } - // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. - Rvalue::Aggregate(box ref kind, ref operands) => { - let agg_ty = lhs_place.ty(self.body, self.tcx).ty; - let lhs = match kind { - // Do not support unions. - AggregateKind::Adt(.., Some(_)) => return None, - AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => { - if let Some(discr_target) = - self.map.apply(lhs, TrackElem::Discriminant) - && let Some(discr_value) = - discriminant_for_variant(agg_ty, *variant_index) - { - self.process_operand(bb, discr_target, &discr_value, state); - } - self.map.apply(lhs, TrackElem::Variant(*variant_index))? - } - _ => lhs, - }; - for (field_index, operand) in operands.iter_enumerated() { - if let Some(field) = - self.map.apply(lhs, TrackElem::Field(field_index)) - { - self.process_operand(bb, field, operand, state); - } - } - } - // Transfer the conditions on the copy rhs, after inversing polarity. - Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => { - let conditions = state.try_get_idx(lhs, self.map)?; - let place = self.map.find(place.as_ref())?; - let conds = conditions.map(self.arena, Condition::inv); - state.insert_value_idx(place, conds, self.map); - } - // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`. - // Create a condition on `rhs ?= B`. - Rvalue::BinaryOp( - op, - box ( - Operand::Move(place) | Operand::Copy(place), - Operand::Constant(value), - ) - | box ( - Operand::Constant(value), - Operand::Move(place) | Operand::Copy(place), - ), - ) => { - let conditions = state.try_get_idx(lhs, self.map)?; - let place = self.map.find(place.as_ref())?; - let equals = match op { - BinOp::Eq => ScalarInt::TRUE, - BinOp::Ne => ScalarInt::FALSE, - _ => return None, - }; - let value = value - .const_ - .normalize(self.tcx, self.param_env) - .try_to_scalar_int()?; - let conds = conditions.map(self.arena, |c| Condition { - value, - polarity: if c.matches(equals) { - Polarity::Eq - } else { - Polarity::Ne - }, - ..c - }); - state.insert_value_idx(place, conds, self.map); - } - - _ => {} - } - } + self.process_assign(bb, lhs_place, rhs, state)?; } _ => {} } @@ -518,11 +572,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { cost: &CostChecker<'_, 'tcx>, depth: usize, ) { - let register_opportunity = |c: Condition| { - debug!(?bb, ?c.target, "register"); - self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target }) - }; - let term = self.body.basic_blocks[bb].terminator(); let place_to_flood = match term.kind { // We come from a target, so those are not possible. @@ -544,16 +593,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // Flood the overwritten place, and progress through. TerminatorKind::Drop { place: destination, .. } | TerminatorKind::Call { destination, .. } => Some(destination), - // Treat as an `assume(cond == expected)`. - TerminatorKind::Assert { ref cond, expected, .. } => { - if let Some(place) = cond.place() - && let Some(conditions) = state.try_get(place.as_ref(), self.map) - { - let expected = if expected { ScalarInt::TRUE } else { ScalarInt::FALSE }; - conditions.iter_matches(expected).for_each(register_opportunity); - } - None - } + // Ignore, as this can be a no-op at codegen time. + TerminatorKind::Assert { .. } => None, }; // We can recurse through this terminator. @@ -577,7 +618,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { let discr = discr.place()?; let discr_ty = discr.ty(self.body, self.tcx).ty; - let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?; + let discr_layout = self.ecx.layout_of(discr_ty).ok()?; let conditions = state.try_get(discr.as_ref(), self.map)?; if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) { diff --git a/compiler/rustc_mir_transform/src/known_panics_lint.rs b/compiler/rustc_mir_transform/src/known_panics_lint.rs new file mode 100644 index 00000000000..a9cd688c315 --- /dev/null +++ b/compiler/rustc_mir_transform/src/known_panics_lint.rs @@ -0,0 +1,975 @@ +//! A lint that checks for known panics like +//! overflows, division by zero, +//! out-of-bound access etc. +//! Uses const propagation to determine the +//! values of operands during checks. + +use std::fmt::Debug; + +use rustc_const_eval::interpret::{ + format_interp_error, ImmTy, InterpCx, InterpResult, Projectable, Scalar, +}; +use rustc_data_structures::fx::FxHashSet; +use rustc_hir::def::DefKind; +use rustc_hir::HirId; +use rustc_index::{bit_set::BitSet, Idx, IndexVec}; +use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; +use rustc_middle::ty::{self, ConstInt, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisitableExt}; +use rustc_span::Span; +use rustc_target::abi::{Abi, FieldIdx, HasDataLayout, Size, TargetDataLayout, VariantIdx}; + +use crate::dataflow_const_prop::DummyMachine; +use crate::errors::{AssertLint, AssertLintKind}; +use crate::MirLint; + +pub struct KnownPanicsLint; + +impl<'tcx> MirLint<'tcx> for KnownPanicsLint { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + if body.tainted_by_errors.is_some() { + return; + } + + let def_id = body.source.def_id().expect_local(); + let def_kind = tcx.def_kind(def_id); + let is_fn_like = def_kind.is_fn_like(); + let is_assoc_const = def_kind == DefKind::AssocConst; + + // Only run const prop on functions, methods, closures and associated constants + if !is_fn_like && !is_assoc_const { + // skip anon_const/statics/consts because they'll be evaluated by miri anyway + trace!("KnownPanicsLint skipped for {:?}", def_id); + return; + } + + // FIXME(welseywiser) const prop doesn't work on coroutines because of query cycles + // computing their layout. + if tcx.is_coroutine(def_id.to_def_id()) { + trace!("KnownPanicsLint skipped for coroutine {:?}", def_id); + return; + } + + trace!("KnownPanicsLint starting for {:?}", def_id); + + let mut linter = ConstPropagator::new(body, tcx); + linter.visit_body(body); + + trace!("KnownPanicsLint done for {:?}", def_id); + } +} + +/// Visits MIR nodes, performs const propagation +/// and runs lint checks as it goes +struct ConstPropagator<'mir, 'tcx> { + ecx: InterpCx<'mir, 'tcx, DummyMachine>, + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + worklist: Vec<BasicBlock>, + visited_blocks: BitSet<BasicBlock>, + locals: IndexVec<Local, Value<'tcx>>, + body: &'mir Body<'tcx>, + written_only_inside_own_block_locals: FxHashSet<Local>, + can_const_prop: IndexVec<Local, ConstPropMode>, +} + +#[derive(Debug, Clone)] +enum Value<'tcx> { + Immediate(ImmTy<'tcx>), + Aggregate { variant: VariantIdx, fields: IndexVec<FieldIdx, Value<'tcx>> }, + Uninit, +} + +impl<'tcx> From<ImmTy<'tcx>> for Value<'tcx> { + fn from(v: ImmTy<'tcx>) -> Self { + Self::Immediate(v) + } +} + +impl<'tcx> Value<'tcx> { + fn project( + &self, + proj: &[PlaceElem<'tcx>], + prop: &ConstPropagator<'_, 'tcx>, + ) -> Option<&Value<'tcx>> { + let mut this = self; + for proj in proj { + this = match (*proj, this) { + (PlaceElem::Field(idx, _), Value::Aggregate { fields, .. }) => { + fields.get(idx).unwrap_or(&Value::Uninit) + } + (PlaceElem::Index(idx), Value::Aggregate { fields, .. }) => { + let idx = prop.get_const(idx.into())?.immediate()?; + let idx = prop.ecx.read_target_usize(idx).ok()?; + fields.get(FieldIdx::from_u32(idx.try_into().ok()?)).unwrap_or(&Value::Uninit) + } + ( + PlaceElem::ConstantIndex { offset, min_length: _, from_end: false }, + Value::Aggregate { fields, .. }, + ) => fields + .get(FieldIdx::from_u32(offset.try_into().ok()?)) + .unwrap_or(&Value::Uninit), + _ => return None, + }; + } + Some(this) + } + + fn project_mut(&mut self, proj: &[PlaceElem<'_>]) -> Option<&mut Value<'tcx>> { + let mut this = self; + for proj in proj { + this = match (proj, this) { + (PlaceElem::Field(idx, _), Value::Aggregate { fields, .. }) => { + fields.ensure_contains_elem(*idx, || Value::Uninit) + } + (PlaceElem::Field(..), val @ Value::Uninit) => { + *val = Value::Aggregate { + variant: VariantIdx::new(0), + fields: Default::default(), + }; + val.project_mut(&[*proj])? + } + _ => return None, + }; + } + Some(this) + } + + fn immediate(&self) -> Option<&ImmTy<'tcx>> { + match self { + Value::Immediate(op) => Some(op), + _ => None, + } + } +} + +impl<'tcx> LayoutOfHelpers<'tcx> for ConstPropagator<'_, 'tcx> { + type LayoutOfResult = Result<TyAndLayout<'tcx>, LayoutError<'tcx>>; + + #[inline] + fn handle_layout_err(&self, err: LayoutError<'tcx>, _: Span, _: Ty<'tcx>) -> LayoutError<'tcx> { + err + } +} + +impl HasDataLayout for ConstPropagator<'_, '_> { + #[inline] + fn data_layout(&self) -> &TargetDataLayout { + &self.tcx.data_layout + } +} + +impl<'tcx> ty::layout::HasTyCtxt<'tcx> for ConstPropagator<'_, 'tcx> { + #[inline] + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } +} + +impl<'tcx> ty::layout::HasParamEnv<'tcx> for ConstPropagator<'_, 'tcx> { + #[inline] + fn param_env(&self) -> ty::ParamEnv<'tcx> { + self.param_env + } +} + +impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { + fn new(body: &'mir Body<'tcx>, tcx: TyCtxt<'tcx>) -> ConstPropagator<'mir, 'tcx> { + let def_id = body.source.def_id(); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + + let can_const_prop = CanConstProp::check(tcx, param_env, body); + let ecx = InterpCx::new(tcx, tcx.def_span(def_id), param_env, DummyMachine); + + ConstPropagator { + ecx, + tcx, + param_env, + worklist: vec![START_BLOCK], + visited_blocks: BitSet::new_empty(body.basic_blocks.len()), + locals: IndexVec::from_elem_n(Value::Uninit, body.local_decls.len()), + body, + can_const_prop, + written_only_inside_own_block_locals: Default::default(), + } + } + + fn local_decls(&self) -> &'mir LocalDecls<'tcx> { + &self.body.local_decls + } + + fn get_const(&self, place: Place<'tcx>) -> Option<&Value<'tcx>> { + self.locals[place.local].project(&place.projection, self) + } + + /// Remove `local` from the pool of `Locals`. Allows writing to them, + /// but not reading from them anymore. + fn remove_const(&mut self, local: Local) { + self.locals[local] = Value::Uninit; + self.written_only_inside_own_block_locals.remove(&local); + } + + fn access_mut(&mut self, place: &Place<'_>) -> Option<&mut Value<'tcx>> { + match self.can_const_prop[place.local] { + ConstPropMode::NoPropagation => return None, + ConstPropMode::OnlyInsideOwnBlock => { + self.written_only_inside_own_block_locals.insert(place.local); + } + ConstPropMode::FullConstProp => {} + } + self.locals[place.local].project_mut(place.projection) + } + + fn lint_root(&self, source_info: SourceInfo) -> Option<HirId> { + source_info.scope.lint_root(&self.body.source_scopes) + } + + fn use_ecx<F, T>(&mut self, f: F) -> Option<T> + where + F: FnOnce(&mut Self) -> InterpResult<'tcx, T>, + { + match f(self) { + Ok(val) => Some(val), + Err(error) => { + trace!("InterpCx operation failed: {:?}", error); + // Some errors shouldn't come up because creating them causes + // an allocation, which we should avoid. When that happens, + // dedicated error variants should be introduced instead. + assert!( + !error.kind().formatted_string(), + "known panics lint encountered formatting error: {}", + format_interp_error(self.ecx.tcx.dcx(), error), + ); + None + } + } + } + + /// Returns the value, if any, of evaluating `c`. + fn eval_constant(&mut self, c: &ConstOperand<'tcx>) -> Option<ImmTy<'tcx>> { + // FIXME we need to revisit this for #67176 + if c.has_param() { + return None; + } + + // Normalization needed b/c known panics lint runs in + // `mir_drops_elaborated_and_const_checked`, which happens before + // optimized MIR. Only after optimizing the MIR can we guarantee + // that the `RevealAll` pass has happened and that the body's consts + // are normalized, so any call to resolve before that needs to be + // manually normalized. + let val = self.tcx.try_normalize_erasing_regions(self.param_env, c.const_).ok()?; + + self.use_ecx(|this| this.ecx.eval_mir_constant(&val, Some(c.span), None))? + .as_mplace_or_imm() + .right() + } + + /// Returns the value, if any, of evaluating `place`. + #[instrument(level = "trace", skip(self), ret)] + fn eval_place(&mut self, place: Place<'tcx>) -> Option<ImmTy<'tcx>> { + match self.get_const(place)? { + Value::Immediate(imm) => Some(imm.clone()), + Value::Aggregate { .. } => None, + Value::Uninit => None, + } + } + + /// Returns the value, if any, of evaluating `op`. Calls upon `eval_constant` + /// or `eval_place`, depending on the variant of `Operand` used. + fn eval_operand(&mut self, op: &Operand<'tcx>) -> Option<ImmTy<'tcx>> { + match *op { + Operand::Constant(ref c) => self.eval_constant(c), + Operand::Move(place) | Operand::Copy(place) => self.eval_place(place), + } + } + + fn report_assert_as_lint( + &self, + location: Location, + lint_kind: AssertLintKind, + assert_kind: AssertKind<impl Debug>, + ) { + let source_info = self.body.source_info(location); + if let Some(lint_root) = self.lint_root(*source_info) { + let span = source_info.span; + self.tcx.emit_node_span_lint( + lint_kind.lint(), + lint_root, + span, + AssertLint { span, assert_kind, lint_kind }, + ); + } + } + + fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>, location: Location) -> Option<()> { + let arg = self.eval_operand(arg)?; + if let (val, true) = self.use_ecx(|this| { + let val = this.ecx.read_immediate(&arg)?; + let (_res, overflow) = this.ecx.overflowing_unary_op(op, &val)?; + Ok((val, overflow)) + })? { + // `AssertKind` only has an `OverflowNeg` variant, so make sure that is + // appropriate to use. + assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow"); + self.report_assert_as_lint( + location, + AssertLintKind::ArithmeticOverflow, + AssertKind::OverflowNeg(val.to_const_int()), + ); + return None; + } + + Some(()) + } + + fn check_binary_op( + &mut self, + op: BinOp, + left: &Operand<'tcx>, + right: &Operand<'tcx>, + location: Location, + ) -> Option<()> { + let r = + self.eval_operand(right).and_then(|r| self.use_ecx(|this| this.ecx.read_immediate(&r))); + let l = + self.eval_operand(left).and_then(|l| self.use_ecx(|this| this.ecx.read_immediate(&l))); + // Check for exceeding shifts *even if* we cannot evaluate the LHS. + if matches!(op, BinOp::Shr | BinOp::Shl) { + let r = r.clone()?; + // We need the type of the LHS. We cannot use `place_layout` as that is the type + // of the result, which for checked binops is not the same! + let left_ty = left.ty(self.local_decls(), self.tcx); + let left_size = self.ecx.layout_of(left_ty).ok()?.size; + let right_size = r.layout.size; + let r_bits = r.to_scalar().to_bits(right_size).ok(); + if r_bits.is_some_and(|b| b >= left_size.bits() as u128) { + debug!("check_binary_op: reporting assert for {:?}", location); + let panic = AssertKind::Overflow( + op, + match l { + Some(l) => l.to_const_int(), + // Invent a dummy value, the diagnostic ignores it anyway + None => ConstInt::new( + ScalarInt::try_from_uint(1_u8, left_size).unwrap(), + left_ty.is_signed(), + left_ty.is_ptr_sized_integral(), + ), + }, + r.to_const_int(), + ); + self.report_assert_as_lint(location, AssertLintKind::ArithmeticOverflow, panic); + return None; + } + } + + if let (Some(l), Some(r)) = (l, r) { + // The remaining operators are handled through `overflowing_binary_op`. + if self.use_ecx(|this| { + let (_res, overflow) = this.ecx.overflowing_binary_op(op, &l, &r)?; + Ok(overflow) + })? { + self.report_assert_as_lint( + location, + AssertLintKind::ArithmeticOverflow, + AssertKind::Overflow(op, l.to_const_int(), r.to_const_int()), + ); + return None; + } + } + Some(()) + } + + fn check_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) -> Option<()> { + // Perform any special handling for specific Rvalue types. + // Generally, checks here fall into one of two categories: + // 1. Additional checking to provide useful lints to the user + // - In this case, we will do some validation and then fall through to the + // end of the function which evals the assignment. + // 2. Working around bugs in other parts of the compiler + // - In this case, we'll return `None` from this function to stop evaluation. + match rvalue { + // Additional checking: give lints to the user if an overflow would occur. + // We do this here and not in the `Assert` terminator as that terminator is + // only sometimes emitted (overflow checks can be disabled), but we want to always + // lint. + Rvalue::UnaryOp(op, arg) => { + trace!("checking UnaryOp(op = {:?}, arg = {:?})", op, arg); + self.check_unary_op(*op, arg, location)?; + } + Rvalue::BinaryOp(op, box (left, right)) => { + trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right); + self.check_binary_op(*op, left, right, location)?; + } + Rvalue::CheckedBinaryOp(op, box (left, right)) => { + trace!( + "checking CheckedBinaryOp(op = {:?}, left = {:?}, right = {:?})", + op, + left, + right + ); + self.check_binary_op(*op, left, right, location)?; + } + + // Do not try creating references (#67862) + Rvalue::AddressOf(_, place) | Rvalue::Ref(_, _, place) => { + trace!("skipping AddressOf | Ref for {:?}", place); + + // This may be creating mutable references or immutable references to cells. + // If that happens, the pointed to value could be mutated via that reference. + // Since we aren't tracking references, the const propagator loses track of what + // value the local has right now. + // Thus, all locals that have their reference taken + // must not take part in propagation. + self.remove_const(place.local); + + return None; + } + Rvalue::ThreadLocalRef(def_id) => { + trace!("skipping ThreadLocalRef({:?})", def_id); + + return None; + } + + // There's no other checking to do at this time. + Rvalue::Aggregate(..) + | Rvalue::Use(..) + | Rvalue::CopyForDeref(..) + | Rvalue::Repeat(..) + | Rvalue::Len(..) + | Rvalue::Cast(..) + | Rvalue::ShallowInitBox(..) + | Rvalue::Discriminant(..) + | Rvalue::NullaryOp(..) => {} + } + + // FIXME we need to revisit this for #67176 + if rvalue.has_param() { + return None; + } + if !rvalue.ty(self.local_decls(), self.tcx).is_sized(self.tcx, self.param_env) { + // the interpreter doesn't support unsized locals (only unsized arguments), + // but rustc does (in a kinda broken way), so we have to skip them here + return None; + } + + Some(()) + } + + fn check_assertion( + &mut self, + expected: bool, + msg: &AssertKind<Operand<'tcx>>, + cond: &Operand<'tcx>, + location: Location, + ) -> Option<!> { + let value = &self.eval_operand(cond)?; + trace!("assertion on {:?} should be {:?}", value, expected); + + let expected = Scalar::from_bool(expected); + let value_const = self.use_ecx(|this| this.ecx.read_scalar(value))?; + + if expected != value_const { + // Poison all places this operand references so that further code + // doesn't use the invalid value + if let Some(place) = cond.place() { + self.remove_const(place.local); + } + + enum DbgVal<T> { + Val(T), + Underscore, + } + impl<T: std::fmt::Debug> std::fmt::Debug for DbgVal<T> { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Val(val) => val.fmt(fmt), + Self::Underscore => fmt.write_str("_"), + } + } + } + let mut eval_to_int = |op| { + // This can be `None` if the lhs wasn't const propagated and we just + // triggered the assert on the value of the rhs. + self.eval_operand(op) + .and_then(|op| self.ecx.read_immediate(&op).ok()) + .map_or(DbgVal::Underscore, |op| DbgVal::Val(op.to_const_int())) + }; + let msg = match msg { + AssertKind::DivisionByZero(op) => AssertKind::DivisionByZero(eval_to_int(op)), + AssertKind::RemainderByZero(op) => AssertKind::RemainderByZero(eval_to_int(op)), + AssertKind::Overflow(bin_op @ (BinOp::Div | BinOp::Rem), op1, op2) => { + // Division overflow is *UB* in the MIR, and different than the + // other overflow checks. + AssertKind::Overflow(*bin_op, eval_to_int(op1), eval_to_int(op2)) + } + AssertKind::BoundsCheck { ref len, ref index } => { + let len = eval_to_int(len); + let index = eval_to_int(index); + AssertKind::BoundsCheck { len, index } + } + // Remaining overflow errors are already covered by checks on the binary operators. + AssertKind::Overflow(..) | AssertKind::OverflowNeg(_) => return None, + // Need proper const propagator for these. + _ => return None, + }; + self.report_assert_as_lint(location, AssertLintKind::UnconditionalPanic, msg); + } + + None + } + + fn ensure_not_propagated(&self, local: Local) { + if cfg!(debug_assertions) { + let val = self.get_const(local.into()); + assert!( + matches!(val, Some(Value::Uninit)) + || self + .layout_of(self.local_decls()[local].ty) + .map_or(true, |layout| layout.is_zst()), + "failed to remove values for `{local:?}`, value={val:?}", + ) + } + } + + #[instrument(level = "trace", skip(self), ret)] + fn eval_rvalue(&mut self, rvalue: &Rvalue<'tcx>, dest: &Place<'tcx>) -> Option<()> { + if !dest.projection.is_empty() { + return None; + } + use rustc_middle::mir::Rvalue::*; + let layout = self.ecx.layout_of(dest.ty(self.body, self.tcx).ty).ok()?; + trace!(?layout); + + let val: Value<'_> = match *rvalue { + ThreadLocalRef(_) => return None, + + Use(ref operand) => self.eval_operand(operand)?.into(), + + CopyForDeref(place) => self.eval_place(place)?.into(), + + BinaryOp(bin_op, box (ref left, ref right)) => { + let left = self.eval_operand(left)?; + let left = self.use_ecx(|this| this.ecx.read_immediate(&left))?; + + let right = self.eval_operand(right)?; + let right = self.use_ecx(|this| this.ecx.read_immediate(&right))?; + + let val = + self.use_ecx(|this| this.ecx.wrapping_binary_op(bin_op, &left, &right))?; + val.into() + } + + CheckedBinaryOp(bin_op, box (ref left, ref right)) => { + let left = self.eval_operand(left)?; + let left = self.use_ecx(|this| this.ecx.read_immediate(&left))?; + + let right = self.eval_operand(right)?; + let right = self.use_ecx(|this| this.ecx.read_immediate(&right))?; + + let (val, overflowed) = + self.use_ecx(|this| this.ecx.overflowing_binary_op(bin_op, &left, &right))?; + let overflowed = ImmTy::from_bool(overflowed, self.tcx); + Value::Aggregate { + variant: VariantIdx::new(0), + fields: [Value::from(val), overflowed.into()].into_iter().collect(), + } + } + + UnaryOp(un_op, ref operand) => { + let operand = self.eval_operand(operand)?; + let val = self.use_ecx(|this| this.ecx.read_immediate(&operand))?; + + let val = self.use_ecx(|this| this.ecx.wrapping_unary_op(un_op, &val))?; + val.into() + } + + Aggregate(ref kind, ref fields) => Value::Aggregate { + fields: fields + .iter() + .map(|field| self.eval_operand(field).map_or(Value::Uninit, Value::Immediate)) + .collect(), + variant: match **kind { + AggregateKind::Adt(_, variant, _, _, _) => variant, + AggregateKind::Array(_) + | AggregateKind::Tuple + | AggregateKind::Closure(_, _) + | AggregateKind::Coroutine(_, _) + | AggregateKind::CoroutineClosure(_, _) => VariantIdx::new(0), + }, + }, + + Repeat(ref op, n) => { + trace!(?op, ?n); + return None; + } + + Len(place) => { + let len = match self.get_const(place)? { + Value::Immediate(src) => src.len(&self.ecx).ok()?, + Value::Aggregate { fields, .. } => fields.len() as u64, + Value::Uninit => match place.ty(self.local_decls(), self.tcx).ty.kind() { + ty::Array(_, n) => n.try_eval_target_usize(self.tcx, self.param_env)?, + _ => return None, + }, + }; + ImmTy::from_scalar(Scalar::from_target_usize(len, self), layout).into() + } + + Ref(..) | AddressOf(..) => return None, + + NullaryOp(ref null_op, ty) => { + let op_layout = self.use_ecx(|this| this.ecx.layout_of(ty))?; + let val = match null_op { + NullOp::SizeOf => op_layout.size.bytes(), + NullOp::AlignOf => op_layout.align.abi.bytes(), + NullOp::OffsetOf(fields) => { + op_layout.offset_of_subfield(self, fields.iter()).bytes() + } + NullOp::DebugAssertions => return None, + }; + ImmTy::from_scalar(Scalar::from_target_usize(val, self), layout).into() + } + + ShallowInitBox(..) => return None, + + Cast(ref kind, ref value, to) => match kind { + CastKind::IntToInt | CastKind::IntToFloat => { + let value = self.eval_operand(value)?; + let value = self.ecx.read_immediate(&value).ok()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.int_to_int_or_float(&value, to).ok()?; + res.into() + } + CastKind::FloatToFloat | CastKind::FloatToInt => { + let value = self.eval_operand(value)?; + let value = self.ecx.read_immediate(&value).ok()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.float_to_float_or_int(&value, to).ok()?; + res.into() + } + CastKind::Transmute => { + let value = self.eval_operand(value)?; + let to = self.ecx.layout_of(to).ok()?; + // `offset` for immediates only supports scalar/scalar-pair ABIs, + // so bail out if the target is not one. + match (value.layout.abi, to.abi) { + (Abi::Scalar(..), Abi::Scalar(..)) => {} + (Abi::ScalarPair(..), Abi::ScalarPair(..)) => {} + _ => return None, + } + + value.offset(Size::ZERO, to, &self.ecx).ok()?.into() + } + _ => return None, + }, + + Discriminant(place) => { + let variant = match self.get_const(place)? { + Value::Immediate(op) => { + let op = op.clone(); + self.use_ecx(|this| this.ecx.read_discriminant(&op))? + } + Value::Aggregate { variant, .. } => *variant, + Value::Uninit => return None, + }; + let imm = self.use_ecx(|this| { + this.ecx.discriminant_for_variant( + place.ty(this.local_decls(), this.tcx).ty, + variant, + ) + })?; + imm.into() + } + }; + trace!(?val); + + *self.access_mut(dest)? = val; + + Some(()) + } +} + +impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { + fn visit_body(&mut self, body: &Body<'tcx>) { + while let Some(bb) = self.worklist.pop() { + if !self.visited_blocks.insert(bb) { + continue; + } + + let data = &body.basic_blocks[bb]; + self.visit_basic_block_data(bb, data); + } + } + + fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { + self.super_operand(operand, location); + } + + fn visit_constant(&mut self, constant: &ConstOperand<'tcx>, location: Location) { + trace!("visit_constant: {:?}", constant); + self.super_constant(constant, location); + self.eval_constant(constant); + } + + fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { + self.super_assign(place, rvalue, location); + + let Some(()) = self.check_rvalue(rvalue, location) else { return }; + + match self.can_const_prop[place.local] { + // Do nothing if the place is indirect. + _ if place.is_indirect() => {} + ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), + ConstPropMode::OnlyInsideOwnBlock | ConstPropMode::FullConstProp => { + if self.eval_rvalue(rvalue, place).is_none() { + // Const prop failed, so erase the destination, ensuring that whatever happens + // from here on, does not know about the previous value. + // This is important in case we have + // ```rust + // let mut x = 42; + // x = SOME_MUTABLE_STATIC; + // // x must now be uninit + // ``` + // FIXME: we overzealously erase the entire local, because that's easier to + // implement. + trace!( + "propagation into {:?} failed. + Nuking the entire site from orbit, it's the only way to be sure", + place, + ); + self.remove_const(place.local); + } + } + } + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + trace!("visit_statement: {:?}", statement); + + // We want to evaluate operands before any change to the assigned-to value, + // so we recurse first. + self.super_statement(statement, location); + + match statement.kind { + StatementKind::SetDiscriminant { ref place, variant_index } => { + match self.can_const_prop[place.local] { + // Do nothing if the place is indirect. + _ if place.is_indirect() => {} + ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), + ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { + match self.access_mut(place) { + Some(Value::Aggregate { variant, .. }) => *variant = variant_index, + _ => self.remove_const(place.local), + } + } + } + } + StatementKind::StorageLive(local) => { + self.remove_const(local); + } + StatementKind::StorageDead(local) => { + self.remove_const(local); + } + _ => {} + } + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + self.super_terminator(terminator, location); + match &terminator.kind { + TerminatorKind::Assert { expected, ref msg, ref cond, .. } => { + self.check_assertion(*expected, msg, cond, location); + } + TerminatorKind::SwitchInt { ref discr, ref targets } => { + if let Some(ref value) = self.eval_operand(discr) + && let Some(value_const) = self.use_ecx(|this| this.ecx.read_scalar(value)) + && let Ok(constant) = value_const.try_to_int() + && let Ok(constant) = constant.to_bits(constant.size()) + { + // We managed to evaluate the discriminant, so we know we only need to visit + // one target. + let target = targets.target_for_value(constant); + self.worklist.push(target); + return; + } + // We failed to evaluate the discriminant, fallback to visiting all successors. + } + // None of these have Operands to const-propagate. + TerminatorKind::Goto { .. } + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::Yield { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::Call { .. } + | TerminatorKind::InlineAsm { .. } => {} + } + + self.worklist.extend(terminator.successors()); + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { + self.super_basic_block_data(block, data); + + // We remove all Locals which are restricted in propagation to their containing blocks and + // which were modified in the current block. + // Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`. + let mut written_only_inside_own_block_locals = + std::mem::take(&mut self.written_only_inside_own_block_locals); + + // This loop can get very hot for some bodies: it check each local in each bb. + // To avoid this quadratic behaviour, we only clear the locals that were modified inside + // the current block. + // The order in which we remove consts does not matter. + #[allow(rustc::potential_query_instability)] + for local in written_only_inside_own_block_locals.drain() { + debug_assert_eq!(self.can_const_prop[local], ConstPropMode::OnlyInsideOwnBlock); + self.remove_const(local); + } + self.written_only_inside_own_block_locals = written_only_inside_own_block_locals; + + if cfg!(debug_assertions) { + for (local, &mode) in self.can_const_prop.iter_enumerated() { + match mode { + ConstPropMode::FullConstProp => {} + ConstPropMode::NoPropagation | ConstPropMode::OnlyInsideOwnBlock => { + self.ensure_not_propagated(local); + } + } + } + } + } +} + +/// The maximum number of bytes that we'll allocate space for a local or the return value. +/// Needed for #66397, because otherwise we eval into large places and that can cause OOM or just +/// Severely regress performance. +const MAX_ALLOC_LIMIT: u64 = 1024; + +/// The mode that `ConstProp` is allowed to run in for a given `Local`. +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum ConstPropMode { + /// The `Local` can be propagated into and reads of this `Local` can also be propagated. + FullConstProp, + /// The `Local` can only be propagated into and from its own block. + OnlyInsideOwnBlock, + /// The `Local` cannot be part of propagation at all. Any statement + /// referencing it either for reading or writing will not get propagated. + NoPropagation, +} + +/// A visitor that determines locals in a MIR body +/// that can be const propagated +pub struct CanConstProp { + can_const_prop: IndexVec<Local, ConstPropMode>, + // False at the beginning. Once set, no more assignments are allowed to that local. + found_assignment: BitSet<Local>, +} + +impl CanConstProp { + /// Returns true if `local` can be propagated + pub fn check<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + body: &Body<'tcx>, + ) -> IndexVec<Local, ConstPropMode> { + let mut cpv = CanConstProp { + can_const_prop: IndexVec::from_elem(ConstPropMode::FullConstProp, &body.local_decls), + found_assignment: BitSet::new_empty(body.local_decls.len()), + }; + for (local, val) in cpv.can_const_prop.iter_enumerated_mut() { + let ty = body.local_decls[local].ty; + match tcx.layout_of(param_env.and(ty)) { + Ok(layout) if layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) => {} + // Either the layout fails to compute, then we can't use this local anyway + // or the local is too large, then we don't want to. + _ => { + *val = ConstPropMode::NoPropagation; + continue; + } + } + } + // Consider that arguments are assigned on entry. + for arg in body.args_iter() { + cpv.found_assignment.insert(arg); + } + cpv.visit_body(body); + cpv.can_const_prop + } +} + +impl<'tcx> Visitor<'tcx> for CanConstProp { + fn visit_place(&mut self, place: &Place<'tcx>, mut context: PlaceContext, loc: Location) { + use rustc_middle::mir::visit::PlaceContext::*; + + // Dereferencing just read the addess of `place.local`. + if place.projection.first() == Some(&PlaceElem::Deref) { + context = NonMutatingUse(NonMutatingUseContext::Copy); + } + + self.visit_local(place.local, context, loc); + self.visit_projection(place.as_ref(), context, loc); + } + + fn visit_local(&mut self, local: Local, context: PlaceContext, _: Location) { + use rustc_middle::mir::visit::PlaceContext::*; + match context { + // These are just stores, where the storing is not propagatable, but there may be later + // mutations of the same local via `Store` + | MutatingUse(MutatingUseContext::Call) + | MutatingUse(MutatingUseContext::AsmOutput) + | MutatingUse(MutatingUseContext::Deinit) + // Actual store that can possibly even propagate a value + | MutatingUse(MutatingUseContext::Store) + | MutatingUse(MutatingUseContext::SetDiscriminant) => { + if !self.found_assignment.insert(local) { + match &mut self.can_const_prop[local] { + // If the local can only get propagated in its own block, then we don't have + // to worry about multiple assignments, as we'll nuke the const state at the + // end of the block anyway, and inside the block we overwrite previous + // states as applicable. + ConstPropMode::OnlyInsideOwnBlock => {} + ConstPropMode::NoPropagation => {} + other @ ConstPropMode::FullConstProp => { + trace!( + "local {:?} can't be propagated because of multiple assignments. Previous state: {:?}", + local, other, + ); + *other = ConstPropMode::OnlyInsideOwnBlock; + } + } + } + } + // Reading constants is allowed an arbitrary number of times + NonMutatingUse(NonMutatingUseContext::Copy) + | NonMutatingUse(NonMutatingUseContext::Move) + | NonMutatingUse(NonMutatingUseContext::Inspect) + | NonMutatingUse(NonMutatingUseContext::PlaceMention) + | NonUse(_) => {} + + // These could be propagated with a smarter analysis or just some careful thinking about + // whether they'd be fine right now. + MutatingUse(MutatingUseContext::Yield) + | MutatingUse(MutatingUseContext::Drop) + | MutatingUse(MutatingUseContext::Retag) + // These can't ever be propagated under any scheme, as we can't reason about indirect + // mutation. + | NonMutatingUse(NonMutatingUseContext::SharedBorrow) + | NonMutatingUse(NonMutatingUseContext::FakeBorrow) + | NonMutatingUse(NonMutatingUseContext::AddressOf) + | MutatingUse(MutatingUseContext::Borrow) + | MutatingUse(MutatingUseContext::AddressOf) => { + trace!("local {:?} can't be propagated because it's used: {:?}", local, context); + self.can_const_prop[local] = ConstPropMode::NoPropagation; + } + MutatingUse(MutatingUseContext::Projection) + | NonMutatingUse(NonMutatingUseContext::Projection) => bug!("visit_place should not pass {context:?} for {local:?}"), + } + } +} diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 19bfed4333c..945c3c662a6 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -1,22 +1,20 @@ -#![deny(rustc::untranslatable_diagnostic)] -#![deny(rustc::diagnostic_outside_of_impl)] #![feature(assert_matches)] #![feature(box_patterns)] +#![feature(const_type_name)] #![feature(cow_is_borrowed)] #![feature(decl_macro)] #![feature(impl_trait_in_assoc_type)] +#![feature(inline_const)] #![feature(is_sorted)] #![feature(let_chains)] #![feature(map_try_insert)] -#![feature(min_specialization)] +#![cfg_attr(bootstrap, feature(min_specialization))] #![feature(never_type)] #![feature(option_get_or_insert_default)] #![feature(round_char_boundary)] -#![feature(trusted_step)] #![feature(try_blocks)] #![feature(yeet_expr)] #![feature(if_let_guard)] -#![recursion_limit = "256"] #[macro_use] extern crate tracing; @@ -61,9 +59,6 @@ mod remove_place_mention; mod add_subtyping_projections; pub mod cleanup_post_borrowck; mod const_debuginfo; -mod const_goto; -mod const_prop; -mod const_prop_lint; mod copy_prop; mod coroutine; mod cost_checker; @@ -87,6 +82,7 @@ mod gvn; pub mod inline; mod instsimplify; mod jump_threading; +mod known_panics_lint; mod large_enums; mod lint; mod lower_intrinsics; @@ -105,7 +101,6 @@ mod remove_unneeded_drops; mod remove_zsts; mod required_consts; mod reveal_all; -mod separate_const_switch; mod shim; mod ssa; // This pass is public to allow external drivers to perform MIR cleanup @@ -165,8 +160,7 @@ fn remap_mir_for_const_eval_select<'tcx>( fn_span, .. } if let ty::FnDef(def_id, _) = *const_.ty().kind() - && tcx.item_name(def_id) == sym::const_eval_select - && tcx.is_intrinsic(def_id) => + && matches!(tcx.intrinsic(def_id), Some(sym::const_eval_select)) => { let [tupled_args, called_in_const, called_at_rt]: [_; 3] = std::mem::take(args).try_into().unwrap(); @@ -269,6 +263,7 @@ fn mir_const_qualif(tcx: TyCtxt<'_>, def: LocalDefId) -> ConstQualifs { let body = &tcx.mir_const(def).borrow(); if body.return_ty().references_error() { + // It's possible to reach here without an error being emitted (#121103). tcx.dcx().span_delayed_bug(body.span, "mir_const_qualif: MIR had errors"); return Default::default(); } @@ -307,6 +302,10 @@ fn mir_const(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> { &Lint(check_packed_ref::CheckPackedRef), &Lint(check_const_item_mutation::CheckConstItemMutation), &Lint(function_item_references::FunctionItemReferences), + // If this is an async closure's output coroutine, generate + // by-move and by-mut bodies if needed. We do this first so + // they can be optimized in lockstep with their parent bodies. + &coroutine::ByMoveBody, // What we need to do constant evaluation. &simplify::SimplifyCfg::Initial, &rustc_peek::SanityCheck, // Just a lint @@ -436,7 +435,7 @@ fn mir_drops_elaborated_and_const_checked(tcx: TyCtxt<'_>, def: LocalDefId) -> & // // We manually filter the predicates, skipping anything that's not // "global". We are in a potentially generic context - // (e.g. we are evaluating a function without substituting generic + // (e.g. we are evaluating a function without instantiating generic // parameters, so this filtering serves two purposes: // // 1. We skip evaluating any predicates that we would @@ -534,7 +533,7 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &elaborate_box_derefs::ElaborateBoxDerefs, &coroutine::StateTransform, &add_retag::AddRetag, - &Lint(const_prop_lint::ConstPropLint), + &Lint(known_panics_lint::KnownPanicsLint), ]; pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial))); } @@ -576,10 +575,10 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &inline::Inline, // Code from other crates may have storage markers, so this needs to happen after inlining. &remove_storage_markers::RemoveStorageMarkers, - // Inlining and substitution may introduce ZST and useless drops. + // Inlining and instantiation may introduce ZST and useless drops. &remove_zsts::RemoveZsts, &remove_unneeded_drops::RemoveUnneededDrops, - // Type substitution may create uninhabited enums. + // Type instantiation may create uninhabited enums. &uninhabited_enum_branching::UninhabitedEnumBranching, &unreachable_prop::UnreachablePropagation, &o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching), @@ -588,7 +587,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { // Has to run after `slice::len` lowering &normalize_array_len::NormalizeArrayLen, - &const_goto::ConstGoto, &ref_prop::ReferencePropagation, &sroa::ScalarReplacementOfAggregates, &match_branches::MatchBranchSimplification, @@ -599,10 +597,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &dead_store_elimination::DeadStoreElimination::Initial, &gvn::GVN, &simplify::SimplifyLocals::AfterGVN, - // Perform `SeparateConstSwitch` after SSA-based analyses, as cloning blocks may - // destroy the SSA property. It should still happen before const-propagation, so the - // latter pass will leverage the created opportunities. - &separate_const_switch::SeparateConstSwitch, &dataflow_const_prop::DataflowConstProp, &const_debuginfo::ConstDebugInfo, &o1(simplify_branches::SimplifyConstCondition::AfterConstProp), @@ -657,7 +651,6 @@ fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> { debug!("about to call mir_drops_elaborated..."); let body = tcx.mir_drops_elaborated_and_const_checked(did).steal(); let mut body = remap_mir_for_const_eval_select(tcx, body, hir::Constness::NotConst); - debug!("body: {:#?}", body); if body.tainted_by_errors.is_some() { return body; @@ -678,7 +671,7 @@ fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> { } /// Fetch all the promoteds of an item and prepare their MIR bodies to be ready for -/// constant evaluation once all substitutions become known. +/// constant evaluation once all generic parameters become known. fn promoted_mir(tcx: TyCtxt<'_>, def: LocalDefId) -> &IndexVec<Promoted, Body<'_>> { if tcx.is_constructor(def.to_def_id()) { return tcx.arena.alloc(IndexVec::new()); diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs index 897375e0e16..a0af902c4e1 100644 --- a/compiler/rustc_mir_transform/src/lower_intrinsics.rs +++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs @@ -14,13 +14,23 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { if let TerminatorKind::Call { func, args, destination, target, .. } = &mut terminator.kind && let ty::FnDef(def_id, generic_args) = *func.ty(local_decls, tcx).kind() - && tcx.is_intrinsic(def_id) + && let Some(intrinsic_name) = tcx.intrinsic(def_id) { - let intrinsic_name = tcx.item_name(def_id); match intrinsic_name { sym::unreachable => { terminator.kind = TerminatorKind::Unreachable; } + sym::debug_assertions => { + let target = target.unwrap(); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::NullaryOp(NullOp::DebugAssertions, tcx.types.bool), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } sym::forget => { if let Some(target) = *target { block.statements.push(Statement { diff --git a/compiler/rustc_mir_transform/src/nrvo.rs b/compiler/rustc_mir_transform/src/nrvo.rs index ff309bd10ec..c3a92911bbf 100644 --- a/compiler/rustc_mir_transform/src/nrvo.rs +++ b/compiler/rustc_mir_transform/src/nrvo.rs @@ -1,7 +1,7 @@ //! See the docs for [`RenameReturnPlace`]. use rustc_hir::Mutability; -use rustc_index::bit_set::HybridBitSet; +use rustc_index::bit_set::BitSet; use rustc_middle::mir::visit::{MutVisitor, NonUseContext, PlaceContext, Visitor}; use rustc_middle::mir::{self, BasicBlock, Local, Location}; use rustc_middle::ty::TyCtxt; @@ -123,7 +123,7 @@ fn find_local_assigned_to_return_place( body: &mut mir::Body<'_>, ) -> Option<Local> { let mut block = start; - let mut seen = HybridBitSet::new_empty(body.basic_blocks.len()); + let mut seen = BitSet::new_empty(body.basic_blocks.len()); // Iterate as long as `block` has exactly one predecessor that we have not yet visited. while seen.insert(block) { diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index f4c572aec12..77478cc741d 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -7,8 +7,12 @@ use crate::{lint::lint_body, validate, MirPass}; /// Just like `MirPass`, except it cannot mutate `Body`. pub trait MirLint<'tcx> { fn name(&self) -> &'static str { - let name = std::any::type_name::<Self>(); - if let Some((_, tail)) = name.rsplit_once(':') { tail } else { name } + // FIXME Simplify the implementation once more `str` methods get const-stable. + // See copypaste in `MirPass` + const { + let name = std::any::type_name::<Self>(); + rustc_middle::util::common::c_name(name) + } } fn is_enabled(&self, _sess: &Session) -> bool { @@ -177,6 +181,15 @@ fn run_passes_inner<'tcx>( body.pass_count = 1; } + + if let Some(coroutine) = body.coroutine.as_mut() { + if let Some(by_move_body) = coroutine.by_move_body.as_mut() { + run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each); + } + if let Some(by_mut_body) = coroutine.by_mut_body.as_mut() { + run_passes_inner(tcx, by_mut_body, passes, phase_change, validate_each); + } + } } pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: String) { diff --git a/compiler/rustc_mir_transform/src/promote_consts.rs b/compiler/rustc_mir_transform/src/promote_consts.rs index c00093ea27e..577b8f2080f 100644 --- a/compiler/rustc_mir_transform/src/promote_consts.rs +++ b/compiler/rustc_mir_transform/src/promote_consts.rs @@ -446,6 +446,7 @@ impl<'tcx> Validator<'_, 'tcx> { NullOp::SizeOf => {} NullOp::AlignOf => {} NullOp::OffsetOf(_) => {} + NullOp::DebugAssertions => {} }, Rvalue::ShallowInitBox(_, _) => return Err(Unpromotable), diff --git a/compiler/rustc_mir_transform/src/ref_prop.rs b/compiler/rustc_mir_transform/src/ref_prop.rs index 05a3ac3cc75..d5642be5513 100644 --- a/compiler/rustc_mir_transform/src/ref_prop.rs +++ b/compiler/rustc_mir_transform/src/ref_prop.rs @@ -44,7 +44,7 @@ use crate::ssa::{SsaLocals, StorageLiveLocals}; /// /// # Liveness /// -/// When performing a substitution, we must take care not to introduce uses of dangling locals. +/// When performing an instantiation, we must take care not to introduce uses of dangling locals. /// To ensure this, we walk the body with the `MaybeStorageDead` dataflow analysis: /// - if we want to replace `*x` by reborrow `*y` and `y` may be dead, we allow replacement and /// mark storage statements on `y` for removal; @@ -55,7 +55,7 @@ use crate::ssa::{SsaLocals, StorageLiveLocals}; /// /// For `&mut` borrows, we also need to preserve the uniqueness property: /// we must avoid creating a state where we interleave uses of `*_1` and `_2`. -/// To do it, we only perform full substitution of mutable borrows: +/// To do it, we only perform full instantiation of mutable borrows: /// we replace either all or none of the occurrences of `*_1`. /// /// Some care has to be taken when `_1` is copied in other locals. @@ -63,10 +63,10 @@ use crate::ssa::{SsaLocals, StorageLiveLocals}; /// _3 = *_1; /// _4 = _1 /// _5 = *_4 -/// In such cases, fully substituting `_1` means fully substituting all of the copies. +/// In such cases, fully instantiating `_1` means fully instantiating all of the copies. /// /// For immutable borrows, we do not need to preserve such uniqueness property, -/// so we perform all the possible substitutions without removing the `_1 = &_2` statement. +/// so we perform all the possible instantiations without removing the `_1 = &_2` statement. pub struct ReferencePropagation; impl<'tcx> MirPass<'tcx> for ReferencePropagation { diff --git a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs index 095119e2e3f..fb52bfa468a 100644 --- a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs +++ b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs @@ -15,7 +15,8 @@ impl<'tcx> MirPass<'tcx> for RemoveNoopLandingPads { } fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - debug!("remove_noop_landing_pads({:?})", body); + let def_id = body.source.def_id(); + debug!(?def_id); self.remove_nop_landing_pads(body) } } @@ -81,8 +82,6 @@ impl RemoveNoopLandingPads { } fn remove_nop_landing_pads(&self, body: &mut Body<'_>) { - debug!("body: {:#?}", body); - // Skip the pass if there are no blocks with a resume terminator. let has_resume = body .basic_blocks diff --git a/compiler/rustc_mir_transform/src/remove_storage_markers.rs b/compiler/rustc_mir_transform/src/remove_storage_markers.rs index 795f5232ee3..f68e592db15 100644 --- a/compiler/rustc_mir_transform/src/remove_storage_markers.rs +++ b/compiler/rustc_mir_transform/src/remove_storage_markers.rs @@ -7,14 +7,10 @@ pub struct RemoveStorageMarkers; impl<'tcx> MirPass<'tcx> for RemoveStorageMarkers { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() > 0 + sess.mir_opt_level() > 0 && !sess.emit_lifetime_markers() } - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - if tcx.sess.emit_lifetime_markers() { - return; - } - + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { trace!("Running RemoveStorageMarkers on {:?}", body.source); for data in body.basic_blocks.as_mut_preserves_cfg() { data.statements.retain(|statement| match statement.kind { diff --git a/compiler/rustc_mir_transform/src/remove_zsts.rs b/compiler/rustc_mir_transform/src/remove_zsts.rs index 34d57a45301..9a94cae3382 100644 --- a/compiler/rustc_mir_transform/src/remove_zsts.rs +++ b/compiler/rustc_mir_transform/src/remove_zsts.rs @@ -46,6 +46,7 @@ fn maybe_zst(ty: Ty<'_>) -> bool { ty::Adt(..) | ty::Array(..) | ty::Closure(..) + | ty::CoroutineClosure(..) | ty::Tuple(..) | ty::Alias(ty::Opaque, ..) => true, // definitely ZST diff --git a/compiler/rustc_mir_transform/src/separate_const_switch.rs b/compiler/rustc_mir_transform/src/separate_const_switch.rs deleted file mode 100644 index 7120ef72142..00000000000 --- a/compiler/rustc_mir_transform/src/separate_const_switch.rs +++ /dev/null @@ -1,343 +0,0 @@ -//! A pass that duplicates switch-terminated blocks -//! into a new copy for each predecessor, provided -//! the predecessor sets the value being switched -//! over to a constant. -//! -//! The purpose of this pass is to help constant -//! propagation passes to simplify the switch terminator -//! of the copied blocks into gotos when some predecessors -//! statically determine the output of switches. -//! -//! ```text -//! x = 12 --- ---> something -//! \ / 12 -//! --> switch x -//! / \ otherwise -//! x = y --- ---> something else -//! ``` -//! becomes -//! ```text -//! x = 12 ---> switch x ------> something -//! \ / 12 -//! X -//! / \ otherwise -//! x = y ---> switch x ------> something else -//! ``` -//! so it can hopefully later be turned by another pass into -//! ```text -//! x = 12 --------------------> something -//! / 12 -//! / -//! / otherwise -//! x = y ---- switch x ------> something else -//! ``` -//! -//! This optimization is meant to cover simple cases -//! like `?` desugaring. For now, it thus focuses on -//! simplicity rather than completeness (it notably -//! sometimes duplicates abusively). - -use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; -use smallvec::SmallVec; - -pub struct SeparateConstSwitch; - -impl<'tcx> MirPass<'tcx> for SeparateConstSwitch { - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - // This pass participates in some as-of-yet untested unsoundness found - // in https://github.com/rust-lang/rust/issues/112460 - sess.mir_opt_level() >= 2 && sess.opts.unstable_opts.unsound_mir_opts - } - - fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // If execution did something, applying a simplification layer - // helps later passes optimize the copy away. - if separate_const_switch(body) > 0 { - super::simplify::simplify_cfg(body); - } - } -} - -/// Returns the amount of blocks that were duplicated -pub fn separate_const_switch(body: &mut Body<'_>) -> usize { - let mut new_blocks: SmallVec<[(BasicBlock, BasicBlock); 6]> = SmallVec::new(); - let predecessors = body.basic_blocks.predecessors(); - 'block_iter: for (block_id, block) in body.basic_blocks.iter_enumerated() { - if let TerminatorKind::SwitchInt { - discr: Operand::Copy(switch_place) | Operand::Move(switch_place), - .. - } = block.terminator().kind - { - // If the block is on an unwind path, do not - // apply the optimization as unwind paths - // rely on a unique parent invariant - if block.is_cleanup { - continue 'block_iter; - } - - // If the block has fewer than 2 predecessors, ignore it - // we could maybe chain blocks that have exactly one - // predecessor, but for now we ignore that - if predecessors[block_id].len() < 2 { - continue 'block_iter; - } - - // First, let's find a non-const place - // that determines the result of the switch - if let Some(switch_place) = find_determining_place(switch_place, block) { - // We now have an input place for which it would - // be interesting if predecessors assigned it from a const - - let mut predecessors_left = predecessors[block_id].len(); - 'predec_iter: for predecessor_id in predecessors[block_id].iter().copied() { - let predecessor = &body.basic_blocks[predecessor_id]; - - // First we make sure the predecessor jumps - // in a reasonable way - match &predecessor.terminator().kind { - // The following terminators are - // unconditionally valid - TerminatorKind::Goto { .. } | TerminatorKind::SwitchInt { .. } => {} - - TerminatorKind::FalseEdge { real_target, .. } => { - if *real_target != block_id { - continue 'predec_iter; - } - } - - // The following terminators are not allowed - TerminatorKind::UnwindResume - | TerminatorKind::Drop { .. } - | TerminatorKind::Call { .. } - | TerminatorKind::Assert { .. } - | TerminatorKind::FalseUnwind { .. } - | TerminatorKind::Yield { .. } - | TerminatorKind::UnwindTerminate(_) - | TerminatorKind::Return - | TerminatorKind::Unreachable - | TerminatorKind::InlineAsm { .. } - | TerminatorKind::CoroutineDrop => { - continue 'predec_iter; - } - } - - if is_likely_const(switch_place, predecessor) { - new_blocks.push((predecessor_id, block_id)); - predecessors_left -= 1; - if predecessors_left < 2 { - // If the original block only has one predecessor left, - // we have nothing left to do - break 'predec_iter; - } - } - } - } - } - } - - // Once the analysis is done, perform the duplication - let body_span = body.span; - let copied_blocks = new_blocks.len(); - let blocks = body.basic_blocks_mut(); - for (pred_id, target_id) in new_blocks { - let new_block = blocks[target_id].clone(); - let new_block_id = blocks.push(new_block); - let terminator = blocks[pred_id].terminator_mut(); - - match terminator.kind { - TerminatorKind::Goto { ref mut target } => { - *target = new_block_id; - } - - TerminatorKind::FalseEdge { ref mut real_target, .. } => { - if *real_target == target_id { - *real_target = new_block_id; - } - } - - TerminatorKind::SwitchInt { ref mut targets, .. } => { - targets.all_targets_mut().iter_mut().for_each(|x| { - if *x == target_id { - *x = new_block_id; - } - }); - } - - TerminatorKind::UnwindResume - | TerminatorKind::UnwindTerminate(_) - | TerminatorKind::Return - | TerminatorKind::Unreachable - | TerminatorKind::CoroutineDrop - | TerminatorKind::Assert { .. } - | TerminatorKind::FalseUnwind { .. } - | TerminatorKind::Drop { .. } - | TerminatorKind::Call { .. } - | TerminatorKind::InlineAsm { .. } - | TerminatorKind::Yield { .. } => { - span_bug!( - body_span, - "basic block terminator had unexpected kind {:?}", - &terminator.kind - ) - } - } - } - - copied_blocks -} - -/// This function describes a rough heuristic guessing -/// whether a place is last set with a const within the block. -/// Notably, it will be overly pessimistic in cases that are already -/// not handled by `separate_const_switch`. -fn is_likely_const<'tcx>(mut tracked_place: Place<'tcx>, block: &BasicBlockData<'tcx>) -> bool { - for statement in block.statements.iter().rev() { - match &statement.kind { - StatementKind::Assign(assign) => { - if assign.0 == tracked_place { - match assign.1 { - // These rvalues are definitely constant - Rvalue::Use(Operand::Constant(_)) - | Rvalue::Ref(_, _, _) - | Rvalue::AddressOf(_, _) - | Rvalue::Cast(_, Operand::Constant(_), _) - | Rvalue::NullaryOp(_, _) - | Rvalue::ShallowInitBox(_, _) - | Rvalue::UnaryOp(_, Operand::Constant(_)) => return true, - - // These rvalues make things ambiguous - Rvalue::Repeat(_, _) - | Rvalue::ThreadLocalRef(_) - | Rvalue::Len(_) - | Rvalue::BinaryOp(_, _) - | Rvalue::CheckedBinaryOp(_, _) - | Rvalue::Aggregate(_, _) => return false, - - // These rvalues move the place to track - Rvalue::Cast(_, Operand::Copy(place) | Operand::Move(place), _) - | Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) - | Rvalue::CopyForDeref(place) - | Rvalue::UnaryOp(_, Operand::Copy(place) | Operand::Move(place)) - | Rvalue::Discriminant(place) => tracked_place = place, - } - } - } - - // If the discriminant is set, it is always set - // as a constant, so the job is done. - // As we are **ignoring projections**, if the place - // we are tracking sees its discriminant be set, - // that means we had to be tracking the discriminant - // specifically (as it is impossible to switch over - // an enum directly, and if we were switching over - // its content, we would have had to at least cast it to - // some variant first) - StatementKind::SetDiscriminant { place, .. } => { - if **place == tracked_place { - return true; - } - } - - // These statements have no influence on the place - // we are interested in - StatementKind::FakeRead(_) - | StatementKind::Deinit(_) - | StatementKind::StorageLive(_) - | StatementKind::Retag(_, _) - | StatementKind::AscribeUserType(_, _) - | StatementKind::PlaceMention(..) - | StatementKind::Coverage(_) - | StatementKind::StorageDead(_) - | StatementKind::Intrinsic(_) - | StatementKind::ConstEvalCounter - | StatementKind::Nop => {} - } - } - - // If no good reason for the place to be const is found, - // give up. We could maybe go up predecessors, but in - // most cases giving up now should be sufficient. - false -} - -/// Finds a unique place that entirely determines the value -/// of `switch_place`, if it exists. This is only a heuristic. -/// Ideally we would like to track multiple determining places -/// for some edge cases, but one is enough for a lot of situations. -fn find_determining_place<'tcx>( - mut switch_place: Place<'tcx>, - block: &BasicBlockData<'tcx>, -) -> Option<Place<'tcx>> { - for statement in block.statements.iter().rev() { - match &statement.kind { - StatementKind::Assign(op) => { - if op.0 != switch_place { - continue; - } - - match op.1 { - // The following rvalues move the place - // that may be const in the predecessor - Rvalue::Use(Operand::Move(new) | Operand::Copy(new)) - | Rvalue::UnaryOp(_, Operand::Copy(new) | Operand::Move(new)) - | Rvalue::CopyForDeref(new) - | Rvalue::Cast(_, Operand::Move(new) | Operand::Copy(new), _) - | Rvalue::Repeat(Operand::Move(new) | Operand::Copy(new), _) - | Rvalue::Discriminant(new) - => switch_place = new, - - // The following rvalues might still make the block - // be valid but for now we reject them - Rvalue::Len(_) - | Rvalue::Ref(_, _, _) - | Rvalue::BinaryOp(_, _) - | Rvalue::CheckedBinaryOp(_, _) - | Rvalue::Aggregate(_, _) - - // The following rvalues definitely mean we cannot - // or should not apply this optimization - | Rvalue::Use(Operand::Constant(_)) - | Rvalue::Repeat(Operand::Constant(_), _) - | Rvalue::ThreadLocalRef(_) - | Rvalue::AddressOf(_, _) - | Rvalue::NullaryOp(_, _) - | Rvalue::ShallowInitBox(_, _) - | Rvalue::UnaryOp(_, Operand::Constant(_)) - | Rvalue::Cast(_, Operand::Constant(_), _) => return None, - } - } - - // These statements have no influence on the place - // we are interested in - StatementKind::FakeRead(_) - | StatementKind::Deinit(_) - | StatementKind::StorageLive(_) - | StatementKind::StorageDead(_) - | StatementKind::Retag(_, _) - | StatementKind::AscribeUserType(_, _) - | StatementKind::PlaceMention(..) - | StatementKind::Coverage(_) - | StatementKind::Intrinsic(_) - | StatementKind::ConstEvalCounter - | StatementKind::Nop => {} - - // If the discriminant is set, it is always set - // as a constant, so the job is already done. - // As we are **ignoring projections**, if the place - // we are tracking sees its discriminant be set, - // that means we had to be tracking the discriminant - // specifically (as it is impossible to switch over - // an enum directly, and if we were switching over - // its content, we would have had to at least cast it to - // some variant first) - StatementKind::SetDiscriminant { place, .. } => { - if **place == switch_place { - return None; - } - } - } - } - - Some(switch_place) -} diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 89414ce940e..75613a2c555 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -3,8 +3,8 @@ use rustc_hir::def_id::DefId; use rustc_hir::lang_items::LangItem; use rustc_middle::mir::*; use rustc_middle::query::Providers; -use rustc_middle::ty::GenericArgs; use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt}; +use rustc_middle::ty::{GenericArgs, CAPTURE_STRUCT_LOCAL}; use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; use rustc_index::{Idx, IndexVec}; @@ -66,11 +66,76 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut)) } + ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind, + } => match target_kind { + ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"), + ty::ClosureKind::FnMut => { + // No need to optimize the body, it has already been optimized + // since we steal it from the `AsyncFn::call` body and just fix + // the return type. + return build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id); + } + ty::ClosureKind::FnOnce => { + build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id) + } + }, + + ty::InstanceDef::CoroutineKindShim { coroutine_def_id, target_kind } => match target_kind { + ty::ClosureKind::Fn => unreachable!(), + ty::ClosureKind::FnMut => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine_by_mut_body() + .unwrap() + .clone(); + } + ty::ClosureKind::FnOnce => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine_by_move_body() + .unwrap() + .clone(); + } + }, + ty::InstanceDef::DropGlue(def_id, ty) => { // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end // of this function. Is this intentional? if let Some(ty::Coroutine(coroutine_def_id, args)) = ty.map(Ty::kind) { - let body = tcx.optimized_mir(*coroutine_def_id).coroutine_drop().unwrap(); + let coroutine_body = tcx.optimized_mir(*coroutine_def_id); + + let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind() + else { + bug!() + }; + + // If this is a regular coroutine, grab its drop shim. If this is a coroutine + // that comes from a coroutine-closure, and the kind ty differs from the "maximum" + // kind that it supports, then grab the appropriate drop shim. This ensures that + // the future returned by `<[coroutine-closure] as AsyncFnOnce>::call_once` will + // drop the coroutine-closure's upvars. + let body = if id_args.as_coroutine().kind_ty() == args.as_coroutine().kind_ty() { + coroutine_body.coroutine_drop().unwrap() + } else { + match args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() { + ty::ClosureKind::Fn => { + unreachable!() + } + ty::ClosureKind::FnMut => coroutine_body + .coroutine_by_mut_body() + .unwrap() + .coroutine_drop() + .unwrap(), + ty::ClosureKind::FnOnce => coroutine_body + .coroutine_by_move_body() + .unwrap() + .coroutine_drop() + .unwrap(), + } + }; + let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args); debug!("make_shim({:?}) = {:?}", instance, body); @@ -382,16 +447,13 @@ fn build_thread_local_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'t fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Body<'tcx> { debug!("build_clone_shim(def_id={:?})", def_id); - let param_env = tcx.param_env_reveal_all_normalized(def_id); - let mut builder = CloneShimBuilder::new(tcx, def_id, self_ty); - let is_copy = self_ty.is_copy_modulo_regions(tcx, param_env); let dest = Place::return_place(); let src = tcx.mk_place_deref(Place::from(Local::new(1 + 0))); match self_ty.kind() { - _ if is_copy => builder.copy_shim(), + ty::FnDef(..) | ty::FnPtr(_) => builder.copy_shim(), ty::Closure(_, args) => builder.tuple_like_shim(dest, src, args.as_closure().upvar_tys()), ty::Tuple(..) => builder.tuple_like_shim(dest, src, self_ty.tuple_fields()), ty::Coroutine(coroutine_def_id, args) => { @@ -415,7 +477,7 @@ struct CloneShimBuilder<'tcx> { impl<'tcx> CloneShimBuilder<'tcx> { fn new(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Self { - // we must subst the self_ty because it's + // we must instantiate the self_ty because it's // otherwise going to be TySelf and we can't index // or access fields of a Place of type TySelf. let sig = tcx.fn_sig(def_id).instantiate(tcx, &[self_ty.into()]); @@ -654,8 +716,8 @@ fn build_call_shim<'tcx>( call_kind: CallKind<'tcx>, ) -> Body<'tcx> { // `FnPtrShim` contains the fn pointer type that a call shim is being built for - this is used - // to substitute into the signature of the shim. It is not necessary for users of this - // MIR body to perform further substitutions (see `InstanceDef::has_polymorphic_mir_body`). + // to instantiate into the signature of the shim. It is not necessary for users of this + // MIR body to perform further instantiations (see `InstanceDef::has_polymorphic_mir_body`). let (sig_args, untuple_args) = if let ty::InstanceDef::FnPtrShim(_, ty) = instance { let sig = tcx.instantiate_bound_regions_with_erased(ty.fn_sig(tcx)); @@ -981,3 +1043,114 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t let source = MirSource::from_instance(ty::InstanceDef::FnPtrAddrShim(def_id, self_ty)); new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span) } + +fn build_construct_coroutine_by_move_shim<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_closure_def_id: DefId, +) -> Body<'tcx> { + let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let ty::CoroutineClosure(_, args) = *self_ty.kind() else { + bug!(); + }; + + let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { + tcx.mk_fn_sig( + [self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()), + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.as_coroutine_closure().parent_args(), + tcx.coroutine_for_closure(coroutine_closure_def_id), + ty::ClosureKind::FnOnce, + tcx.lifetimes.re_erased, + args.as_coroutine_closure().tupled_upvars_ty(), + args.as_coroutine_closure().coroutine_captures_by_ref_ty(), + ), + sig.c_variadic, + sig.unsafety, + sig.abi, + ) + }); + let sig = tcx.liberate_late_bound_regions(coroutine_closure_def_id, poly_sig); + let ty::Coroutine(coroutine_def_id, coroutine_args) = *sig.output().kind() else { + bug!(); + }; + + let span = tcx.def_span(coroutine_closure_def_id); + let locals = local_decls_for_sig(&sig, span); + + let mut fields = vec![]; + for idx in 1..sig.inputs().len() { + fields.push(Operand::Move(Local::from_usize(idx + 1).into())); + } + for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() { + fields.push(Operand::Move(tcx.mk_place_field( + Local::from_usize(1).into(), + FieldIdx::from_usize(idx), + ty, + ))); + } + + let source_info = SourceInfo::outermost(span); + let rvalue = Rvalue::Aggregate( + Box::new(AggregateKind::Coroutine(coroutine_def_id, coroutine_args)), + IndexVec::from_raw(fields), + ); + let stmt = Statement { + source_info, + kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))), + }; + let statements = vec![stmt]; + let start_block = BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }; + + let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind: ty::ClosureKind::FnOnce, + }); + + let body = + new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span); + dump_mir(tcx, false, "coroutine_closure_by_move", &0, &body, |_, _| Ok(())); + + body +} + +fn build_construct_coroutine_by_mut_shim<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_closure_def_id: DefId, +) -> Body<'tcx> { + let mut body = tcx.optimized_mir(coroutine_closure_def_id).clone(); + let coroutine_closure_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let ty::CoroutineClosure(_, args) = *coroutine_closure_ty.kind() else { + bug!(); + }; + let args = args.as_coroutine_closure(); + + body.local_decls[RETURN_PLACE].ty = + tcx.instantiate_bound_regions_with_erased(args.coroutine_closure_sig().map_bound(|sig| { + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(coroutine_closure_def_id), + ty::ClosureKind::FnMut, + tcx.lifetimes.re_erased, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ) + })); + body.local_decls[CAPTURE_STRUCT_LOCAL].ty = + Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_closure_ty); + + body.source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind: ty::ClosureKind::FnMut, + }); + + body.pass_count = 0; + dump_mir(tcx, false, "coroutine_closure_by_mut", &0, &body, |_, _| Ok(())); + + body +} diff --git a/compiler/rustc_mir_transform/src/ssa.rs b/compiler/rustc_mir_transform/src/ssa.rs index 1ed3b14e755..e4fdbd6ae69 100644 --- a/compiler/rustc_mir_transform/src/ssa.rs +++ b/compiler/rustc_mir_transform/src/ssa.rs @@ -170,7 +170,7 @@ impl SsaLocals { /// _c => _a /// _d => _a // transitively through _c /// - /// Exception: we do not see through the return place, as it cannot be substituted. + /// Exception: we do not see through the return place, as it cannot be instantiated. pub fn copy_classes(&self) -> &IndexSlice<Local, Local> { &self.copy_classes } |
