diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
30 files changed, 1406 insertions, 1351 deletions
diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs index dfc7a9891f9..ba70a4453d6 100644 --- a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs +++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs @@ -39,7 +39,9 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls { let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, + ty::CoroutineClosure(..) => Abi::RustCall, ty::Coroutine(..) => Abi::Rust, + ty::Error(_) => return, _ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty), }; let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi); diff --git a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs index 3195cd3622d..1f615c9d8d1 100644 --- a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs +++ b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs @@ -100,7 +100,7 @@ impl<'tcx> Visitor<'tcx> for ConstMutationChecker<'_, 'tcx> { && let Some((lint_root, span, item)) = self.should_lint_const_item_usage(lhs, def_id, loc) { - self.tcx.emit_spanned_lint( + self.tcx.emit_node_span_lint( CONST_ITEM_MUTATION, lint_root, span, @@ -145,7 +145,7 @@ impl<'tcx> Visitor<'tcx> for ConstMutationChecker<'_, 'tcx> { if let Some((lint_root, span, item)) = self.should_lint_const_item_usage(place, def_id, lint_loc) { - self.tcx.emit_spanned_lint( + self.tcx.emit_node_span_lint( CONST_ITEM_MUTATION, lint_root, span, diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs index 582c2c0c6b6..fbb62695383 100644 --- a/compiler/rustc_mir_transform/src/check_unsafety.rs +++ b/compiler/rustc_mir_transform/src/check_unsafety.rs @@ -128,7 +128,9 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { ), } } - &AggregateKind::Closure(def_id, _) | &AggregateKind::Coroutine(def_id, _) => { + &AggregateKind::Closure(def_id, _) + | &AggregateKind::CoroutineClosure(def_id, _) + | &AggregateKind::Coroutine(def_id, _) => { let def_id = def_id.expect_local(); let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } = self.tcx.mir_unsafety_check_result(def_id); @@ -527,7 +529,7 @@ fn report_unused_unsafe(tcx: TyCtxt<'_>, kind: UnusedUnsafe, id: HirId) { } else { None }; - tcx.emit_spanned_lint(UNUSED_UNSAFE, id, span, errors::UnusedUnsafe { span, nested_parent }); + tcx.emit_node_span_lint(UNUSED_UNSAFE, id, span, errors::UnusedUnsafe { span, nested_parent }); } pub fn check_unsafety(tcx: TyCtxt<'_>, def_id: LocalDefId) { @@ -577,7 +579,7 @@ pub fn check_unsafety(tcx: TyCtxt<'_>, def_id: LocalDefId) { }); } UnsafetyViolationKind::UnsafeFn => { - tcx.emit_spanned_lint( + tcx.emit_node_span_lint( UNSAFE_OP_IN_UNSAFE_FN, lint_root, source_info.span, diff --git a/compiler/rustc_mir_transform/src/const_goto.rs b/compiler/rustc_mir_transform/src/const_goto.rs deleted file mode 100644 index cb5b66b314d..00000000000 --- a/compiler/rustc_mir_transform/src/const_goto.rs +++ /dev/null @@ -1,128 +0,0 @@ -//! This pass optimizes the following sequence -//! ```rust,ignore (example) -//! bb2: { -//! _2 = const true; -//! goto -> bb3; -//! } -//! -//! bb3: { -//! switchInt(_2) -> [false: bb4, otherwise: bb5]; -//! } -//! ``` -//! into -//! ```rust,ignore (example) -//! bb2: { -//! _2 = const true; -//! goto -> bb5; -//! } -//! ``` - -use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; -use rustc_middle::{mir::visit::Visitor, ty::ParamEnv}; - -use super::simplify::{simplify_cfg, simplify_locals}; - -pub struct ConstGoto; - -impl<'tcx> MirPass<'tcx> for ConstGoto { - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - // This pass participates in some as-of-yet untested unsoundness found - // in https://github.com/rust-lang/rust/issues/112460 - sess.mir_opt_level() >= 2 && sess.opts.unstable_opts.unsound_mir_opts - } - - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - trace!("Running ConstGoto on {:?}", body.source); - let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); - let mut opt_finder = - ConstGotoOptimizationFinder { tcx, body, optimizations: vec![], param_env }; - opt_finder.visit_body(body); - let should_simplify = !opt_finder.optimizations.is_empty(); - for opt in opt_finder.optimizations { - let block = &mut body.basic_blocks_mut()[opt.bb_with_goto]; - block.statements.extend(opt.stmts_move_up); - let terminator = block.terminator_mut(); - let new_goto = TerminatorKind::Goto { target: opt.target_to_use_in_goto }; - debug!("SUCCESS: replacing `{:?}` with `{:?}`", terminator.kind, new_goto); - terminator.kind = new_goto; - } - - // if we applied optimizations, we potentially have some cfg to cleanup to - // make it easier for further passes - if should_simplify { - simplify_cfg(body); - simplify_locals(body, tcx); - } - } -} - -impl<'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'_, 'tcx> { - fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { - if data.is_cleanup { - // Because of the restrictions around control flow in cleanup blocks, we don't perform - // this optimization at all in such blocks. - return; - } - self.super_basic_block_data(block, data); - } - - fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { - let _: Option<_> = try { - let target = terminator.kind.as_goto()?; - // We only apply this optimization if the last statement is a const assignment - let last_statement = self.body.basic_blocks[location.block].statements.last()?; - - if let (place, Rvalue::Use(Operand::Constant(_const))) = - last_statement.kind.as_assign()? - { - // We found a constant being assigned to `place`. - // Now check that the target of this Goto switches on this place. - let target_bb = &self.body.basic_blocks[target]; - - // The `StorageDead(..)` statement does not affect the functionality of mir. - // We can move this part of the statement up to the predecessor. - let mut stmts_move_up = Vec::new(); - for stmt in &target_bb.statements { - if let StatementKind::StorageDead(..) = stmt.kind { - stmts_move_up.push(stmt.clone()) - } else { - None?; - } - } - - let target_bb_terminator = target_bb.terminator(); - let (discr, targets) = target_bb_terminator.kind.as_switch()?; - if discr.place() == Some(*place) { - let switch_ty = place.ty(self.body.local_decls(), self.tcx).ty; - debug_assert_eq!(switch_ty, _const.ty()); - // We now know that the Switch matches on the const place, and it is statementless - // Now find which value in the Switch matches the const value. - let const_value = _const.const_.try_eval_bits(self.tcx, self.param_env)?; - let target_to_use_in_goto = targets.target_for_value(const_value); - self.optimizations.push(OptimizationToApply { - bb_with_goto: location.block, - target_to_use_in_goto, - stmts_move_up, - }); - } - } - Some(()) - }; - - self.super_terminator(terminator, location); - } -} - -struct OptimizationToApply<'tcx> { - bb_with_goto: BasicBlock, - target_to_use_in_goto: BasicBlock, - stmts_move_up: Vec<Statement<'tcx>>, -} - -pub struct ConstGotoOptimizationFinder<'a, 'tcx> { - tcx: TyCtxt<'tcx>, - body: &'a Body<'tcx>, - param_env: ParamEnv<'tcx>, - optimizations: Vec<OptimizationToApply<'tcx>>, -} diff --git a/compiler/rustc_mir_transform/src/const_prop.rs b/compiler/rustc_mir_transform/src/const_prop.rs index c5824c30770..eba62aae60f 100644 --- a/compiler/rustc_mir_transform/src/const_prop.rs +++ b/compiler/rustc_mir_transform/src/const_prop.rs @@ -1,21 +1,12 @@ //! Propagates constants for early reporting of statically known //! assertion failures -use rustc_const_eval::interpret::{ - self, compile_time_machine, AllocId, ConstAllocation, FnArg, Frame, ImmTy, InterpCx, - InterpResult, OpTy, PlaceTy, Pointer, -}; -use rustc_data_structures::fx::FxHashSet; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; -use rustc_middle::query::TyCtxtAt; -use rustc_middle::ty::layout::TyAndLayout; -use rustc_middle::ty::{self, ParamEnv, TyCtxt}; -use rustc_span::def_id::DefId; +use rustc_middle::ty::{ParamEnv, TyCtxt}; use rustc_target::abi::Size; -use rustc_target::spec::abi::Abi as CallAbi; /// The maximum number of bytes that we'll allocate space for a local or the return value. /// Needed for #66397, because otherwise we eval into large places and that can cause OOM or just @@ -43,168 +34,12 @@ pub(crate) macro throw_machine_stop_str($($tt:tt)*) {{ fn add_args( self: Box<Self>, - _: &mut dyn FnMut(std::borrow::Cow<'static, str>, rustc_errors::DiagnosticArgValue<'static>), + _: &mut dyn FnMut(rustc_errors::DiagnosticArgName, rustc_errors::DiagnosticArgValue), ) {} } throw_machine_stop!(Zst) }} -pub(crate) struct ConstPropMachine<'mir, 'tcx> { - /// The virtual call stack. - stack: Vec<Frame<'mir, 'tcx>>, - pub written_only_inside_own_block_locals: FxHashSet<Local>, - pub can_const_prop: IndexVec<Local, ConstPropMode>, -} - -impl ConstPropMachine<'_, '_> { - pub fn new(can_const_prop: IndexVec<Local, ConstPropMode>) -> Self { - Self { - stack: Vec::new(), - written_only_inside_own_block_locals: Default::default(), - can_const_prop, - } - } -} - -impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> { - compile_time_machine!(<'mir, 'tcx>); - - const PANIC_ON_ALLOC_FAIL: bool = true; // all allocations are small (see `MAX_ALLOC_LIMIT`) - - const POST_MONO_CHECKS: bool = false; // this MIR is still generic! - - type MemoryKind = !; - - #[inline(always)] - fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool { - false // no reason to enforce alignment - } - - #[inline(always)] - fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool { - false // for now, we don't enforce validity - } - - fn load_mir( - _ecx: &InterpCx<'mir, 'tcx, Self>, - _instance: ty::InstanceDef<'tcx>, - ) -> InterpResult<'tcx, &'tcx Body<'tcx>> { - throw_machine_stop_str!("calling functions isn't supported in ConstProp") - } - - fn panic_nounwind(_ecx: &mut InterpCx<'mir, 'tcx, Self>, _msg: &str) -> InterpResult<'tcx> { - throw_machine_stop_str!("panicking isn't supported in ConstProp") - } - - fn find_mir_or_eval_fn( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _instance: ty::Instance<'tcx>, - _abi: CallAbi, - _args: &[FnArg<'tcx>], - _destination: &PlaceTy<'tcx>, - _target: Option<BasicBlock>, - _unwind: UnwindAction, - ) -> InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> { - Ok(None) - } - - fn call_intrinsic( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _instance: ty::Instance<'tcx>, - _args: &[OpTy<'tcx>], - _destination: &PlaceTy<'tcx>, - _target: Option<BasicBlock>, - _unwind: UnwindAction, - ) -> InterpResult<'tcx> { - throw_machine_stop_str!("calling intrinsics isn't supported in ConstProp") - } - - fn assert_panic( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - _msg: &rustc_middle::mir::AssertMessage<'tcx>, - _unwind: rustc_middle::mir::UnwindAction, - ) -> InterpResult<'tcx> { - bug!("panics terminators are not evaluated in ConstProp") - } - - fn binary_ptr_op( - _ecx: &InterpCx<'mir, 'tcx, Self>, - _bin_op: BinOp, - _left: &ImmTy<'tcx>, - _right: &ImmTy<'tcx>, - ) -> InterpResult<'tcx, (ImmTy<'tcx>, bool)> { - // We can't do this because aliasing of memory can differ between const eval and llvm - throw_machine_stop_str!("pointer arithmetic or comparisons aren't supported in ConstProp") - } - - fn before_access_local_mut<'a>( - ecx: &'a mut InterpCx<'mir, 'tcx, Self>, - frame: usize, - local: Local, - ) -> InterpResult<'tcx> { - assert_eq!(frame, 0); - match ecx.machine.can_const_prop[local] { - ConstPropMode::NoPropagation => { - throw_machine_stop_str!( - "tried to write to a local that is marked as not propagatable" - ) - } - ConstPropMode::OnlyInsideOwnBlock => { - ecx.machine.written_only_inside_own_block_locals.insert(local); - } - ConstPropMode::FullConstProp => {} - } - Ok(()) - } - - fn before_access_global( - _tcx: TyCtxtAt<'tcx>, - _machine: &Self, - _alloc_id: AllocId, - alloc: ConstAllocation<'tcx>, - _static_def_id: Option<DefId>, - is_write: bool, - ) -> InterpResult<'tcx> { - if is_write { - throw_machine_stop_str!("can't write to global"); - } - // If the static allocation is mutable, then we can't const prop it as its content - // might be different at runtime. - if alloc.inner().mutability.is_mut() { - throw_machine_stop_str!("can't access mutable globals in ConstProp"); - } - - Ok(()) - } - - #[inline(always)] - fn expose_ptr(_ecx: &mut InterpCx<'mir, 'tcx, Self>, _ptr: Pointer) -> InterpResult<'tcx> { - throw_machine_stop_str!("exposing pointers isn't supported in ConstProp") - } - - #[inline(always)] - fn init_frame_extra( - _ecx: &mut InterpCx<'mir, 'tcx, Self>, - frame: Frame<'mir, 'tcx>, - ) -> InterpResult<'tcx, Frame<'mir, 'tcx>> { - Ok(frame) - } - - #[inline(always)] - fn stack<'a>( - ecx: &'a InterpCx<'mir, 'tcx, Self>, - ) -> &'a [Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>] { - &ecx.machine.stack - } - - #[inline(always)] - fn stack_mut<'a>( - ecx: &'a mut InterpCx<'mir, 'tcx, Self>, - ) -> &'a mut Vec<Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>> { - &mut ecx.machine.stack - } -} - /// The mode that `ConstProp` is allowed to run in for a given `Local`. #[derive(Clone, Copy, Debug, PartialEq)] pub enum ConstPropMode { diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs index d0bbca08a40..f8e6905282c 100644 --- a/compiler/rustc_mir_transform/src/const_prop_lint.rs +++ b/compiler/rustc_mir_transform/src/const_prop_lint.rs @@ -3,37 +3,27 @@ use std::fmt::Debug; -use either::Left; - -use rustc_const_eval::interpret::Immediate; use rustc_const_eval::interpret::{ - InterpCx, InterpResult, MemoryKind, OpTy, Scalar, StackPopCleanup, + format_interp_error, ImmTy, InterpCx, InterpResult, Projectable, Scalar, }; -use rustc_const_eval::ReportErrorExt; +use rustc_data_structures::fx::FxHashSet; use rustc_hir::def::DefKind; use rustc_hir::HirId; use rustc_index::bit_set::BitSet; +use rustc_index::{Idx, IndexVec}; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; -use rustc_middle::ty::GenericArgs; -use rustc_middle::ty::{ - self, ConstInt, Instance, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisitableExt, -}; +use rustc_middle::ty::{self, ConstInt, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisitableExt}; use rustc_span::Span; -use rustc_target::abi::{HasDataLayout, Size, TargetDataLayout}; +use rustc_target::abi::{Abi, FieldIdx, HasDataLayout, Size, TargetDataLayout, VariantIdx}; use crate::const_prop::CanConstProp; -use crate::const_prop::ConstPropMachine; use crate::const_prop::ConstPropMode; -use crate::errors::AssertLint; +use crate::dataflow_const_prop::DummyMachine; +use crate::errors::{AssertLint, AssertLintKind}; use crate::MirLint; -/// The maximum number of bytes that we'll allocate space for a local or the return value. -/// Needed for #66397, because otherwise we eval into large places and that can cause OOM or just -/// Severely regress performance. -const MAX_ALLOC_LIMIT: u64 = 1024; - pub struct ConstPropLint; impl<'tcx> MirLint<'tcx> for ConstPropLint { @@ -81,11 +71,85 @@ impl<'tcx> MirLint<'tcx> for ConstPropLint { /// Finds optimization opportunities on the MIR. struct ConstPropagator<'mir, 'tcx> { - ecx: InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, + ecx: InterpCx<'mir, 'tcx, DummyMachine>, tcx: TyCtxt<'tcx>, param_env: ParamEnv<'tcx>, worklist: Vec<BasicBlock>, visited_blocks: BitSet<BasicBlock>, + locals: IndexVec<Local, Value<'tcx>>, + body: &'mir Body<'tcx>, + written_only_inside_own_block_locals: FxHashSet<Local>, + can_const_prop: IndexVec<Local, ConstPropMode>, +} + +#[derive(Debug, Clone)] +enum Value<'tcx> { + Immediate(ImmTy<'tcx>), + Aggregate { variant: VariantIdx, fields: IndexVec<FieldIdx, Value<'tcx>> }, + Uninit, +} + +impl<'tcx> From<ImmTy<'tcx>> for Value<'tcx> { + fn from(v: ImmTy<'tcx>) -> Self { + Self::Immediate(v) + } +} + +impl<'tcx> Value<'tcx> { + fn project( + &self, + proj: &[PlaceElem<'tcx>], + prop: &ConstPropagator<'_, 'tcx>, + ) -> Option<&Value<'tcx>> { + let mut this = self; + for proj in proj { + this = match (*proj, this) { + (PlaceElem::Field(idx, _), Value::Aggregate { fields, .. }) => { + fields.get(idx).unwrap_or(&Value::Uninit) + } + (PlaceElem::Index(idx), Value::Aggregate { fields, .. }) => { + let idx = prop.get_const(idx.into())?.immediate()?; + let idx = prop.ecx.read_target_usize(idx).ok()?; + fields.get(FieldIdx::from_u32(idx.try_into().ok()?)).unwrap_or(&Value::Uninit) + } + ( + PlaceElem::ConstantIndex { offset, min_length: _, from_end: false }, + Value::Aggregate { fields, .. }, + ) => fields + .get(FieldIdx::from_u32(offset.try_into().ok()?)) + .unwrap_or(&Value::Uninit), + _ => return None, + }; + } + Some(this) + } + + fn project_mut(&mut self, proj: &[PlaceElem<'_>]) -> Option<&mut Value<'tcx>> { + let mut this = self; + for proj in proj { + this = match (proj, this) { + (PlaceElem::Field(idx, _), Value::Aggregate { fields, .. }) => { + fields.ensure_contains_elem(*idx, || Value::Uninit) + } + (PlaceElem::Field(..), val @ Value::Uninit) => { + *val = Value::Aggregate { + variant: VariantIdx::new(0), + fields: Default::default(), + }; + val.project_mut(&[*proj])? + } + _ => return None, + }; + } + Some(this) + } + + fn immediate(&self) -> Option<&ImmTy<'tcx>> { + match self { + Value::Immediate(op) => Some(op), + _ => None, + } + } } impl<'tcx> LayoutOfHelpers<'tcx> for ConstPropagator<'_, 'tcx> { @@ -121,49 +185,10 @@ impl<'tcx> ty::layout::HasParamEnv<'tcx> for ConstPropagator<'_, 'tcx> { impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { fn new(body: &'mir Body<'tcx>, tcx: TyCtxt<'tcx>) -> ConstPropagator<'mir, 'tcx> { let def_id = body.source.def_id(); - let args = &GenericArgs::identity_for_item(tcx, def_id); let param_env = tcx.param_env_reveal_all_normalized(def_id); let can_const_prop = CanConstProp::check(tcx, param_env, body); - let mut ecx = InterpCx::new( - tcx, - tcx.def_span(def_id), - param_env, - ConstPropMachine::new(can_const_prop), - ); - - let ret_layout = ecx - .layout_of(body.bound_return_ty().instantiate(tcx, args)) - .ok() - // Don't bother allocating memory for large values. - // I don't know how return types can seem to be unsized but this happens in the - // `type/type-unsatisfiable.rs` test. - .filter(|ret_layout| { - ret_layout.is_sized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) - }) - .unwrap_or_else(|| ecx.layout_of(tcx.types.unit).unwrap()); - - let ret = ecx - .allocate(ret_layout, MemoryKind::Stack) - .expect("couldn't perform small allocation") - .into(); - - ecx.push_stack_frame( - Instance::new(def_id, args), - body, - &ret, - StackPopCleanup::Root { cleanup: false }, - ) - .expect("failed to push initial stack frame"); - - for local in body.local_decls.indices() { - // Mark everything initially live. - // This is somewhat dicey since some of them might be unsized and it is incoherent to - // mark those as live... We rely on `local_to_place`/`local_to_op` in the interpreter - // stopping us before those unsized immediates can cause issues deeper in the - // interpreter. - ecx.frame_mut().locals[local].make_live_uninit(); - } + let ecx = InterpCx::new(tcx, tcx.def_span(def_id), param_env, DummyMachine); ConstPropagator { ecx, @@ -171,61 +196,47 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { param_env, worklist: vec![START_BLOCK], visited_blocks: BitSet::new_empty(body.basic_blocks.len()), + locals: IndexVec::from_elem_n(Value::Uninit, body.local_decls.len()), + body, + can_const_prop, + written_only_inside_own_block_locals: Default::default(), } } - fn body(&self) -> &'mir Body<'tcx> { - self.ecx.frame().body - } - fn local_decls(&self) -> &'mir LocalDecls<'tcx> { - &self.body().local_decls + &self.body.local_decls } - fn get_const(&self, place: Place<'tcx>) -> Option<OpTy<'tcx>> { - let op = match self.ecx.eval_place_to_op(place, None) { - Ok(op) => { - if op - .as_mplace_or_imm() - .right() - .is_some_and(|imm| matches!(*imm, Immediate::Uninit)) - { - // Make sure nobody accidentally uses this value. - return None; - } - op - } - Err(e) => { - trace!("get_const failed: {:?}", e.into_kind().debug()); - return None; - } - }; - - // Try to read the local as an immediate so that if it is representable as a scalar, we can - // handle it as such, but otherwise, just return the value as is. - Some(match self.ecx.read_immediate_raw(&op) { - Ok(Left(imm)) => imm.into(), - _ => op, - }) + fn get_const(&self, place: Place<'tcx>) -> Option<&Value<'tcx>> { + self.locals[place.local].project(&place.projection, self) } /// Remove `local` from the pool of `Locals`. Allows writing to them, /// but not reading from them anymore. - fn remove_const(ecx: &mut InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, local: Local) { - ecx.frame_mut().locals[local].make_live_uninit(); - ecx.machine.written_only_inside_own_block_locals.remove(&local); + fn remove_const(&mut self, local: Local) { + self.locals[local] = Value::Uninit; + self.written_only_inside_own_block_locals.remove(&local); + } + + fn access_mut(&mut self, place: &Place<'_>) -> Option<&mut Value<'tcx>> { + match self.can_const_prop[place.local] { + ConstPropMode::NoPropagation => return None, + ConstPropMode::OnlyInsideOwnBlock => { + self.written_only_inside_own_block_locals.insert(place.local); + } + ConstPropMode::FullConstProp => {} + } + self.locals[place.local].project_mut(place.projection) } fn lint_root(&self, source_info: SourceInfo) -> Option<HirId> { - source_info.scope.lint_root(&self.body().source_scopes) + source_info.scope.lint_root(&self.body.source_scopes) } - fn use_ecx<F, T>(&mut self, location: Location, f: F) -> Option<T> + fn use_ecx<F, T>(&mut self, f: F) -> Option<T> where F: FnOnce(&mut Self) -> InterpResult<'tcx, T>, { - // Overwrite the PC -- whatever the interpreter does to it does not make any sense anyway. - self.ecx.frame_mut().loc = Left(location); match f(self) { Ok(val) => Some(val), Err(error) => { @@ -236,7 +247,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { assert!( !error.kind().formatted_string(), "const-prop encountered formatting error: {}", - self.ecx.format_error(error), + format_interp_error(self.ecx.tcx.dcx(), error), ); None } @@ -244,7 +255,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } /// Returns the value, if any, of evaluating `c`. - fn eval_constant(&mut self, c: &ConstOperand<'tcx>, location: Location) -> Option<OpTy<'tcx>> { + fn eval_constant(&mut self, c: &ConstOperand<'tcx>) -> Option<ImmTy<'tcx>> { // FIXME we need to revisit this for #67176 if c.has_param() { return None; @@ -258,46 +269,62 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // manually normalized. let val = self.tcx.try_normalize_erasing_regions(self.param_env, c.const_).ok()?; - self.use_ecx(location, |this| this.ecx.eval_mir_constant(&val, Some(c.span), None)) + self.use_ecx(|this| this.ecx.eval_mir_constant(&val, Some(c.span), None))? + .as_mplace_or_imm() + .right() } /// Returns the value, if any, of evaluating `place`. - fn eval_place(&mut self, place: Place<'tcx>, location: Location) -> Option<OpTy<'tcx>> { - trace!("eval_place(place={:?})", place); - self.use_ecx(location, |this| this.ecx.eval_place_to_op(place, None)) + #[instrument(level = "trace", skip(self), ret)] + fn eval_place(&mut self, place: Place<'tcx>) -> Option<ImmTy<'tcx>> { + match self.get_const(place)? { + Value::Immediate(imm) => Some(imm.clone()), + Value::Aggregate { .. } => None, + Value::Uninit => None, + } } /// Returns the value, if any, of evaluating `op`. Calls upon `eval_constant` /// or `eval_place`, depending on the variant of `Operand` used. - fn eval_operand(&mut self, op: &Operand<'tcx>, location: Location) -> Option<OpTy<'tcx>> { + fn eval_operand(&mut self, op: &Operand<'tcx>) -> Option<ImmTy<'tcx>> { match *op { - Operand::Constant(ref c) => self.eval_constant(c, location), - Operand::Move(place) | Operand::Copy(place) => self.eval_place(place, location), + Operand::Constant(ref c) => self.eval_constant(c), + Operand::Move(place) | Operand::Copy(place) => self.eval_place(place), } } - fn report_assert_as_lint(&self, source_info: &SourceInfo, lint: AssertLint<impl Debug>) { + fn report_assert_as_lint( + &self, + location: Location, + lint_kind: AssertLintKind, + assert_kind: AssertKind<impl Debug>, + ) { + let source_info = self.body.source_info(location); if let Some(lint_root) = self.lint_root(*source_info) { - self.tcx.emit_spanned_lint(lint.lint(), lint_root, source_info.span, lint); + let span = source_info.span; + self.tcx.emit_node_span_lint( + lint_kind.lint(), + lint_root, + span, + AssertLint { span, assert_kind, lint_kind }, + ); } } fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>, location: Location) -> Option<()> { - if let (val, true) = self.use_ecx(location, |this| { - let val = this.ecx.read_immediate(&this.ecx.eval_operand(arg, None)?)?; + let arg = self.eval_operand(arg)?; + if let (val, true) = self.use_ecx(|this| { + let val = this.ecx.read_immediate(&arg)?; let (_res, overflow) = this.ecx.overflowing_unary_op(op, &val)?; Ok((val, overflow)) })? { // `AssertKind` only has an `OverflowNeg` variant, so make sure that is // appropriate to use. assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow"); - let source_info = self.body().source_info(location); self.report_assert_as_lint( - source_info, - AssertLint::ArithmeticOverflow( - source_info.span, - AssertKind::OverflowNeg(val.to_const_int()), - ), + location, + AssertLintKind::ArithmeticOverflow, + AssertKind::OverflowNeg(val.to_const_int()), ); return None; } @@ -312,11 +339,10 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { right: &Operand<'tcx>, location: Location, ) -> Option<()> { - let r = self.use_ecx(location, |this| { - this.ecx.read_immediate(&this.ecx.eval_operand(right, None)?) - }); - let l = self - .use_ecx(location, |this| this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?)); + let r = + self.eval_operand(right).and_then(|r| self.use_ecx(|this| this.ecx.read_immediate(&r))); + let l = + self.eval_operand(left).and_then(|l| self.use_ecx(|this| this.ecx.read_immediate(&l))); // Check for exceeding shifts *even if* we cannot evaluate the LHS. if matches!(op, BinOp::Shr | BinOp::Shl) { let r = r.clone()?; @@ -328,7 +354,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { let r_bits = r.to_scalar().to_bits(right_size).ok(); if r_bits.is_some_and(|b| b >= left_size.bits() as u128) { debug!("check_binary_op: reporting assert for {:?}", location); - let source_info = self.body().source_info(location); let panic = AssertKind::Overflow( op, match l { @@ -342,27 +367,21 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { }, r.to_const_int(), ); - self.report_assert_as_lint( - source_info, - AssertLint::ArithmeticOverflow(source_info.span, panic), - ); + self.report_assert_as_lint(location, AssertLintKind::ArithmeticOverflow, panic); return None; } } if let (Some(l), Some(r)) = (l, r) { // The remaining operators are handled through `overflowing_binary_op`. - if self.use_ecx(location, |this| { + if self.use_ecx(|this| { let (_res, overflow) = this.ecx.overflowing_binary_op(op, &l, &r)?; Ok(overflow) })? { - let source_info = self.body().source_info(location); self.report_assert_as_lint( - source_info, - AssertLint::ArithmeticOverflow( - source_info.span, - AssertKind::Overflow(op, l.to_const_int(), r.to_const_int()), - ), + location, + AssertLintKind::ArithmeticOverflow, + AssertKind::Overflow(op, l.to_const_int(), r.to_const_int()), ); return None; } @@ -411,7 +430,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // value the local has right now. // Thus, all locals that have their reference taken // must not take part in propagation. - Self::remove_const(&mut self.ecx, place.local); + self.remove_const(place.local); return None; } @@ -453,17 +472,17 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { cond: &Operand<'tcx>, location: Location, ) -> Option<!> { - let value = &self.eval_operand(cond, location)?; + let value = &self.eval_operand(cond)?; trace!("assertion on {:?} should be {:?}", value, expected); let expected = Scalar::from_bool(expected); - let value_const = self.use_ecx(location, |this| this.ecx.read_scalar(value))?; + let value_const = self.use_ecx(|this| this.ecx.read_scalar(value))?; if expected != value_const { // Poison all places this operand references so that further code // doesn't use the invalid value if let Some(place) = cond.place() { - Self::remove_const(&mut self.ecx, place.local); + self.remove_const(place.local); } enum DbgVal<T> { @@ -481,7 +500,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { let mut eval_to_int = |op| { // This can be `None` if the lhs wasn't const propagated and we just // triggered the assert on the value of the rhs. - self.eval_operand(op, location) + self.eval_operand(op) .and_then(|op| self.ecx.read_immediate(&op).ok()) .map_or(DbgVal::Underscore, |op| DbgVal::Val(op.to_const_int())) }; @@ -503,11 +522,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // Need proper const propagator for these. _ => return None, }; - let source_info = self.body().source_info(location); - self.report_assert_as_lint( - source_info, - AssertLint::UnconditionalPanic(source_info.span, msg), - ); + self.report_assert_as_lint(location, AssertLintKind::UnconditionalPanic, msg); } None @@ -515,16 +530,173 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { fn ensure_not_propagated(&self, local: Local) { if cfg!(debug_assertions) { + let val = self.get_const(local.into()); assert!( - self.get_const(local.into()).is_none() + matches!(val, Some(Value::Uninit)) || self .layout_of(self.local_decls()[local].ty) .map_or(true, |layout| layout.is_zst()), - "failed to remove values for `{local:?}`, value={:?}", - self.get_const(local.into()), + "failed to remove values for `{local:?}`, value={val:?}", ) } } + + #[instrument(level = "trace", skip(self), ret)] + fn eval_rvalue(&mut self, rvalue: &Rvalue<'tcx>, dest: &Place<'tcx>) -> Option<()> { + if !dest.projection.is_empty() { + return None; + } + use rustc_middle::mir::Rvalue::*; + let layout = self.ecx.layout_of(dest.ty(self.body, self.tcx).ty).ok()?; + trace!(?layout); + + let val: Value<'_> = match *rvalue { + ThreadLocalRef(_) => return None, + + Use(ref operand) => self.eval_operand(operand)?.into(), + + CopyForDeref(place) => self.eval_place(place)?.into(), + + BinaryOp(bin_op, box (ref left, ref right)) => { + let left = self.eval_operand(left)?; + let left = self.use_ecx(|this| this.ecx.read_immediate(&left))?; + + let right = self.eval_operand(right)?; + let right = self.use_ecx(|this| this.ecx.read_immediate(&right))?; + + let val = + self.use_ecx(|this| this.ecx.wrapping_binary_op(bin_op, &left, &right))?; + val.into() + } + + CheckedBinaryOp(bin_op, box (ref left, ref right)) => { + let left = self.eval_operand(left)?; + let left = self.use_ecx(|this| this.ecx.read_immediate(&left))?; + + let right = self.eval_operand(right)?; + let right = self.use_ecx(|this| this.ecx.read_immediate(&right))?; + + let (val, overflowed) = + self.use_ecx(|this| this.ecx.overflowing_binary_op(bin_op, &left, &right))?; + let overflowed = ImmTy::from_bool(overflowed, self.tcx); + Value::Aggregate { + variant: VariantIdx::new(0), + fields: [Value::from(val), overflowed.into()].into_iter().collect(), + } + } + + UnaryOp(un_op, ref operand) => { + let operand = self.eval_operand(operand)?; + let val = self.use_ecx(|this| this.ecx.read_immediate(&operand))?; + + let val = self.use_ecx(|this| this.ecx.wrapping_unary_op(un_op, &val))?; + val.into() + } + + Aggregate(ref kind, ref fields) => Value::Aggregate { + fields: fields + .iter() + .map(|field| self.eval_operand(field).map_or(Value::Uninit, Value::Immediate)) + .collect(), + variant: match **kind { + AggregateKind::Adt(_, variant, _, _, _) => variant, + AggregateKind::Array(_) + | AggregateKind::Tuple + | AggregateKind::Closure(_, _) + | AggregateKind::Coroutine(_, _) + | AggregateKind::CoroutineClosure(_, _) => VariantIdx::new(0), + }, + }, + + Repeat(ref op, n) => { + trace!(?op, ?n); + return None; + } + + Len(place) => { + let len = match self.get_const(place)? { + Value::Immediate(src) => src.len(&self.ecx).ok()?, + Value::Aggregate { fields, .. } => fields.len() as u64, + Value::Uninit => match place.ty(self.local_decls(), self.tcx).ty.kind() { + ty::Array(_, n) => n.try_eval_target_usize(self.tcx, self.param_env)?, + _ => return None, + }, + }; + ImmTy::from_scalar(Scalar::from_target_usize(len, self), layout).into() + } + + Ref(..) | AddressOf(..) => return None, + + NullaryOp(ref null_op, ty) => { + let op_layout = self.use_ecx(|this| this.ecx.layout_of(ty))?; + let val = match null_op { + NullOp::SizeOf => op_layout.size.bytes(), + NullOp::AlignOf => op_layout.align.abi.bytes(), + NullOp::OffsetOf(fields) => { + op_layout.offset_of_subfield(self, fields.iter()).bytes() + } + NullOp::DebugAssertions => return None, + }; + ImmTy::from_scalar(Scalar::from_target_usize(val, self), layout).into() + } + + ShallowInitBox(..) => return None, + + Cast(ref kind, ref value, to) => match kind { + CastKind::IntToInt | CastKind::IntToFloat => { + let value = self.eval_operand(value)?; + let value = self.ecx.read_immediate(&value).ok()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.int_to_int_or_float(&value, to).ok()?; + res.into() + } + CastKind::FloatToFloat | CastKind::FloatToInt => { + let value = self.eval_operand(value)?; + let value = self.ecx.read_immediate(&value).ok()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.float_to_float_or_int(&value, to).ok()?; + res.into() + } + CastKind::Transmute => { + let value = self.eval_operand(value)?; + let to = self.ecx.layout_of(to).ok()?; + // `offset` for immediates only supports scalar/scalar-pair ABIs, + // so bail out if the target is not one. + match (value.layout.abi, to.abi) { + (Abi::Scalar(..), Abi::Scalar(..)) => {} + (Abi::ScalarPair(..), Abi::ScalarPair(..)) => {} + _ => return None, + } + + value.offset(Size::ZERO, to, &self.ecx).ok()?.into() + } + _ => return None, + }, + + Discriminant(place) => { + let variant = match self.get_const(place)? { + Value::Immediate(op) => { + let op = op.clone(); + self.use_ecx(|this| this.ecx.read_discriminant(&op))? + } + Value::Aggregate { variant, .. } => *variant, + Value::Uninit => return None, + }; + let imm = self.use_ecx(|this| { + this.ecx.discriminant_for_variant( + place.ty(this.local_decls(), this.tcx).ty, + variant, + ) + })?; + imm.into() + } + }; + trace!(?val); + + *self.access_mut(dest)? = val; + + Some(()) + } } impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { @@ -546,7 +718,7 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { fn visit_constant(&mut self, constant: &ConstOperand<'tcx>, location: Location) { trace!("visit_constant: {:?}", constant); self.super_constant(constant, location); - self.eval_constant(constant, location); + self.eval_constant(constant); } fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { @@ -554,15 +726,12 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { let Some(()) = self.check_rvalue(rvalue, location) else { return }; - match self.ecx.machine.can_const_prop[place.local] { + match self.can_const_prop[place.local] { // Do nothing if the place is indirect. _ if place.is_indirect() => {} ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), ConstPropMode::OnlyInsideOwnBlock | ConstPropMode::FullConstProp => { - if self - .use_ecx(location, |this| this.ecx.eval_rvalue_into_place(rvalue, *place)) - .is_none() - { + if self.eval_rvalue(rvalue, place).is_none() { // Const prop failed, so erase the destination, ensuring that whatever happens // from here on, does not know about the previous value. // This is important in case we have @@ -578,7 +747,7 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { Nuking the entire site from orbit, it's the only way to be sure", place, ); - Self::remove_const(&mut self.ecx, place.local); + self.remove_const(place.local); } } } @@ -592,28 +761,24 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { self.super_statement(statement, location); match statement.kind { - StatementKind::SetDiscriminant { ref place, .. } => { - match self.ecx.machine.can_const_prop[place.local] { + StatementKind::SetDiscriminant { ref place, variant_index } => { + match self.can_const_prop[place.local] { // Do nothing if the place is indirect. _ if place.is_indirect() => {} ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { - if self.use_ecx(location, |this| this.ecx.statement(statement)).is_some() { - trace!("propped discriminant into {:?}", place); - } else { - Self::remove_const(&mut self.ecx, place.local); + match self.access_mut(place) { + Some(Value::Aggregate { variant, .. }) => *variant = variant_index, + _ => self.remove_const(place.local), } } } } StatementKind::StorageLive(local) => { - let frame = self.ecx.frame_mut(); - frame.locals[local].make_live_uninit(); + self.remove_const(local); } StatementKind::StorageDead(local) => { - let frame = self.ecx.frame_mut(); - // We don't actually track liveness, so the local remains live. But forget its value. - frame.locals[local].make_live_uninit(); + self.remove_const(local); } _ => {} } @@ -626,9 +791,8 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { self.check_assertion(*expected, msg, cond, location); } TerminatorKind::SwitchInt { ref discr, ref targets } => { - if let Some(ref value) = self.eval_operand(discr, location) - && let Some(value_const) = - self.use_ecx(location, |this| this.ecx.read_scalar(value)) + if let Some(ref value) = self.eval_operand(discr) + && let Some(value_const) = self.use_ecx(|this| this.ecx.read_scalar(value)) && let Ok(constant) = value_const.try_to_int() && let Ok(constant) = constant.to_bits(constant.size()) { @@ -665,7 +829,7 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { // which were modified in the current block. // Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`. let mut written_only_inside_own_block_locals = - std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals); + std::mem::take(&mut self.written_only_inside_own_block_locals); // This loop can get very hot for some bodies: it check each local in each bb. // To avoid this quadratic behaviour, we only clear the locals that were modified inside @@ -673,17 +837,13 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { // The order in which we remove consts does not matter. #[allow(rustc::potential_query_instability)] for local in written_only_inside_own_block_locals.drain() { - debug_assert_eq!( - self.ecx.machine.can_const_prop[local], - ConstPropMode::OnlyInsideOwnBlock - ); - Self::remove_const(&mut self.ecx, local); + debug_assert_eq!(self.can_const_prop[local], ConstPropMode::OnlyInsideOwnBlock); + self.remove_const(local); } - self.ecx.machine.written_only_inside_own_block_locals = - written_only_inside_own_block_locals; + self.written_only_inside_own_block_locals = written_only_inside_own_block_locals; if cfg!(debug_assertions) { - for (local, &mode) in self.ecx.machine.can_const_prop.iter_enumerated() { + for (local, &mode) in self.can_const_prop.iter_enumerated() { match mode { ConstPropMode::FullConstProp => {} ConstPropMode::NoPropagation | ConstPropMode::OnlyInsideOwnBlock => { diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index eaa36e0cc91..a0851aa557b 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -50,6 +50,9 @@ //! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing. //! Otherwise it drops all the values in scope at the last suspension point. +mod by_move_body; +pub use by_move_body::ByMoveBody; + use crate::abort_unwinding_calls; use crate::deref_separator::deref_finder; use crate::errors; @@ -723,7 +726,7 @@ fn replace_resume_ty_local<'tcx>( /// The async lowering step and the type / lifetime inference / checking are /// still using the `resume` argument for the time being. After this transform, /// the coroutine body doesn't have the `resume` argument. -fn transform_gen_context<'tcx>(_tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { +fn transform_gen_context<'tcx>(body: &mut Body<'tcx>) { // This leaves the local representing the `resume` argument in place, // but turns it into a regular local variable. This is cheaper than // adjusting all local references in the body after removing it. @@ -1231,7 +1234,12 @@ fn create_coroutine_drop_shim<'tcx>( drop_clean: BasicBlock, ) -> Body<'tcx> { let mut body = body.clone(); - body.arg_count = 1; // make sure the resume argument is not included here + // Take the coroutine info out of the body, since the drop shim is + // not a coroutine body itself; it just has its drop built out of it. + let _ = body.coroutine.take(); + // Make sure the resume argument is not included here, since we're + // building a body for `drop_in_place`. + body.arg_count = 1; let source_info = SourceInfo::outermost(body.span); @@ -1725,7 +1733,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // Remove the context argument within generator bodies. if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) { - transform_gen_context(tcx, body); + transform_gen_context(body); } // The original arguments to the function are no longer arguments, mark them as such. @@ -2071,7 +2079,7 @@ fn check_must_not_suspend_def( span: data.source_span, reason: s.as_str().to_string(), }); - tcx.emit_spanned_lint( + tcx.emit_node_span_lint( rustc_session::lint::builtin::MUST_NOT_SUSPEND, hir_id, data.source_span, diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs new file mode 100644 index 00000000000..e40f4520671 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -0,0 +1,159 @@ +//! A MIR pass which duplicates a coroutine's body and removes any derefs which +//! would be present for upvars that are taken by-ref. The result of which will +//! be a coroutine body that takes all of its upvars by-move, and which we stash +//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures. + +use rustc_data_structures::fx::FxIndexSet; +use rustc_hir as hir; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::{self, dump_mir, MirPass}; +use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt}; +use rustc_target::abi::FieldIdx; + +pub struct ByMoveBody; + +impl<'tcx> MirPass<'tcx> for ByMoveBody { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { + let Some(coroutine_def_id) = body.source.def_id().as_local() else { + return; + }; + let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) = + tcx.coroutine_kind(coroutine_def_id) + else { + return; + }; + let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + if coroutine_ty.references_error() { + return; + } + let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") }; + + let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(); + if coroutine_kind == ty::ClosureKind::FnOnce { + return; + } + + let mut by_ref_fields = FxIndexSet::default(); + let by_move_upvars = Ty::new_tup_from_iter( + tcx, + tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| { + if capture.is_by_ref() { + by_ref_fields.insert(FieldIdx::from_usize(idx)); + } + capture.place.ty() + }), + ); + let by_move_coroutine_ty = Ty::new_coroutine( + tcx, + coroutine_def_id.to_def_id(), + ty::CoroutineArgs::new( + tcx, + ty::CoroutineArgsParts { + parent_args: args.as_coroutine().parent_args(), + kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce), + resume_ty: args.as_coroutine().resume_ty(), + yield_ty: args.as_coroutine().yield_ty(), + return_ty: args.as_coroutine().return_ty(), + witness: args.as_coroutine().witness(), + tupled_upvars_ty: by_move_upvars, + }, + ) + .args, + ); + + let mut by_move_body = body.clone(); + MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body); + dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(())); + by_move_body.source = mir::MirSource { + instance: InstanceDef::CoroutineKindShim { + coroutine_def_id: coroutine_def_id.to_def_id(), + target_kind: ty::ClosureKind::FnOnce, + }, + promoted: None, + }; + body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body); + + // If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body. + // This is actually just a copy of the by-ref body, but with a different self type. + // FIXME(async_closures): We could probably unify this with the by-ref body somehow. + if coroutine_kind == ty::ClosureKind::Fn { + let by_mut_coroutine_ty = Ty::new_coroutine( + tcx, + coroutine_def_id.to_def_id(), + ty::CoroutineArgs::new( + tcx, + ty::CoroutineArgsParts { + parent_args: args.as_coroutine().parent_args(), + kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut), + resume_ty: args.as_coroutine().resume_ty(), + yield_ty: args.as_coroutine().yield_ty(), + return_ty: args.as_coroutine().return_ty(), + witness: args.as_coroutine().witness(), + tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(), + }, + ) + .args, + ); + let mut by_mut_body = body.clone(); + by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty; + dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(())); + by_mut_body.source = mir::MirSource { + instance: InstanceDef::CoroutineKindShim { + coroutine_def_id: coroutine_def_id.to_def_id(), + target_kind: ty::ClosureKind::FnMut, + }, + promoted: None, + }; + body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body); + } + } +} + +struct MakeByMoveBody<'tcx> { + tcx: TyCtxt<'tcx>, + by_ref_fields: FxIndexSet<FieldIdx>, + by_move_coroutine_ty: Ty<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place( + &mut self, + place: &mut mir::Place<'tcx>, + context: mir::visit::PlaceContext, + location: mir::Location, + ) { + if place.local == ty::CAPTURE_STRUCT_LOCAL + && !place.projection.is_empty() + && let mir::ProjectionElem::Field(idx, ty) = place.projection[0] + && self.by_ref_fields.contains(&idx) + { + let (begin, end) = place.projection[1..].split_first().unwrap(); + // FIXME(async_closures): I'm actually a bit surprised to see that we always + // initially deref the by-ref upvars. If this is not actually true, then we + // will at least get an ICE that explains why this isn't true :^) + assert_eq!(*begin, mir::ProjectionElem::Deref); + // Peel one ref off of the ty. + let peeled_ty = ty.builtin_deref(true).unwrap().ty; + *place = mir::Place { + local: place.local, + projection: self.tcx.mk_place_elems_from_iter( + [mir::ProjectionElem::Field(idx, peeled_ty)] + .into_iter() + .chain(end.iter().copied()), + ), + }; + } + self.super_place(place, context, location); + } + + fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) { + // Replace the type of the self arg. + if local == ty::CAPTURE_STRUCT_LOCAL { + local_decl.ty = self.by_move_coroutine_ty; + } + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs index 8c11dea5d4e..9a1d8bae6b4 100644 --- a/compiler/rustc_mir_transform/src/coverage/counters.rs +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -1,4 +1,5 @@ -use rustc_data_structures::fx::FxIndexMap; +use rustc_data_structures::captures::Captures; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::graph::WithNumNodes; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; @@ -38,19 +39,27 @@ impl Debug for BcbCounter { } } +#[derive(Debug)] +pub(super) enum CounterIncrementSite { + Node { bcb: BasicCoverageBlock }, + Edge { from_bcb: BasicCoverageBlock, to_bcb: BasicCoverageBlock }, +} + /// Generates and stores coverage counter and coverage expression information /// associated with nodes/edges in the BCB graph. pub(super) struct CoverageCounters { - next_counter_id: CounterId, + /// List of places where a counter-increment statement should be injected + /// into MIR, each with its corresponding counter ID. + counter_increment_sites: IndexVec<CounterId, CounterIncrementSite>, /// Coverage counters/expressions that are associated with individual BCBs. bcb_counters: IndexVec<BasicCoverageBlock, Option<BcbCounter>>, /// Coverage counters/expressions that are associated with the control-flow /// edge between two BCBs. /// - /// The iteration order of this map can affect the precise contents of MIR, - /// so we use `FxIndexMap` to avoid query stability hazards. - bcb_edge_counters: FxIndexMap<(BasicCoverageBlock, BasicCoverageBlock), BcbCounter>, + /// We currently don't iterate over this map, but if we do in the future, + /// switch it back to `FxIndexMap` to avoid query stability hazards. + bcb_edge_counters: FxHashMap<(BasicCoverageBlock, BasicCoverageBlock), BcbCounter>, /// Tracks which BCBs have a counter associated with some incoming edge. /// Only used by assertions, to verify that BCBs with incoming edge /// counters do not have their own physical counters (expressions are allowed). @@ -71,9 +80,9 @@ impl CoverageCounters { let num_bcbs = basic_coverage_blocks.num_nodes(); let mut this = Self { - next_counter_id: CounterId::START, + counter_increment_sites: IndexVec::new(), bcb_counters: IndexVec::from_elem_n(None, num_bcbs), - bcb_edge_counters: FxIndexMap::default(), + bcb_edge_counters: FxHashMap::default(), bcb_has_incoming_edge_counters: BitSet::new_empty(num_bcbs), expressions: IndexVec::new(), }; @@ -84,8 +93,8 @@ impl CoverageCounters { this } - fn make_counter(&mut self) -> BcbCounter { - let id = self.next_counter(); + fn make_counter(&mut self, site: CounterIncrementSite) -> BcbCounter { + let id = self.counter_increment_sites.push(site); BcbCounter::Counter { id } } @@ -103,15 +112,8 @@ impl CoverageCounters { self.make_expression(lhs, Op::Add, rhs) } - /// Counter IDs start from one and go up. - fn next_counter(&mut self) -> CounterId { - let next = self.next_counter_id; - self.next_counter_id = self.next_counter_id + 1; - next - } - pub(super) fn num_counters(&self) -> usize { - self.next_counter_id.as_usize() + self.counter_increment_sites.len() } #[cfg(test)] @@ -171,22 +173,26 @@ impl CoverageCounters { self.bcb_counters[bcb] } - pub(super) fn bcb_node_counters( + /// Returns an iterator over all the nodes/edges in the coverage graph that + /// should have a counter-increment statement injected into MIR, along with + /// each site's corresponding counter ID. + pub(super) fn counter_increment_sites( &self, - ) -> impl Iterator<Item = (BasicCoverageBlock, &BcbCounter)> { - self.bcb_counters - .iter_enumerated() - .filter_map(|(bcb, counter_kind)| Some((bcb, counter_kind.as_ref()?))) + ) -> impl Iterator<Item = (CounterId, &CounterIncrementSite)> { + self.counter_increment_sites.iter_enumerated() } - /// For each edge in the BCB graph that has an associated counter, yields - /// that edge's *from* and *to* nodes, and its counter. - pub(super) fn bcb_edge_counters( + /// Returns an iterator over the subset of BCB nodes that have been associated + /// with a counter *expression*, along with the ID of that expression. + pub(super) fn bcb_nodes_with_coverage_expressions( &self, - ) -> impl Iterator<Item = (BasicCoverageBlock, BasicCoverageBlock, &BcbCounter)> { - self.bcb_edge_counters - .iter() - .map(|(&(from_bcb, to_bcb), counter_kind)| (from_bcb, to_bcb, counter_kind)) + ) -> impl Iterator<Item = (BasicCoverageBlock, ExpressionId)> + Captures<'_> { + self.bcb_counters.iter_enumerated().filter_map(|(bcb, &counter_kind)| match counter_kind { + // Yield the BCB along with its associated expression ID. + Some(BcbCounter::Expression { id }) => Some((bcb, id)), + // This BCB is associated with a counter or nothing, so skip it. + Some(BcbCounter::Counter { .. }) | None => None, + }) } pub(super) fn into_expressions(self) -> IndexVec<ExpressionId, Expression> { @@ -339,7 +345,8 @@ impl<'a> MakeBcbCounters<'a> { // program results in a tight infinite loop, but it should still compile. let one_path_to_target = !self.basic_coverage_blocks.bcb_has_multiple_in_edges(bcb); if one_path_to_target || self.bcb_predecessors(bcb).contains(&bcb) { - let counter_kind = self.coverage_counters.make_counter(); + let counter_kind = + self.coverage_counters.make_counter(CounterIncrementSite::Node { bcb }); if one_path_to_target { debug!("{bcb:?} gets a new counter: {counter_kind:?}"); } else { @@ -401,7 +408,8 @@ impl<'a> MakeBcbCounters<'a> { } // Make a new counter to count this edge. - let counter_kind = self.coverage_counters.make_counter(); + let counter_kind = + self.coverage_counters.make_counter(CounterIncrementSite::Edge { from_bcb, to_bcb }); debug!("Edge {from_bcb:?}->{to_bcb:?} gets a new counter: {counter_kind:?}"); self.coverage_counters.set_bcb_edge_counter(from_bcb, to_bcb, counter_kind) } diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index a11d224e8f1..4c5be0a3f4b 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -7,7 +7,7 @@ mod spans; #[cfg(test)] mod tests; -use self::counters::{BcbCounter, CoverageCounters}; +use self::counters::{CounterIncrementSite, CoverageCounters}; use self::graph::{BasicCoverageBlock, CoverageGraph}; use self::spans::{BcbMapping, BcbMappingKind, CoverageSpans}; @@ -59,170 +59,148 @@ impl<'tcx> MirPass<'tcx> for InstrumentCoverage { _ => {} } - trace!("InstrumentCoverage starting for {def_id:?}"); - Instrumentor::new(tcx, mir_body).inject_counters(); - trace!("InstrumentCoverage done for {def_id:?}"); + instrument_function_for_coverage(tcx, mir_body); } } -struct Instrumentor<'a, 'tcx> { - tcx: TyCtxt<'tcx>, - mir_body: &'a mut mir::Body<'tcx>, - hir_info: ExtractedHirInfo, - basic_coverage_blocks: CoverageGraph, -} +fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir::Body<'tcx>) { + let def_id = mir_body.source.def_id(); + let _span = debug_span!("instrument_function_for_coverage", ?def_id).entered(); -impl<'a, 'tcx> Instrumentor<'a, 'tcx> { - fn new(tcx: TyCtxt<'tcx>, mir_body: &'a mut mir::Body<'tcx>) -> Self { - let hir_info = extract_hir_info(tcx, mir_body.source.def_id().expect_local()); + let hir_info = extract_hir_info(tcx, def_id.expect_local()); + let basic_coverage_blocks = CoverageGraph::from_mir(mir_body); - debug!(?hir_info, "instrumenting {:?}", mir_body.source.def_id()); - - let basic_coverage_blocks = CoverageGraph::from_mir(mir_body); + //////////////////////////////////////////////////// + // Compute coverage spans from the `CoverageGraph`. + let Some(coverage_spans) = + spans::generate_coverage_spans(mir_body, &hir_info, &basic_coverage_blocks) + else { + // No relevant spans were found in MIR, so skip instrumenting this function. + return; + }; - Self { tcx, mir_body, hir_info, basic_coverage_blocks } + //////////////////////////////////////////////////// + // Create an optimized mix of `Counter`s and `Expression`s for the `CoverageGraph`. Ensure + // every coverage span has a `Counter` or `Expression` assigned to its `BasicCoverageBlock` + // and all `Expression` dependencies (operands) are also generated, for any other + // `BasicCoverageBlock`s not already associated with a coverage span. + let bcb_has_coverage_spans = |bcb| coverage_spans.bcb_has_coverage_spans(bcb); + let coverage_counters = + CoverageCounters::make_bcb_counters(&basic_coverage_blocks, bcb_has_coverage_spans); + + let mappings = create_mappings(tcx, &hir_info, &coverage_spans, &coverage_counters); + if mappings.is_empty() { + // No spans could be converted into valid mappings, so skip this function. + debug!("no spans could be converted into valid mappings; skipping"); + return; } - fn inject_counters(&'a mut self) { - //////////////////////////////////////////////////// - // Compute coverage spans from the `CoverageGraph`. - let Some(coverage_spans) = CoverageSpans::generate_coverage_spans( - self.mir_body, - &self.hir_info, - &self.basic_coverage_blocks, - ) else { - // No relevant spans were found in MIR, so skip instrumenting this function. - return; - }; - - //////////////////////////////////////////////////// - // Create an optimized mix of `Counter`s and `Expression`s for the `CoverageGraph`. Ensure - // every coverage span has a `Counter` or `Expression` assigned to its `BasicCoverageBlock` - // and all `Expression` dependencies (operands) are also generated, for any other - // `BasicCoverageBlock`s not already associated with a coverage span. - let bcb_has_coverage_spans = |bcb| coverage_spans.bcb_has_coverage_spans(bcb); - let coverage_counters = CoverageCounters::make_bcb_counters( - &self.basic_coverage_blocks, - bcb_has_coverage_spans, - ); - - let mappings = self.create_mappings(&coverage_spans, &coverage_counters); - if mappings.is_empty() { - // No spans could be converted into valid mappings, so skip this function. - debug!("no spans could be converted into valid mappings; skipping"); - return; - } - - self.inject_coverage_statements(bcb_has_coverage_spans, &coverage_counters); - - self.mir_body.function_coverage_info = Some(Box::new(FunctionCoverageInfo { - function_source_hash: self.hir_info.function_source_hash, - num_counters: coverage_counters.num_counters(), - expressions: coverage_counters.into_expressions(), - mappings, - })); - } + inject_coverage_statements( + mir_body, + &basic_coverage_blocks, + bcb_has_coverage_spans, + &coverage_counters, + ); - /// For each coverage span extracted from MIR, create a corresponding - /// mapping. - /// - /// Precondition: All BCBs corresponding to those spans have been given - /// coverage counters. - fn create_mappings( - &self, - coverage_spans: &CoverageSpans, - coverage_counters: &CoverageCounters, - ) -> Vec<Mapping> { - let source_map = self.tcx.sess.source_map(); - let body_span = self.hir_info.body_span; - - let source_file = source_map.lookup_source_file(body_span.lo()); - use rustc_session::RemapFileNameExt; - let file_name = - Symbol::intern(&source_file.name.for_codegen(self.tcx.sess).to_string_lossy()); - - let term_for_bcb = |bcb| { - coverage_counters - .bcb_counter(bcb) - .expect("all BCBs with spans were given counters") - .as_term() - }; + mir_body.function_coverage_info = Some(Box::new(FunctionCoverageInfo { + function_source_hash: hir_info.function_source_hash, + num_counters: coverage_counters.num_counters(), + expressions: coverage_counters.into_expressions(), + mappings, + })); +} - coverage_spans - .all_bcb_mappings() - .filter_map(|&BcbMapping { kind: bcb_mapping_kind, span }| { - let kind = match bcb_mapping_kind { - BcbMappingKind::Code(bcb) => MappingKind::Code(term_for_bcb(bcb)), - }; - let code_region = make_code_region(source_map, file_name, span, body_span)?; - Some(Mapping { kind, code_region }) - }) - .collect::<Vec<_>>() - } +/// For each coverage span extracted from MIR, create a corresponding +/// mapping. +/// +/// Precondition: All BCBs corresponding to those spans have been given +/// coverage counters. +fn create_mappings<'tcx>( + tcx: TyCtxt<'tcx>, + hir_info: &ExtractedHirInfo, + coverage_spans: &CoverageSpans, + coverage_counters: &CoverageCounters, +) -> Vec<Mapping> { + let source_map = tcx.sess.source_map(); + let body_span = hir_info.body_span; + + let source_file = source_map.lookup_source_file(body_span.lo()); + use rustc_session::RemapFileNameExt; + let file_name = Symbol::intern(&source_file.name.for_codegen(tcx.sess).to_string_lossy()); + + let term_for_bcb = |bcb| { + coverage_counters + .bcb_counter(bcb) + .expect("all BCBs with spans were given counters") + .as_term() + }; - /// For each BCB node or BCB edge that has an associated coverage counter, - /// inject any necessary coverage statements into MIR. - fn inject_coverage_statements( - &mut self, - bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool, - coverage_counters: &CoverageCounters, - ) { - // Process the counters associated with BCB nodes. - for (bcb, counter_kind) in coverage_counters.bcb_node_counters() { - let do_inject = match counter_kind { - // Counter-increment statements always need to be injected. - BcbCounter::Counter { .. } => true, - // The only purpose of expression-used statements is to detect - // when a mapping is unreachable, so we only inject them for - // expressions with one or more mappings. - BcbCounter::Expression { .. } => bcb_has_coverage_spans(bcb), + coverage_spans + .all_bcb_mappings() + .filter_map(|&BcbMapping { kind: bcb_mapping_kind, span }| { + let kind = match bcb_mapping_kind { + BcbMappingKind::Code(bcb) => MappingKind::Code(term_for_bcb(bcb)), }; - if do_inject { - inject_statement( - self.mir_body, - self.make_mir_coverage_kind(counter_kind), - self.basic_coverage_blocks[bcb].leader_bb(), - ); - } - } + let code_region = make_code_region(source_map, file_name, span, body_span)?; + Some(Mapping { kind, code_region }) + }) + .collect::<Vec<_>>() +} - // Process the counters associated with BCB edges. - for (from_bcb, to_bcb, counter_kind) in coverage_counters.bcb_edge_counters() { - let do_inject = match counter_kind { - // Counter-increment statements always need to be injected. - BcbCounter::Counter { .. } => true, - // BCB-edge expressions never have mappings, so they never need - // a corresponding statement. - BcbCounter::Expression { .. } => false, - }; - if !do_inject { - continue; +/// For each BCB node or BCB edge that has an associated coverage counter, +/// inject any necessary coverage statements into MIR. +fn inject_coverage_statements<'tcx>( + mir_body: &mut mir::Body<'tcx>, + basic_coverage_blocks: &CoverageGraph, + bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool, + coverage_counters: &CoverageCounters, +) { + // Inject counter-increment statements into MIR. + for (id, counter_increment_site) in coverage_counters.counter_increment_sites() { + // Determine the block to inject a counter-increment statement into. + // For BCB nodes this is just their first block, but for edges we need + // to create a new block between the two BCBs, and inject into that. + let target_bb = match *counter_increment_site { + CounterIncrementSite::Node { bcb } => basic_coverage_blocks[bcb].leader_bb(), + CounterIncrementSite::Edge { from_bcb, to_bcb } => { + // Create a new block between the last block of `from_bcb` and + // the first block of `to_bcb`. + let from_bb = basic_coverage_blocks[from_bcb].last_bb(); + let to_bb = basic_coverage_blocks[to_bcb].leader_bb(); + + let new_bb = inject_edge_counter_basic_block(mir_body, from_bb, to_bb); + debug!( + "Edge {from_bcb:?} (last {from_bb:?}) -> {to_bcb:?} (leader {to_bb:?}) \ + requires a new MIR BasicBlock {new_bb:?} for counter increment {id:?}", + ); + new_bb } + }; - // We need to inject a coverage statement into a new BB between the - // last BB of `from_bcb` and the first BB of `to_bcb`. - let from_bb = self.basic_coverage_blocks[from_bcb].last_bb(); - let to_bb = self.basic_coverage_blocks[to_bcb].leader_bb(); - - let new_bb = inject_edge_counter_basic_block(self.mir_body, from_bb, to_bb); - debug!( - "Edge {from_bcb:?} (last {from_bb:?}) -> {to_bcb:?} (leader {to_bb:?}) \ - requires a new MIR BasicBlock {new_bb:?} for edge counter {counter_kind:?}", - ); - - // Inject a counter into the newly-created BB. - inject_statement(self.mir_body, self.make_mir_coverage_kind(counter_kind), new_bb); - } + inject_statement(mir_body, CoverageKind::CounterIncrement { id }, target_bb); } - fn make_mir_coverage_kind(&self, counter_kind: &BcbCounter) -> CoverageKind { - match *counter_kind { - BcbCounter::Counter { id } => CoverageKind::CounterIncrement { id }, - BcbCounter::Expression { id } => CoverageKind::ExpressionUsed { id }, - } + // For each counter expression that is directly associated with at least one + // span, we inject an "expression-used" statement, so that coverage codegen + // can check whether the injected statement survived MIR optimization. + // (BCB edges can't have spans, so we only need to process BCB nodes here.) + // + // See the code in `rustc_codegen_llvm::coverageinfo::map_data` that deals + // with "expressions seen" and "zero terms". + for (bcb, expression_id) in coverage_counters + .bcb_nodes_with_coverage_expressions() + .filter(|&(bcb, _)| bcb_has_coverage_spans(bcb)) + { + inject_statement( + mir_body, + CoverageKind::ExpressionUsed { id: expression_id }, + basic_coverage_blocks[bcb].leader_bb(), + ); } } +/// Given two basic blocks that have a control-flow edge between them, creates +/// and returns a new block that sits between those blocks. fn inject_edge_counter_basic_block( mir_body: &mut mir::Body<'_>, from_bb: BasicBlock, @@ -329,7 +307,7 @@ fn make_code_region( start_line = source_map.doctest_offset_line(&file.name, start_line); end_line = source_map.doctest_offset_line(&file.name, end_line); - Some(CodeRegion { + check_code_region(CodeRegion { file_name, start_line: start_line as u32, start_col: start_col as u32, @@ -338,6 +316,39 @@ fn make_code_region( }) } +/// If `llvm-cov` sees a code region that is improperly ordered (end < start), +/// it will immediately exit with a fatal error. To prevent that from happening, +/// discard regions that are improperly ordered, or might be interpreted in a +/// way that makes them improperly ordered. +fn check_code_region(code_region: CodeRegion) -> Option<CodeRegion> { + let CodeRegion { file_name: _, start_line, start_col, end_line, end_col } = code_region; + + // Line/column coordinates are supposed to be 1-based. If we ever emit + // coordinates of 0, `llvm-cov` might misinterpret them. + let all_nonzero = [start_line, start_col, end_line, end_col].into_iter().all(|x| x != 0); + // Coverage mappings use the high bit of `end_col` to indicate that a + // region is actually a "gap" region, so make sure it's unset. + let end_col_has_high_bit_unset = (end_col & (1 << 31)) == 0; + // If a region is improperly ordered (end < start), `llvm-cov` will exit + // with a fatal error, which is inconvenient for users and hard to debug. + let is_ordered = (start_line, start_col) <= (end_line, end_col); + + if all_nonzero && end_col_has_high_bit_unset && is_ordered { + Some(code_region) + } else { + debug!( + ?code_region, + ?all_nonzero, + ?end_col_has_high_bit_unset, + ?is_ordered, + "Skipping code region that would be misinterpreted or rejected by LLVM" + ); + // If this happens in a debug build, ICE to make it easier to notice. + debug_assert!(false, "Improper code region: {code_region:?}"); + None + } +} + fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { // Only instrument functions, methods, and closures (not constants since they are evaluated // at compile time by Miri). @@ -351,7 +362,18 @@ fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { return false; } + // Don't instrument functions with `#[automatically_derived]` on their + // enclosing impl block, on the assumption that most users won't care about + // coverage for derived impls. + if let Some(impl_of) = tcx.impl_of_method(def_id.to_def_id()) + && tcx.is_automatically_derived(impl_of) + { + trace!("InstrumentCoverage skipped for {def_id:?} (automatically derived)"); + return false; + } + if tcx.codegen_fn_attrs(def_id).flags.contains(CodegenFnAttrFlags::NO_COVERAGE) { + trace!("InstrumentCoverage skipped for {def_id:?} (`#[coverage(off)]`)"); return false; } @@ -363,7 +385,9 @@ fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { struct ExtractedHirInfo { function_source_hash: u64, is_async_fn: bool, - fn_sig_span: Span, + /// The span of the function's signature, extended to the start of `body_span`. + /// Must have the same context and filename as the body span. + fn_sig_span_extended: Option<Span>, body_span: Span, } @@ -376,13 +400,25 @@ fn extract_hir_info<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> ExtractedHir hir::map::associated_body(hir_node).expect("HIR node is a function with body"); let hir_body = tcx.hir().body(fn_body_id); - let is_async_fn = hir_node.fn_sig().is_some_and(|fn_sig| fn_sig.header.is_async()); - let body_span = get_body_span(tcx, hir_body, def_id); + let maybe_fn_sig = hir_node.fn_sig(); + let is_async_fn = maybe_fn_sig.is_some_and(|fn_sig| fn_sig.header.is_async()); + + let mut body_span = hir_body.value.span; + + use rustc_hir::{Closure, Expr, ExprKind, Node}; + // Unexpand a closure's body span back to the context of its declaration. + // This helps with closure bodies that consist of just a single bang-macro, + // and also with closure bodies produced by async desugaring. + if let Node::Expr(&Expr { kind: ExprKind::Closure(&Closure { fn_decl_span, .. }), .. }) = + hir_node + { + body_span = body_span.find_ancestor_in_same_ctxt(fn_decl_span).unwrap_or(body_span); + } // The actual signature span is only used if it has the same context and // filename as the body, and precedes the body. - let maybe_fn_sig_span = hir_node.fn_sig().map(|fn_sig| fn_sig.span); - let fn_sig_span = maybe_fn_sig_span + let fn_sig_span_extended = maybe_fn_sig + .map(|fn_sig| fn_sig.span) .filter(|&fn_sig_span| { let source_map = tcx.sess.source_map(); let file_idx = |span: Span| source_map.lookup_source_file_idx(span.lo()); @@ -392,39 +428,15 @@ fn extract_hir_info<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> ExtractedHir && file_idx(fn_sig_span) == file_idx(body_span) }) // If so, extend it to the start of the body span. - .map(|fn_sig_span| fn_sig_span.with_hi(body_span.lo())) - // Otherwise, create a dummy signature span at the start of the body. - .unwrap_or_else(|| body_span.shrink_to_lo()); + .map(|fn_sig_span| fn_sig_span.with_hi(body_span.lo())); let function_source_hash = hash_mir_source(tcx, hir_body); - ExtractedHirInfo { function_source_hash, is_async_fn, fn_sig_span, body_span } -} - -fn get_body_span<'tcx>( - tcx: TyCtxt<'tcx>, - hir_body: &rustc_hir::Body<'tcx>, - def_id: LocalDefId, -) -> Span { - let mut body_span = hir_body.value.span; - - if tcx.is_closure_or_coroutine(def_id.to_def_id()) { - // If the current function is a closure, and its "body" span was created - // by macro expansion or compiler desugaring, try to walk backwards to - // the pre-expansion call site or body. - body_span = body_span.source_callsite(); - } - - body_span + ExtractedHirInfo { function_source_hash, is_async_fn, fn_sig_span_extended, body_span } } fn hash_mir_source<'tcx>(tcx: TyCtxt<'tcx>, hir_body: &'tcx rustc_hir::Body<'tcx>) -> u64 { // FIXME(cjgillot) Stop hashing HIR manually here. let owner = hir_body.id().hir_id.owner; - tcx.hir_owner_nodes(owner) - .unwrap() - .opt_hash_including_bodies - .unwrap() - .to_smaller_hash() - .as_u64() + tcx.hir_owner_nodes(owner).opt_hash_including_bodies.unwrap().to_smaller_hash().as_u64() } diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index 81f6c831206..d3d0c7bcc95 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -3,7 +3,7 @@ use rustc_index::bit_set::BitSet; use rustc_middle::mir; use rustc_span::{BytePos, Span, DUMMY_SP}; -use super::graph::{BasicCoverageBlock, CoverageGraph}; +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; use crate::coverage::ExtractedHirInfo; mod from_mir; @@ -26,52 +26,63 @@ pub(super) struct CoverageSpans { } impl CoverageSpans { - /// Extracts coverage-relevant spans from MIR, and associates them with - /// their corresponding BCBs. - /// - /// Returns `None` if no coverage-relevant spans could be extracted. - pub(super) fn generate_coverage_spans( - mir_body: &mir::Body<'_>, - hir_info: &ExtractedHirInfo, - basic_coverage_blocks: &CoverageGraph, - ) -> Option<Self> { - let mut mappings = vec![]; - - let coverage_spans = CoverageSpansGenerator::generate_coverage_spans( + pub(super) fn bcb_has_coverage_spans(&self, bcb: BasicCoverageBlock) -> bool { + self.bcb_has_mappings.contains(bcb) + } + + pub(super) fn all_bcb_mappings(&self) -> impl Iterator<Item = &BcbMapping> { + self.mappings.iter() + } +} + +/// Extracts coverage-relevant spans from MIR, and associates them with +/// their corresponding BCBs. +/// +/// Returns `None` if no coverage-relevant spans could be extracted. +pub(super) fn generate_coverage_spans( + mir_body: &mir::Body<'_>, + hir_info: &ExtractedHirInfo, + basic_coverage_blocks: &CoverageGraph, +) -> Option<CoverageSpans> { + let mut mappings = vec![]; + + if hir_info.is_async_fn { + // An async function desugars into a function that returns a future, + // with the user code wrapped in a closure. Any spans in the desugared + // outer function will be unhelpful, so just keep the signature span + // and ignore all of the spans in the MIR body. + if let Some(span) = hir_info.fn_sig_span_extended { + mappings.push(BcbMapping { kind: BcbMappingKind::Code(START_BCB), span }); + } + } else { + let sorted_spans = from_mir::mir_to_initial_sorted_coverage_spans( mir_body, hir_info, basic_coverage_blocks, ); + let coverage_spans = SpansRefiner::refine_sorted_spans(basic_coverage_blocks, sorted_spans); mappings.extend(coverage_spans.into_iter().map(|CoverageSpan { bcb, span, .. }| { // Each span produced by the generator represents an ordinary code region. BcbMapping { kind: BcbMappingKind::Code(bcb), span } })); - - if mappings.is_empty() { - return None; - } - - // Identify which BCBs have one or more mappings. - let mut bcb_has_mappings = BitSet::new_empty(basic_coverage_blocks.num_nodes()); - let mut insert = |bcb| { - bcb_has_mappings.insert(bcb); - }; - for &BcbMapping { kind, span: _ } in &mappings { - match kind { - BcbMappingKind::Code(bcb) => insert(bcb), - } - } - - Some(Self { bcb_has_mappings, mappings }) } - pub(super) fn bcb_has_coverage_spans(&self, bcb: BasicCoverageBlock) -> bool { - self.bcb_has_mappings.contains(bcb) + if mappings.is_empty() { + return None; } - pub(super) fn all_bcb_mappings(&self) -> impl Iterator<Item = &BcbMapping> { - self.mappings.iter() + // Identify which BCBs have one or more mappings. + let mut bcb_has_mappings = BitSet::new_empty(basic_coverage_blocks.num_nodes()); + let mut insert = |bcb| { + bcb_has_mappings.insert(bcb); + }; + for &BcbMapping { kind, span: _ } in &mappings { + match kind { + BcbMappingKind::Code(bcb) => insert(bcb), + } } + + Some(CoverageSpans { bcb_has_mappings, mappings }) } /// A BCB is deconstructed into one or more `Span`s. Each `Span` maps to a `CoverageSpan` that @@ -130,7 +141,7 @@ impl CoverageSpan { /// * Merge spans that represent continuous (both in source code and control flow), non-branching /// execution /// * Carve out (leave uncovered) any span that will be counted by another MIR (notably, closures) -struct CoverageSpansGenerator<'a> { +struct SpansRefiner<'a> { /// The BasicCoverageBlock Control Flow Graph (BCB CFG). basic_coverage_blocks: &'a CoverageGraph, @@ -173,40 +184,15 @@ struct CoverageSpansGenerator<'a> { refined_spans: Vec<CoverageSpan>, } -impl<'a> CoverageSpansGenerator<'a> { - /// Generate a minimal set of `CoverageSpan`s, each representing a contiguous code region to be - /// counted. - /// - /// The basic steps are: - /// - /// 1. Extract an initial set of spans from the `Statement`s and `Terminator`s of each - /// `BasicCoverageBlockData`. - /// 2. Sort the spans by span.lo() (starting position). Spans that start at the same position - /// are sorted with longer spans before shorter spans; and equal spans are sorted - /// (deterministically) based on "dominator" relationship (if any). - /// 3. Traverse the spans in sorted order to identify spans that can be dropped (for instance, - /// if another span or spans are already counting the same code region), or should be merged - /// into a broader combined span (because it represents a contiguous, non-branching, and - /// uninterrupted region of source code). - /// - /// Closures are exposed in their enclosing functions as `Assign` `Rvalue`s, and since - /// closures have their own MIR, their `Span` in their enclosing function should be left - /// "uncovered". - /// - /// Note the resulting vector of `CoverageSpan`s may not be fully sorted (and does not need - /// to be). - pub(super) fn generate_coverage_spans( - mir_body: &mir::Body<'_>, - hir_info: &ExtractedHirInfo, +impl<'a> SpansRefiner<'a> { + /// Takes the initial list of (sorted) spans extracted from MIR, and "refines" + /// them by merging compatible adjacent spans, removing redundant spans, + /// and carving holes in spans when they overlap in unwanted ways. + fn refine_sorted_spans( basic_coverage_blocks: &'a CoverageGraph, + sorted_spans: Vec<CoverageSpan>, ) -> Vec<CoverageSpan> { - let sorted_spans = from_mir::mir_to_initial_sorted_coverage_spans( - mir_body, - hir_info, - basic_coverage_blocks, - ); - - let coverage_spans = Self { + let this = Self { basic_coverage_blocks, sorted_spans_iter: sorted_spans.into_iter(), some_curr: None, @@ -217,7 +203,7 @@ impl<'a> CoverageSpansGenerator<'a> { refined_spans: Vec::with_capacity(basic_coverage_blocks.num_nodes() * 2), }; - coverage_spans.to_refined_spans() + this.to_refined_spans() } /// Iterate through the sorted `CoverageSpan`s, and return the refined list of merged and diff --git a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs index 1b6dfccd574..01fae7c0bec 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs @@ -12,30 +12,32 @@ use crate::coverage::graph::{ use crate::coverage::spans::CoverageSpan; use crate::coverage::ExtractedHirInfo; +/// Traverses the MIR body to produce an initial collection of coverage-relevant +/// spans, each associated with a node in the coverage graph (BCB) and possibly +/// other metadata. +/// +/// The returned spans are sorted in a specific order that is expected by the +/// subsequent span-refinement step. pub(super) fn mir_to_initial_sorted_coverage_spans( mir_body: &mir::Body<'_>, hir_info: &ExtractedHirInfo, basic_coverage_blocks: &CoverageGraph, ) -> Vec<CoverageSpan> { - let &ExtractedHirInfo { is_async_fn, fn_sig_span, body_span, .. } = hir_info; - - let mut initial_spans = vec![SpanFromMir::for_fn_sig(fn_sig_span)]; - - if is_async_fn { - // An async function desugars into a function that returns a future, - // with the user code wrapped in a closure. Any spans in the desugared - // outer function will be unhelpful, so just keep the signature span - // and ignore all of the spans in the MIR body. - } else { - for (bcb, bcb_data) in basic_coverage_blocks.iter_enumerated() { - initial_spans.extend(bcb_to_initial_coverage_spans(mir_body, body_span, bcb, bcb_data)); - } + let &ExtractedHirInfo { body_span, .. } = hir_info; - // If no spans were extracted from the body, discard the signature span. - // FIXME: This preserves existing behavior; consider getting rid of it. - if initial_spans.len() == 1 { - initial_spans.clear(); - } + let mut initial_spans = vec![]; + + for (bcb, bcb_data) in basic_coverage_blocks.iter_enumerated() { + initial_spans.extend(bcb_to_initial_coverage_spans(mir_body, body_span, bcb, bcb_data)); + } + + // Only add the signature span if we found at least one span in the body. + if !initial_spans.is_empty() { + // If there is no usable signature span, add a fake one (before refinement) + // to avoid an ugly gap between the body start and the first real span. + // FIXME: Find a more principled way to solve this problem. + let fn_sig_span = hir_info.fn_sig_span_extended.unwrap_or_else(|| body_span.shrink_to_lo()); + initial_spans.push(SpanFromMir::for_fn_sig(fn_sig_span)); } initial_spans.sort_by(|a, b| basic_coverage_blocks.cmp_in_dominator_order(a.bcb, b.bcb)); @@ -154,7 +156,9 @@ fn bcb_to_initial_coverage_spans<'a, 'tcx>( fn is_closure_or_coroutine(statement: &Statement<'_>) -> bool { match statement.kind { StatementKind::Assign(box (_, Rvalue::Aggregate(box ref agg_kind, _))) => match agg_kind { - AggregateKind::Closure(_, _) | AggregateKind::Coroutine(_, _) => true, + AggregateKind::Closure(_, _) + | AggregateKind::Coroutine(_, _) + | AggregateKind::CoroutineClosure(..) => true, _ => false, }, _ => false, diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index ad12bce9b02..86e99a8a5b5 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -2,7 +2,9 @@ //! //! Currently, this pass only propagates scalar values. -use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable}; +use rustc_const_eval::interpret::{ + ImmTy, Immediate, InterpCx, OpTy, PlaceTy, PointerArithmetic, Projectable, +}; use rustc_data_structures::fx::FxHashMap; use rustc_hir::def::DefKind; use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar}; @@ -696,6 +698,7 @@ fn try_write_constant<'tcx>( | ty::Bound(..) | ty::Placeholder(..) | ty::Closure(..) + | ty::CoroutineClosure(..) | ty::Coroutine(..) | ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"), @@ -935,12 +938,50 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm } fn binary_ptr_op( - _ecx: &InterpCx<'mir, 'tcx, Self>, - _bin_op: BinOp, - _left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, - _right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, + ecx: &InterpCx<'mir, 'tcx, Self>, + bin_op: BinOp, + left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, + right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, ) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> { - throw_machine_stop_str!("can't do pointer arithmetic"); + use rustc_middle::mir::BinOp::*; + Ok(match bin_op { + Eq | Ne | Lt | Le | Gt | Ge => { + // Types can differ, e.g. fn ptrs with different `for`. + assert_eq!(left.layout.abi, right.layout.abi); + let size = ecx.pointer_size(); + // Just compare the bits. ScalarPairs are compared lexicographically. + // We thus always compare pairs and simply fill scalars up with 0. + // If the pointer has provenance, `to_bits` will return `Err` and we bail out. + let left = match **left { + Immediate::Scalar(l) => (l.to_bits(size)?, 0), + Immediate::ScalarPair(l1, l2) => (l1.to_bits(size)?, l2.to_bits(size)?), + Immediate::Uninit => panic!("we should never see uninit data here"), + }; + let right = match **right { + Immediate::Scalar(r) => (r.to_bits(size)?, 0), + Immediate::ScalarPair(r1, r2) => (r1.to_bits(size)?, r2.to_bits(size)?), + Immediate::Uninit => panic!("we should never see uninit data here"), + }; + let res = match bin_op { + Eq => left == right, + Ne => left != right, + Lt => left < right, + Le => left <= right, + Gt => left > right, + Ge => left >= right, + _ => bug!(), + }; + (ImmTy::from_bool(res, *ecx.tcx), false) + } + + // Some more operations are possible with atomics. + // The return value always has the provenance of the *left* operand. + Add | Sub | BitOr | BitAnd | BitXor => { + throw_machine_stop_str!("pointer arithmetic is not handled") + } + + _ => span_bug!(ecx.cur_span(), "Invalid operator on pointers: {:?}", bin_op), + }) } fn expose_ptr( diff --git a/compiler/rustc_mir_transform/src/errors.rs b/compiler/rustc_mir_transform/src/errors.rs index 2ee660ddc9b..30de40e226c 100644 --- a/compiler/rustc_mir_transform/src/errors.rs +++ b/compiler/rustc_mir_transform/src/errors.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use rustc_errors::{ - Applicability, DecorateLint, DiagCtxt, DiagnosticArgValue, DiagnosticBuilder, + codes::*, Applicability, DecorateLint, DiagCtxt, DiagnosticArgValue, DiagnosticBuilder, DiagnosticMessage, EmissionGuarantee, IntoDiagnostic, Level, }; use rustc_macros::{Diagnostic, LintDiagnostic, Subdiagnostic}; @@ -33,7 +33,7 @@ pub(crate) enum ConstMutate { } #[derive(Diagnostic)] -#[diag(mir_transform_unaligned_packed_ref, code = "E0793")] +#[diag(mir_transform_unaligned_packed_ref, code = E0793)] #[note] #[note(mir_transform_note_ub)] #[help] @@ -66,7 +66,7 @@ impl<'a, G: EmissionGuarantee> IntoDiagnostic<'a, G> for RequiresUnsafe { #[track_caller] fn into_diagnostic(self, dcx: &'a DiagCtxt, level: Level) -> DiagnosticBuilder<'a, G> { let mut diag = DiagnosticBuilder::new(dcx, level, fluent::mir_transform_requires_unsafe); - diag.code("E0133".to_string()); + diag.code(E0133); diag.span(self.span); diag.span_label(self.span, self.details.label()); let desc = dcx.eagerly_translate_to_string(self.details.label(), [].into_iter()); @@ -125,7 +125,7 @@ impl RequiresUnsafeDetail { diag.arg( "missing_target_features", DiagnosticArgValue::StrListSepByAnd( - missing.iter().map(|feature| Cow::from(feature.as_str())).collect(), + missing.iter().map(|feature| Cow::from(feature.to_string())).collect(), ), ); diag.arg("missing_target_features_count", missing.len()); @@ -136,7 +136,7 @@ impl RequiresUnsafeDetail { DiagnosticArgValue::StrListSepByAnd( build_enabled .iter() - .map(|feature| Cow::from(feature.as_str())) + .map(|feature| Cow::from(feature.to_string())) .collect(), ), ); @@ -201,45 +201,39 @@ impl<'a> DecorateLint<'a, ()> for UnsafeOpInUnsafeFn { } } -pub(crate) enum AssertLint<P> { - ArithmeticOverflow(Span, AssertKind<P>), - UnconditionalPanic(Span, AssertKind<P>), +pub(crate) struct AssertLint<P> { + pub span: Span, + pub assert_kind: AssertKind<P>, + pub lint_kind: AssertLintKind, +} + +pub(crate) enum AssertLintKind { + ArithmeticOverflow, + UnconditionalPanic, } impl<'a, P: std::fmt::Debug> DecorateLint<'a, ()> for AssertLint<P> { fn decorate_lint<'b>(self, diag: &'b mut DiagnosticBuilder<'a, ()>) { - let span = self.span(); - let assert_kind = self.panic(); - let message = assert_kind.diagnostic_message(); - assert_kind.add_args(&mut |name, value| { + let message = self.assert_kind.diagnostic_message(); + self.assert_kind.add_args(&mut |name, value| { diag.arg(name, value); }); - diag.span_label(span, message); + diag.span_label(self.span, message); } fn msg(&self) -> DiagnosticMessage { - match self { - AssertLint::ArithmeticOverflow(..) => fluent::mir_transform_arithmetic_overflow, - AssertLint::UnconditionalPanic(..) => fluent::mir_transform_operation_will_panic, + match self.lint_kind { + AssertLintKind::ArithmeticOverflow => fluent::mir_transform_arithmetic_overflow, + AssertLintKind::UnconditionalPanic => fluent::mir_transform_operation_will_panic, } } } -impl<P> AssertLint<P> { +impl AssertLintKind { pub fn lint(&self) -> &'static Lint { match self { - AssertLint::ArithmeticOverflow(..) => lint::builtin::ARITHMETIC_OVERFLOW, - AssertLint::UnconditionalPanic(..) => lint::builtin::UNCONDITIONAL_PANIC, - } - } - pub fn span(&self) -> Span { - match self { - AssertLint::ArithmeticOverflow(sp, _) | AssertLint::UnconditionalPanic(sp, _) => *sp, - } - } - pub fn panic(self) -> AssertKind<P> { - match self { - AssertLint::ArithmeticOverflow(_, p) | AssertLint::UnconditionalPanic(_, p) => p, + AssertLintKind::ArithmeticOverflow => lint::builtin::ARITHMETIC_OVERFLOW, + AssertLintKind::UnconditionalPanic => lint::builtin::UNCONDITIONAL_PANIC, } } } diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs index 26fcfad8287..663abbece85 100644 --- a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs +++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs @@ -26,7 +26,6 @@ fn abi_can_unwind(abi: Abi) -> bool { PtxKernel | Msp430Interrupt | X86Interrupt - | AmdGpuKernel | EfiApi | AvrInterrupt | AvrNonBlockingInterrupt @@ -58,7 +57,9 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, + ty::CoroutineClosure(..) => Abi::RustCall, ty::Coroutine(..) => Abi::Rust, + ty::Error(_) => return false, _ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty), }; let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi); @@ -112,7 +113,7 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { let span = terminator.source_info.span; let foreign = fn_def_id.is_some(); - tcx.emit_spanned_lint( + tcx.emit_node_span_lint( FFI_UNWIND_CALLS, lint_root, span, diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs index 61d99f1f018..f413bd9b311 100644 --- a/compiler/rustc_mir_transform/src/function_item_references.rs +++ b/compiler/rustc_mir_transform/src/function_item_references.rs @@ -185,7 +185,7 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { ret, ); - self.tcx.emit_spanned_lint( + self.tcx.emit_node_span_lint( FUNCTION_ITEM_REFERENCES, lint_root, span, diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs index 390ec3e1a36..2c7ae53055f 100644 --- a/compiler/rustc_mir_transform/src/gvn.rs +++ b/compiler/rustc_mir_transform/src/gvn.rs @@ -93,7 +93,6 @@ use rustc_index::IndexVec; use rustc_middle::mir::interpret::GlobalAlloc; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; -use rustc_middle::ty::adjustment::PointerCoercion; use rustc_middle::ty::layout::LayoutOf; use rustc_middle::ty::{self, Ty, TyCtxt, TypeAndMut}; use rustc_span::def_id::DefId; @@ -489,6 +488,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { NullOp::OffsetOf(fields) => { layout.offset_of_subfield(&self.ecx, fields.iter()).bytes() } + NullOp::DebugAssertions => return None, }; let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); let imm = ImmTy::try_from_uint(val, usize_layout)?; @@ -551,6 +551,29 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } value.offset(Size::ZERO, to, &self.ecx).ok()? } + CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize) => { + let src = self.evaluated[value].as_ref()?; + let to = self.ecx.layout_of(to).ok()?; + let dest = self.ecx.allocate(to, MemoryKind::Stack).ok()?; + self.ecx.unsize_into(src, to, &dest.clone().into()).ok()?; + self.ecx + .alloc_mark_immutable(dest.ptr().provenance.unwrap().alloc_id()) + .ok()?; + dest.into() + } + CastKind::FnPtrToPtr + | CastKind::PtrToPtr + | CastKind::PointerCoercion( + ty::adjustment::PointerCoercion::MutToConstPointer + | ty::adjustment::PointerCoercion::ArrayToPointer + | ty::adjustment::PointerCoercion::UnsafeFnPointer, + ) => { + let src = self.evaluated[value].as_ref()?; + let src = self.ecx.read_immediate(src).ok()?; + let to = self.ecx.layout_of(to).ok()?; + let ret = self.ecx.ptr_to_ptr(&src, to).ok()?; + ret.into() + } _ => return None, }, }; @@ -777,18 +800,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { // Operations. Rvalue::Len(ref mut place) => return self.simplify_len(place, location), - Rvalue::Cast(kind, ref mut value, to) => { - let from = value.ty(self.local_decls, self.tcx); - let value = self.simplify_operand(value, location)?; - if let CastKind::PointerCoercion( - PointerCoercion::ReifyFnPointer | PointerCoercion::ClosureFnPointer(_), - ) = kind - { - // Each reification of a generic fn may get a different pointer. - // Do not try to merge them. - return self.new_opaque(); - } - Value::Cast { kind, value, from, to } + Rvalue::Cast(ref mut kind, ref mut value, to) => { + return self.simplify_cast(kind, value, to, location); } Rvalue::BinaryOp(op, box (ref mut lhs, ref mut rhs)) => { let ty = lhs.ty(self.local_decls, self.tcx); @@ -861,9 +874,10 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { let tcx = self.tcx; if fields.is_empty() { let is_zst = match *kind { - AggregateKind::Array(..) | AggregateKind::Tuple | AggregateKind::Closure(..) => { - true - } + AggregateKind::Array(..) + | AggregateKind::Tuple + | AggregateKind::Closure(..) + | AggregateKind::CoroutineClosure(..) => true, // Only enums can be non-ZST. AggregateKind::Adt(did, ..) => tcx.def_kind(did) != DefKind::Enum, // Coroutines are never ZST, as they at least contain the implicit states. @@ -885,7 +899,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { assert!(!fields.is_empty()); (AggregateTy::Tuple, FIRST_VARIANT) } - AggregateKind::Closure(did, substs) | AggregateKind::Coroutine(did, substs) => { + AggregateKind::Closure(did, substs) + | AggregateKind::CoroutineClosure(did, substs) + | AggregateKind::Coroutine(did, substs) => { (AggregateTy::Def(did, substs), FIRST_VARIANT) } AggregateKind::Adt(did, variant_index, substs, _, None) => { @@ -1031,6 +1047,50 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } } + fn simplify_cast( + &mut self, + kind: &mut CastKind, + operand: &mut Operand<'tcx>, + to: Ty<'tcx>, + location: Location, + ) -> Option<VnIndex> { + use rustc_middle::ty::adjustment::PointerCoercion::*; + use CastKind::*; + + let mut from = operand.ty(self.local_decls, self.tcx); + let mut value = self.simplify_operand(operand, location)?; + if from == to { + return Some(value); + } + + if let CastKind::PointerCoercion(ReifyFnPointer | ClosureFnPointer(_)) = kind { + // Each reification of a generic fn may get a different pointer. + // Do not try to merge them. + return self.new_opaque(); + } + + if let PtrToPtr | PointerCoercion(MutToConstPointer) = kind + && let Value::Cast { kind: inner_kind, value: inner_value, from: inner_from, to: _ } = + *self.get(value) + && let PtrToPtr | PointerCoercion(MutToConstPointer) = inner_kind + { + from = inner_from; + value = inner_value; + *kind = PtrToPtr; + if inner_from == to { + return Some(inner_value); + } + if let Some(const_) = self.try_as_constant(value) { + *operand = Operand::Constant(Box::new(const_)); + } else if let Some(local) = self.try_as_local(value, location) { + *operand = Operand::Copy(local.into()); + self.reused_locals.insert(local); + } + } + + Some(self.insert(Value::Cast { kind: *kind, value, from, to })) + } + fn simplify_len(&mut self, place: &mut Place<'tcx>, location: Location) -> Option<VnIndex> { // Trivial case: we are fetching a statically known length. let place_ty = place.ty(self.local_decls, self.tcx).ty; @@ -1228,8 +1288,8 @@ impl<'tcx> MutVisitor<'tcx> for StorageRemover<'tcx> { fn visit_operand(&mut self, operand: &mut Operand<'tcx>, _: Location) { if let Operand::Move(place) = *operand - && let Some(local) = place.as_local() - && self.reused_locals.contains(local) + && !place.is_indirect_first_projection() + && self.reused_locals.contains(place.local) { *operand = Operand::Copy(place); } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 67668a216de..956d855ab81 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -2,6 +2,7 @@ use crate::deref_separator::deref_finder; use rustc_attr::InlineAttr; use rustc_const_eval::transform::validate::validate_types; +use rustc_hir::def::DefKind; use rustc_hir::def_id::DefId; use rustc_index::bit_set::BitSet; use rustc_index::Idx; @@ -317,6 +318,8 @@ impl<'tcx> Inliner<'tcx> { | InstanceDef::ReifyShim(_) | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) @@ -382,6 +385,17 @@ impl<'tcx> Inliner<'tcx> { } let fn_sig = self.tcx.fn_sig(def_id).instantiate(self.tcx, args); + + // Additionally, check that the body that we're inlining actually agrees + // with the ABI of the trait that the item comes from. + if let InstanceDef::Item(instance_def_id) = callee.def + && self.tcx.def_kind(instance_def_id) == DefKind::AssocFn + && let instance_fn_sig = self.tcx.fn_sig(instance_def_id).skip_binder() + && instance_fn_sig.abi() != fn_sig.abi() + { + return None; + } + let source_info = SourceInfo { span: fn_span, ..terminator.source_info }; return Some(CallSite { callee, fn_sig, block: bb, source_info }); @@ -1025,21 +1039,16 @@ fn try_instance_mir<'tcx>( tcx: TyCtxt<'tcx>, instance: InstanceDef<'tcx>, ) -> Result<&'tcx Body<'tcx>, &'static str> { - match instance { - ty::InstanceDef::DropGlue(_, Some(ty)) => match ty.kind() { - ty::Adt(def, args) => { - let fields = def.all_fields(); - for field in fields { - let field_ty = field.ty(tcx, args); - if field_ty.has_param() && field_ty.has_projections() { - return Err("cannot build drop shim for polymorphic type"); - } - } - - Ok(tcx.instance_mir(instance)) + if let ty::InstanceDef::DropGlue(_, Some(ty)) = instance + && let ty::Adt(def, args) = ty.kind() + { + let fields = def.all_fields(); + for field in fields { + let field_ty = field.ty(tcx, args); + if field_ty.has_param() && field_ty.has_projections() { + return Err("cannot build drop shim for polymorphic type"); } - _ => Ok(tcx.instance_mir(instance)), - }, - _ => Ok(tcx.instance_mir(instance)), + } } + Ok(tcx.instance_mir(instance)) } diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index d30e0bad813..5b03bc361dd 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -87,6 +87,8 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( | InstanceDef::ReifyShim(_) | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::ThreadLocalShim { .. } | InstanceDef::CloneShim(..) => {} diff --git a/compiler/rustc_mir_transform/src/instsimplify.rs b/compiler/rustc_mir_transform/src/instsimplify.rs index a28db0defc9..06df89c1037 100644 --- a/compiler/rustc_mir_transform/src/instsimplify.rs +++ b/compiler/rustc_mir_transform/src/instsimplify.rs @@ -2,10 +2,12 @@ use crate::simplify::simplify_duplicate_switch_targets; use rustc_middle::mir::*; +use rustc_middle::ty::layout; use rustc_middle::ty::layout::ValidityRequirement; use rustc_middle::ty::{self, GenericArgsRef, ParamEnv, Ty, TyCtxt}; use rustc_span::symbol::Symbol; use rustc_target::abi::FieldIdx; +use rustc_target::spec::abi::Abi; pub struct InstSimplify; @@ -27,17 +29,15 @@ impl<'tcx> MirPass<'tcx> for InstSimplify { ctx.simplify_bool_cmp(&statement.source_info, rvalue); ctx.simplify_ref_deref(&statement.source_info, rvalue); ctx.simplify_len(&statement.source_info, rvalue); - ctx.simplify_cast(&statement.source_info, rvalue); + ctx.simplify_cast(rvalue); } _ => {} } } ctx.simplify_primitive_clone(block.terminator.as_mut().unwrap(), &mut block.statements); - ctx.simplify_intrinsic_assert( - block.terminator.as_mut().unwrap(), - &mut block.statements, - ); + ctx.simplify_intrinsic_assert(block.terminator.as_mut().unwrap()); + ctx.simplify_nounwind_call(block.terminator.as_mut().unwrap()); simplify_duplicate_switch_targets(block.terminator.as_mut().unwrap()); } } @@ -140,7 +140,7 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { } } - fn simplify_cast(&self, _source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + fn simplify_cast(&self, rvalue: &mut Rvalue<'tcx>) { if let Rvalue::Cast(kind, operand, cast_ty) = rvalue { let operand_ty = operand.ty(self.local_decls, self.tcx); if operand_ty == *cast_ty { @@ -252,11 +252,29 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { terminator.kind = TerminatorKind::Goto { target: destination_block }; } - fn simplify_intrinsic_assert( - &self, - terminator: &mut Terminator<'tcx>, - _statements: &mut Vec<Statement<'tcx>>, - ) { + fn simplify_nounwind_call(&self, terminator: &mut Terminator<'tcx>) { + let TerminatorKind::Call { func, unwind, .. } = &mut terminator.kind else { + return; + }; + + let Some((def_id, _)) = func.const_fn_def() else { + return; + }; + + let body_ty = self.tcx.type_of(def_id).skip_binder(); + let body_abi = match body_ty.kind() { + ty::FnDef(..) => body_ty.fn_sig(self.tcx).abi(), + ty::Closure(..) => Abi::RustCall, + ty::Coroutine(..) => Abi::Rust, + _ => bug!("unexpected body ty: {:?}", body_ty), + }; + + if !layout::fn_can_unwind(self.tcx, Some(def_id), body_abi) { + *unwind = UnwindAction::Unreachable; + } + } + + fn simplify_intrinsic_assert(&self, terminator: &mut Terminator<'tcx>) { let TerminatorKind::Call { func, target, .. } = &mut terminator.kind else { return; }; diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs index dcab124505e..78ba166ba43 100644 --- a/compiler/rustc_mir_transform/src/jump_threading.rs +++ b/compiler/rustc_mir_transform/src/jump_threading.rs @@ -36,16 +36,21 @@ //! cost by `MAX_COST`. use rustc_arena::DroplessArena; +use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable}; use rustc_data_structures::fx::FxHashSet; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; +use rustc_middle::mir::interpret::Scalar; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; -use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt}; +use rustc_middle::ty::layout::LayoutOf; +use rustc_middle::ty::{self, ScalarInt, TyCtxt}; use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem}; +use rustc_span::DUMMY_SP; use rustc_target::abi::{TagEncoding, Variants}; use crate::cost_checker::CostChecker; +use crate::dataflow_const_prop::DummyMachine; pub struct JumpThreading; @@ -55,7 +60,7 @@ const MAX_PLACES: usize = 100; impl<'tcx> MirPass<'tcx> for JumpThreading { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 4 + sess.mir_opt_level() >= 2 } #[instrument(skip_all level = "debug")] @@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading { let mut finder = TOFinder { tcx, param_env, + ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), body, arena: &arena, map: &map, @@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading { debug!(?discr, ?bb); let discr_ty = discr.ty(body, tcx).ty; - let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue }; + let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue }; let Some(discr) = finder.map.find(discr.as_ref()) else { continue }; debug!(?discr); @@ -142,6 +148,7 @@ struct ThreadingOpportunity { struct TOFinder<'tcx, 'a> { tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>, + ecx: InterpCx<'tcx, 'tcx, DummyMachine>, body: &'a Body<'tcx>, map: &'a Map, loop_headers: &'a BitSet<BasicBlock>, @@ -329,11 +336,11 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { } #[instrument(level = "trace", skip(self))] - fn process_operand( + fn process_immediate( &mut self, bb: BasicBlock, lhs: PlaceIndex, - rhs: &Operand<'tcx>, + rhs: ImmTy<'tcx>, state: &mut State<ConditionSet<'a>>, ) -> Option<!> { let register_opportunity = |c: Condition| { @@ -341,13 +348,70 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target }) }; + let conditions = state.try_get_idx(lhs, self.map)?; + if let Immediate::Scalar(Scalar::Int(int)) = *rhs { + conditions.iter_matches(int).for_each(register_opportunity); + } + + None + } + + /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. + #[instrument(level = "trace", skip(self))] + fn process_constant( + &mut self, + bb: BasicBlock, + lhs: PlaceIndex, + constant: OpTy<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) { + self.map.for_each_projection_value( + lhs, + constant, + &mut |elem, op| match elem { + TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(), + TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(), + TrackElem::Discriminant => { + let variant = self.ecx.read_discriminant(op).ok()?; + let discr_value = + self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?; + Some(discr_value.into()) + } + TrackElem::DerefLen => { + let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into(); + let len_usize = op.len(&self.ecx).ok()?; + let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); + Some(ImmTy::from_uint(len_usize, layout).into()) + } + }, + &mut |place, op| { + if let Some(conditions) = state.try_get_idx(place, self.map) + && let Ok(imm) = self.ecx.read_immediate_raw(op) + && let Some(imm) = imm.right() + && let Immediate::Scalar(Scalar::Int(int)) = *imm + { + conditions.iter_matches(int).for_each(|c: Condition| { + self.opportunities + .push(ThreadingOpportunity { chain: vec![bb], target: c.target }) + }) + } + }, + ); + } + + #[instrument(level = "trace", skip(self))] + fn process_operand( + &mut self, + bb: BasicBlock, + lhs: PlaceIndex, + rhs: &Operand<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) -> Option<!> { match rhs { // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. Operand::Constant(constant) => { - let conditions = state.try_get_idx(lhs, self.map)?; - let constant = - constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?; - conditions.iter_matches(constant).for_each(register_opportunity); + let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?; + self.process_constant(bb, lhs, constant, state); } // Transfer the conditions on the copied rhs. Operand::Move(rhs) | Operand::Copy(rhs) => { @@ -360,6 +424,84 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { } #[instrument(level = "trace", skip(self))] + fn process_assign( + &mut self, + bb: BasicBlock, + lhs_place: &Place<'tcx>, + rhs: &Rvalue<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) -> Option<!> { + let lhs = self.map.find(lhs_place.as_ref())?; + match rhs { + Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?, + // Transfer the conditions on the copy rhs. + Rvalue::CopyForDeref(rhs) => { + self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)? + } + Rvalue::Discriminant(rhs) => { + let rhs = self.map.find_discr(rhs.as_ref())?; + state.insert_place_idx(rhs, lhs, self.map); + } + // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. + Rvalue::Aggregate(box ref kind, ref operands) => { + let agg_ty = lhs_place.ty(self.body, self.tcx).ty; + let lhs = match kind { + // Do not support unions. + AggregateKind::Adt(.., Some(_)) => return None, + AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => { + if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant) + && let Ok(discr_value) = + self.ecx.discriminant_for_variant(agg_ty, *variant_index) + { + self.process_immediate(bb, discr_target, discr_value, state); + } + self.map.apply(lhs, TrackElem::Variant(*variant_index))? + } + _ => lhs, + }; + for (field_index, operand) in operands.iter_enumerated() { + if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) { + self.process_operand(bb, field, operand, state); + } + } + } + // Transfer the conditions on the copy rhs, after inversing polarity. + Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => { + let conditions = state.try_get_idx(lhs, self.map)?; + let place = self.map.find(place.as_ref())?; + let conds = conditions.map(self.arena, Condition::inv); + state.insert_value_idx(place, conds, self.map); + } + // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`. + // Create a condition on `rhs ?= B`. + Rvalue::BinaryOp( + op, + box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value)) + | box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)), + ) => { + let conditions = state.try_get_idx(lhs, self.map)?; + let place = self.map.find(place.as_ref())?; + let equals = match op { + BinOp::Eq => ScalarInt::TRUE, + BinOp::Ne => ScalarInt::FALSE, + _ => return None, + }; + let value = value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?; + let conds = conditions.map(self.arena, |c| Condition { + value, + polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne }, + ..c + }); + state.insert_value_idx(place, conds, self.map); + } + + _ => {} + } + + None + } + + #[instrument(level = "trace", skip(self))] fn process_statement( &mut self, bb: BasicBlock, @@ -374,18 +516,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // Below, `lhs` is the return value of `mutated_statement`, // the place to which `conditions` apply. - let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| { - let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?; - let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?; - let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?; - Some(Operand::const_from_scalar( - self.tcx, - discr.ty, - scalar.into(), - rustc_span::DUMMY_SP, - )) - }; - match &stmt.kind { // If we expect `discriminant(place) ?= A`, // we have an opportunity if `variant_index ?= A`. @@ -395,7 +525,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant // of a niche encoding. If we cannot ensure that we write to the discriminant, do // nothing. - let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?; + let enum_layout = self.ecx.layout_of(enum_ty).ok()?; let writes_discriminant = match enum_layout.variants { Variants::Single { index } => { assert_eq!(index, *variant_index); @@ -408,8 +538,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { } => *variant_index != untagged_variant, }; if writes_discriminant { - let discr = discriminant_for_variant(enum_ty, *variant_index)?; - self.process_operand(bb, discr_target, &discr, state)?; + let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?; + self.process_immediate(bb, discr_target, discr, state)?; } } // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`. @@ -420,89 +550,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity); } StatementKind::Assign(box (lhs_place, rhs)) => { - if let Some(lhs) = self.map.find(lhs_place.as_ref()) { - match rhs { - Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?, - // Transfer the conditions on the copy rhs. - Rvalue::CopyForDeref(rhs) => { - self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)? - } - Rvalue::Discriminant(rhs) => { - let rhs = self.map.find_discr(rhs.as_ref())?; - state.insert_place_idx(rhs, lhs, self.map); - } - // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. - Rvalue::Aggregate(box ref kind, ref operands) => { - let agg_ty = lhs_place.ty(self.body, self.tcx).ty; - let lhs = match kind { - // Do not support unions. - AggregateKind::Adt(.., Some(_)) => return None, - AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => { - if let Some(discr_target) = - self.map.apply(lhs, TrackElem::Discriminant) - && let Some(discr_value) = - discriminant_for_variant(agg_ty, *variant_index) - { - self.process_operand(bb, discr_target, &discr_value, state); - } - self.map.apply(lhs, TrackElem::Variant(*variant_index))? - } - _ => lhs, - }; - for (field_index, operand) in operands.iter_enumerated() { - if let Some(field) = - self.map.apply(lhs, TrackElem::Field(field_index)) - { - self.process_operand(bb, field, operand, state); - } - } - } - // Transfer the conditions on the copy rhs, after inversing polarity. - Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => { - let conditions = state.try_get_idx(lhs, self.map)?; - let place = self.map.find(place.as_ref())?; - let conds = conditions.map(self.arena, Condition::inv); - state.insert_value_idx(place, conds, self.map); - } - // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`. - // Create a condition on `rhs ?= B`. - Rvalue::BinaryOp( - op, - box ( - Operand::Move(place) | Operand::Copy(place), - Operand::Constant(value), - ) - | box ( - Operand::Constant(value), - Operand::Move(place) | Operand::Copy(place), - ), - ) => { - let conditions = state.try_get_idx(lhs, self.map)?; - let place = self.map.find(place.as_ref())?; - let equals = match op { - BinOp::Eq => ScalarInt::TRUE, - BinOp::Ne => ScalarInt::FALSE, - _ => return None, - }; - let value = value - .const_ - .normalize(self.tcx, self.param_env) - .try_to_scalar_int()?; - let conds = conditions.map(self.arena, |c| Condition { - value, - polarity: if c.matches(equals) { - Polarity::Eq - } else { - Polarity::Ne - }, - ..c - }); - state.insert_value_idx(place, conds, self.map); - } - - _ => {} - } - } + self.process_assign(bb, lhs_place, rhs, state)?; } _ => {} } @@ -518,11 +566,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { cost: &CostChecker<'_, 'tcx>, depth: usize, ) { - let register_opportunity = |c: Condition| { - debug!(?bb, ?c.target, "register"); - self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target }) - }; - let term = self.body.basic_blocks[bb].terminator(); let place_to_flood = match term.kind { // We come from a target, so those are not possible. @@ -544,16 +587,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // Flood the overwritten place, and progress through. TerminatorKind::Drop { place: destination, .. } | TerminatorKind::Call { destination, .. } => Some(destination), - // Treat as an `assume(cond == expected)`. - TerminatorKind::Assert { ref cond, expected, .. } => { - if let Some(place) = cond.place() - && let Some(conditions) = state.try_get(place.as_ref(), self.map) - { - let expected = if expected { ScalarInt::TRUE } else { ScalarInt::FALSE }; - conditions.iter_matches(expected).for_each(register_opportunity); - } - None - } + // Ignore, as this can be a no-op at codegen time. + TerminatorKind::Assert { .. } => None, }; // We can recurse through this terminator. @@ -577,7 +612,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { let discr = discr.place()?; let discr_ty = discr.ty(self.body, self.tcx).ty; - let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?; + let discr_layout = self.ecx.layout_of(discr_ty).ok()?; let conditions = state.try_get(discr.as_ref(), self.map)?; if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) { diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 19bfed4333c..72d9ffe8ca5 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -1,22 +1,20 @@ -#![deny(rustc::untranslatable_diagnostic)] -#![deny(rustc::diagnostic_outside_of_impl)] #![feature(assert_matches)] #![feature(box_patterns)] +#![feature(const_type_name)] #![feature(cow_is_borrowed)] #![feature(decl_macro)] #![feature(impl_trait_in_assoc_type)] +#![feature(inline_const)] #![feature(is_sorted)] #![feature(let_chains)] #![feature(map_try_insert)] -#![feature(min_specialization)] +#![cfg_attr(bootstrap, feature(min_specialization))] #![feature(never_type)] #![feature(option_get_or_insert_default)] #![feature(round_char_boundary)] -#![feature(trusted_step)] #![feature(try_blocks)] #![feature(yeet_expr)] #![feature(if_let_guard)] -#![recursion_limit = "256"] #[macro_use] extern crate tracing; @@ -61,7 +59,6 @@ mod remove_place_mention; mod add_subtyping_projections; pub mod cleanup_post_borrowck; mod const_debuginfo; -mod const_goto; mod const_prop; mod const_prop_lint; mod copy_prop; @@ -105,7 +102,6 @@ mod remove_unneeded_drops; mod remove_zsts; mod required_consts; mod reveal_all; -mod separate_const_switch; mod shim; mod ssa; // This pass is public to allow external drivers to perform MIR cleanup @@ -307,6 +303,10 @@ fn mir_const(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> { &Lint(check_packed_ref::CheckPackedRef), &Lint(check_const_item_mutation::CheckConstItemMutation), &Lint(function_item_references::FunctionItemReferences), + // If this is an async closure's output coroutine, generate + // by-move and by-mut bodies if needed. We do this first so + // they can be optimized in lockstep with their parent bodies. + &coroutine::ByMoveBody, // What we need to do constant evaluation. &simplify::SimplifyCfg::Initial, &rustc_peek::SanityCheck, // Just a lint @@ -588,7 +588,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { // Has to run after `slice::len` lowering &normalize_array_len::NormalizeArrayLen, - &const_goto::ConstGoto, &ref_prop::ReferencePropagation, &sroa::ScalarReplacementOfAggregates, &match_branches::MatchBranchSimplification, @@ -599,10 +598,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &dead_store_elimination::DeadStoreElimination::Initial, &gvn::GVN, &simplify::SimplifyLocals::AfterGVN, - // Perform `SeparateConstSwitch` after SSA-based analyses, as cloning blocks may - // destroy the SSA property. It should still happen before const-propagation, so the - // latter pass will leverage the created opportunities. - &separate_const_switch::SeparateConstSwitch, &dataflow_const_prop::DataflowConstProp, &const_debuginfo::ConstDebugInfo, &o1(simplify_branches::SimplifyConstCondition::AfterConstProp), diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs index 897375e0e16..f43b85173d4 100644 --- a/compiler/rustc_mir_transform/src/lower_intrinsics.rs +++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs @@ -21,6 +21,17 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { sym::unreachable => { terminator.kind = TerminatorKind::Unreachable; } + sym::debug_assertions => { + let target = target.unwrap(); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::NullaryOp(NullOp::DebugAssertions, tcx.types.bool), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } sym::forget => { if let Some(target) = *target { block.statements.push(Statement { diff --git a/compiler/rustc_mir_transform/src/nrvo.rs b/compiler/rustc_mir_transform/src/nrvo.rs index ff309bd10ec..c3a92911bbf 100644 --- a/compiler/rustc_mir_transform/src/nrvo.rs +++ b/compiler/rustc_mir_transform/src/nrvo.rs @@ -1,7 +1,7 @@ //! See the docs for [`RenameReturnPlace`]. use rustc_hir::Mutability; -use rustc_index::bit_set::HybridBitSet; +use rustc_index::bit_set::BitSet; use rustc_middle::mir::visit::{MutVisitor, NonUseContext, PlaceContext, Visitor}; use rustc_middle::mir::{self, BasicBlock, Local, Location}; use rustc_middle::ty::TyCtxt; @@ -123,7 +123,7 @@ fn find_local_assigned_to_return_place( body: &mut mir::Body<'_>, ) -> Option<Local> { let mut block = start; - let mut seen = HybridBitSet::new_empty(body.basic_blocks.len()); + let mut seen = BitSet::new_empty(body.basic_blocks.len()); // Iterate as long as `block` has exactly one predecessor that we have not yet visited. while seen.insert(block) { diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index f4c572aec12..77478cc741d 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -7,8 +7,12 @@ use crate::{lint::lint_body, validate, MirPass}; /// Just like `MirPass`, except it cannot mutate `Body`. pub trait MirLint<'tcx> { fn name(&self) -> &'static str { - let name = std::any::type_name::<Self>(); - if let Some((_, tail)) = name.rsplit_once(':') { tail } else { name } + // FIXME Simplify the implementation once more `str` methods get const-stable. + // See copypaste in `MirPass` + const { + let name = std::any::type_name::<Self>(); + rustc_middle::util::common::c_name(name) + } } fn is_enabled(&self, _sess: &Session) -> bool { @@ -177,6 +181,15 @@ fn run_passes_inner<'tcx>( body.pass_count = 1; } + + if let Some(coroutine) = body.coroutine.as_mut() { + if let Some(by_move_body) = coroutine.by_move_body.as_mut() { + run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each); + } + if let Some(by_mut_body) = coroutine.by_mut_body.as_mut() { + run_passes_inner(tcx, by_mut_body, passes, phase_change, validate_each); + } + } } pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: String) { diff --git a/compiler/rustc_mir_transform/src/promote_consts.rs b/compiler/rustc_mir_transform/src/promote_consts.rs index c00093ea27e..577b8f2080f 100644 --- a/compiler/rustc_mir_transform/src/promote_consts.rs +++ b/compiler/rustc_mir_transform/src/promote_consts.rs @@ -446,6 +446,7 @@ impl<'tcx> Validator<'_, 'tcx> { NullOp::SizeOf => {} NullOp::AlignOf => {} NullOp::OffsetOf(_) => {} + NullOp::DebugAssertions => {} }, Rvalue::ShallowInitBox(_, _) => return Err(Unpromotable), diff --git a/compiler/rustc_mir_transform/src/remove_storage_markers.rs b/compiler/rustc_mir_transform/src/remove_storage_markers.rs index 795f5232ee3..f68e592db15 100644 --- a/compiler/rustc_mir_transform/src/remove_storage_markers.rs +++ b/compiler/rustc_mir_transform/src/remove_storage_markers.rs @@ -7,14 +7,10 @@ pub struct RemoveStorageMarkers; impl<'tcx> MirPass<'tcx> for RemoveStorageMarkers { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() > 0 + sess.mir_opt_level() > 0 && !sess.emit_lifetime_markers() } - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - if tcx.sess.emit_lifetime_markers() { - return; - } - + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { trace!("Running RemoveStorageMarkers on {:?}", body.source); for data in body.basic_blocks.as_mut_preserves_cfg() { data.statements.retain(|statement| match statement.kind { diff --git a/compiler/rustc_mir_transform/src/remove_zsts.rs b/compiler/rustc_mir_transform/src/remove_zsts.rs index 34d57a45301..9a94cae3382 100644 --- a/compiler/rustc_mir_transform/src/remove_zsts.rs +++ b/compiler/rustc_mir_transform/src/remove_zsts.rs @@ -46,6 +46,7 @@ fn maybe_zst(ty: Ty<'_>) -> bool { ty::Adt(..) | ty::Array(..) | ty::Closure(..) + | ty::CoroutineClosure(..) | ty::Tuple(..) | ty::Alias(ty::Opaque, ..) => true, // definitely ZST diff --git a/compiler/rustc_mir_transform/src/separate_const_switch.rs b/compiler/rustc_mir_transform/src/separate_const_switch.rs deleted file mode 100644 index 7120ef72142..00000000000 --- a/compiler/rustc_mir_transform/src/separate_const_switch.rs +++ /dev/null @@ -1,343 +0,0 @@ -//! A pass that duplicates switch-terminated blocks -//! into a new copy for each predecessor, provided -//! the predecessor sets the value being switched -//! over to a constant. -//! -//! The purpose of this pass is to help constant -//! propagation passes to simplify the switch terminator -//! of the copied blocks into gotos when some predecessors -//! statically determine the output of switches. -//! -//! ```text -//! x = 12 --- ---> something -//! \ / 12 -//! --> switch x -//! / \ otherwise -//! x = y --- ---> something else -//! ``` -//! becomes -//! ```text -//! x = 12 ---> switch x ------> something -//! \ / 12 -//! X -//! / \ otherwise -//! x = y ---> switch x ------> something else -//! ``` -//! so it can hopefully later be turned by another pass into -//! ```text -//! x = 12 --------------------> something -//! / 12 -//! / -//! / otherwise -//! x = y ---- switch x ------> something else -//! ``` -//! -//! This optimization is meant to cover simple cases -//! like `?` desugaring. For now, it thus focuses on -//! simplicity rather than completeness (it notably -//! sometimes duplicates abusively). - -use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; -use smallvec::SmallVec; - -pub struct SeparateConstSwitch; - -impl<'tcx> MirPass<'tcx> for SeparateConstSwitch { - fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - // This pass participates in some as-of-yet untested unsoundness found - // in https://github.com/rust-lang/rust/issues/112460 - sess.mir_opt_level() >= 2 && sess.opts.unstable_opts.unsound_mir_opts - } - - fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // If execution did something, applying a simplification layer - // helps later passes optimize the copy away. - if separate_const_switch(body) > 0 { - super::simplify::simplify_cfg(body); - } - } -} - -/// Returns the amount of blocks that were duplicated -pub fn separate_const_switch(body: &mut Body<'_>) -> usize { - let mut new_blocks: SmallVec<[(BasicBlock, BasicBlock); 6]> = SmallVec::new(); - let predecessors = body.basic_blocks.predecessors(); - 'block_iter: for (block_id, block) in body.basic_blocks.iter_enumerated() { - if let TerminatorKind::SwitchInt { - discr: Operand::Copy(switch_place) | Operand::Move(switch_place), - .. - } = block.terminator().kind - { - // If the block is on an unwind path, do not - // apply the optimization as unwind paths - // rely on a unique parent invariant - if block.is_cleanup { - continue 'block_iter; - } - - // If the block has fewer than 2 predecessors, ignore it - // we could maybe chain blocks that have exactly one - // predecessor, but for now we ignore that - if predecessors[block_id].len() < 2 { - continue 'block_iter; - } - - // First, let's find a non-const place - // that determines the result of the switch - if let Some(switch_place) = find_determining_place(switch_place, block) { - // We now have an input place for which it would - // be interesting if predecessors assigned it from a const - - let mut predecessors_left = predecessors[block_id].len(); - 'predec_iter: for predecessor_id in predecessors[block_id].iter().copied() { - let predecessor = &body.basic_blocks[predecessor_id]; - - // First we make sure the predecessor jumps - // in a reasonable way - match &predecessor.terminator().kind { - // The following terminators are - // unconditionally valid - TerminatorKind::Goto { .. } | TerminatorKind::SwitchInt { .. } => {} - - TerminatorKind::FalseEdge { real_target, .. } => { - if *real_target != block_id { - continue 'predec_iter; - } - } - - // The following terminators are not allowed - TerminatorKind::UnwindResume - | TerminatorKind::Drop { .. } - | TerminatorKind::Call { .. } - | TerminatorKind::Assert { .. } - | TerminatorKind::FalseUnwind { .. } - | TerminatorKind::Yield { .. } - | TerminatorKind::UnwindTerminate(_) - | TerminatorKind::Return - | TerminatorKind::Unreachable - | TerminatorKind::InlineAsm { .. } - | TerminatorKind::CoroutineDrop => { - continue 'predec_iter; - } - } - - if is_likely_const(switch_place, predecessor) { - new_blocks.push((predecessor_id, block_id)); - predecessors_left -= 1; - if predecessors_left < 2 { - // If the original block only has one predecessor left, - // we have nothing left to do - break 'predec_iter; - } - } - } - } - } - } - - // Once the analysis is done, perform the duplication - let body_span = body.span; - let copied_blocks = new_blocks.len(); - let blocks = body.basic_blocks_mut(); - for (pred_id, target_id) in new_blocks { - let new_block = blocks[target_id].clone(); - let new_block_id = blocks.push(new_block); - let terminator = blocks[pred_id].terminator_mut(); - - match terminator.kind { - TerminatorKind::Goto { ref mut target } => { - *target = new_block_id; - } - - TerminatorKind::FalseEdge { ref mut real_target, .. } => { - if *real_target == target_id { - *real_target = new_block_id; - } - } - - TerminatorKind::SwitchInt { ref mut targets, .. } => { - targets.all_targets_mut().iter_mut().for_each(|x| { - if *x == target_id { - *x = new_block_id; - } - }); - } - - TerminatorKind::UnwindResume - | TerminatorKind::UnwindTerminate(_) - | TerminatorKind::Return - | TerminatorKind::Unreachable - | TerminatorKind::CoroutineDrop - | TerminatorKind::Assert { .. } - | TerminatorKind::FalseUnwind { .. } - | TerminatorKind::Drop { .. } - | TerminatorKind::Call { .. } - | TerminatorKind::InlineAsm { .. } - | TerminatorKind::Yield { .. } => { - span_bug!( - body_span, - "basic block terminator had unexpected kind {:?}", - &terminator.kind - ) - } - } - } - - copied_blocks -} - -/// This function describes a rough heuristic guessing -/// whether a place is last set with a const within the block. -/// Notably, it will be overly pessimistic in cases that are already -/// not handled by `separate_const_switch`. -fn is_likely_const<'tcx>(mut tracked_place: Place<'tcx>, block: &BasicBlockData<'tcx>) -> bool { - for statement in block.statements.iter().rev() { - match &statement.kind { - StatementKind::Assign(assign) => { - if assign.0 == tracked_place { - match assign.1 { - // These rvalues are definitely constant - Rvalue::Use(Operand::Constant(_)) - | Rvalue::Ref(_, _, _) - | Rvalue::AddressOf(_, _) - | Rvalue::Cast(_, Operand::Constant(_), _) - | Rvalue::NullaryOp(_, _) - | Rvalue::ShallowInitBox(_, _) - | Rvalue::UnaryOp(_, Operand::Constant(_)) => return true, - - // These rvalues make things ambiguous - Rvalue::Repeat(_, _) - | Rvalue::ThreadLocalRef(_) - | Rvalue::Len(_) - | Rvalue::BinaryOp(_, _) - | Rvalue::CheckedBinaryOp(_, _) - | Rvalue::Aggregate(_, _) => return false, - - // These rvalues move the place to track - Rvalue::Cast(_, Operand::Copy(place) | Operand::Move(place), _) - | Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) - | Rvalue::CopyForDeref(place) - | Rvalue::UnaryOp(_, Operand::Copy(place) | Operand::Move(place)) - | Rvalue::Discriminant(place) => tracked_place = place, - } - } - } - - // If the discriminant is set, it is always set - // as a constant, so the job is done. - // As we are **ignoring projections**, if the place - // we are tracking sees its discriminant be set, - // that means we had to be tracking the discriminant - // specifically (as it is impossible to switch over - // an enum directly, and if we were switching over - // its content, we would have had to at least cast it to - // some variant first) - StatementKind::SetDiscriminant { place, .. } => { - if **place == tracked_place { - return true; - } - } - - // These statements have no influence on the place - // we are interested in - StatementKind::FakeRead(_) - | StatementKind::Deinit(_) - | StatementKind::StorageLive(_) - | StatementKind::Retag(_, _) - | StatementKind::AscribeUserType(_, _) - | StatementKind::PlaceMention(..) - | StatementKind::Coverage(_) - | StatementKind::StorageDead(_) - | StatementKind::Intrinsic(_) - | StatementKind::ConstEvalCounter - | StatementKind::Nop => {} - } - } - - // If no good reason for the place to be const is found, - // give up. We could maybe go up predecessors, but in - // most cases giving up now should be sufficient. - false -} - -/// Finds a unique place that entirely determines the value -/// of `switch_place`, if it exists. This is only a heuristic. -/// Ideally we would like to track multiple determining places -/// for some edge cases, but one is enough for a lot of situations. -fn find_determining_place<'tcx>( - mut switch_place: Place<'tcx>, - block: &BasicBlockData<'tcx>, -) -> Option<Place<'tcx>> { - for statement in block.statements.iter().rev() { - match &statement.kind { - StatementKind::Assign(op) => { - if op.0 != switch_place { - continue; - } - - match op.1 { - // The following rvalues move the place - // that may be const in the predecessor - Rvalue::Use(Operand::Move(new) | Operand::Copy(new)) - | Rvalue::UnaryOp(_, Operand::Copy(new) | Operand::Move(new)) - | Rvalue::CopyForDeref(new) - | Rvalue::Cast(_, Operand::Move(new) | Operand::Copy(new), _) - | Rvalue::Repeat(Operand::Move(new) | Operand::Copy(new), _) - | Rvalue::Discriminant(new) - => switch_place = new, - - // The following rvalues might still make the block - // be valid but for now we reject them - Rvalue::Len(_) - | Rvalue::Ref(_, _, _) - | Rvalue::BinaryOp(_, _) - | Rvalue::CheckedBinaryOp(_, _) - | Rvalue::Aggregate(_, _) - - // The following rvalues definitely mean we cannot - // or should not apply this optimization - | Rvalue::Use(Operand::Constant(_)) - | Rvalue::Repeat(Operand::Constant(_), _) - | Rvalue::ThreadLocalRef(_) - | Rvalue::AddressOf(_, _) - | Rvalue::NullaryOp(_, _) - | Rvalue::ShallowInitBox(_, _) - | Rvalue::UnaryOp(_, Operand::Constant(_)) - | Rvalue::Cast(_, Operand::Constant(_), _) => return None, - } - } - - // These statements have no influence on the place - // we are interested in - StatementKind::FakeRead(_) - | StatementKind::Deinit(_) - | StatementKind::StorageLive(_) - | StatementKind::StorageDead(_) - | StatementKind::Retag(_, _) - | StatementKind::AscribeUserType(_, _) - | StatementKind::PlaceMention(..) - | StatementKind::Coverage(_) - | StatementKind::Intrinsic(_) - | StatementKind::ConstEvalCounter - | StatementKind::Nop => {} - - // If the discriminant is set, it is always set - // as a constant, so the job is already done. - // As we are **ignoring projections**, if the place - // we are tracking sees its discriminant be set, - // that means we had to be tracking the discriminant - // specifically (as it is impossible to switch over - // an enum directly, and if we were switching over - // its content, we would have had to at least cast it to - // some variant first) - StatementKind::SetDiscriminant { place, .. } => { - if **place == switch_place { - return None; - } - } - } - } - - Some(switch_place) -} diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 89414ce940e..860d280be29 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -3,8 +3,8 @@ use rustc_hir::def_id::DefId; use rustc_hir::lang_items::LangItem; use rustc_middle::mir::*; use rustc_middle::query::Providers; -use rustc_middle::ty::GenericArgs; use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt}; +use rustc_middle::ty::{GenericArgs, CAPTURE_STRUCT_LOCAL}; use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; use rustc_index::{Idx, IndexVec}; @@ -66,11 +66,76 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut)) } + ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind, + } => match target_kind { + ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"), + ty::ClosureKind::FnMut => { + // No need to optimize the body, it has already been optimized + // since we steal it from the `AsyncFn::call` body and just fix + // the return type. + return build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id); + } + ty::ClosureKind::FnOnce => { + build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id) + } + }, + + ty::InstanceDef::CoroutineKindShim { coroutine_def_id, target_kind } => match target_kind { + ty::ClosureKind::Fn => unreachable!(), + ty::ClosureKind::FnMut => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine_by_mut_body() + .unwrap() + .clone(); + } + ty::ClosureKind::FnOnce => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine_by_move_body() + .unwrap() + .clone(); + } + }, + ty::InstanceDef::DropGlue(def_id, ty) => { // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end // of this function. Is this intentional? if let Some(ty::Coroutine(coroutine_def_id, args)) = ty.map(Ty::kind) { - let body = tcx.optimized_mir(*coroutine_def_id).coroutine_drop().unwrap(); + let coroutine_body = tcx.optimized_mir(*coroutine_def_id); + + let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind() + else { + bug!() + }; + + // If this is a regular coroutine, grab its drop shim. If this is a coroutine + // that comes from a coroutine-closure, and the kind ty differs from the "maximum" + // kind that it supports, then grab the appropriate drop shim. This ensures that + // the future returned by `<[coroutine-closure] as AsyncFnOnce>::call_once` will + // drop the coroutine-closure's upvars. + let body = if id_args.as_coroutine().kind_ty() == args.as_coroutine().kind_ty() { + coroutine_body.coroutine_drop().unwrap() + } else { + match args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() { + ty::ClosureKind::Fn => { + unreachable!() + } + ty::ClosureKind::FnMut => coroutine_body + .coroutine_by_mut_body() + .unwrap() + .coroutine_drop() + .unwrap(), + ty::ClosureKind::FnOnce => coroutine_body + .coroutine_by_move_body() + .unwrap() + .coroutine_drop() + .unwrap(), + } + }; + let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args); debug!("make_shim({:?}) = {:?}", instance, body); @@ -382,16 +447,13 @@ fn build_thread_local_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'t fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Body<'tcx> { debug!("build_clone_shim(def_id={:?})", def_id); - let param_env = tcx.param_env_reveal_all_normalized(def_id); - let mut builder = CloneShimBuilder::new(tcx, def_id, self_ty); - let is_copy = self_ty.is_copy_modulo_regions(tcx, param_env); let dest = Place::return_place(); let src = tcx.mk_place_deref(Place::from(Local::new(1 + 0))); match self_ty.kind() { - _ if is_copy => builder.copy_shim(), + ty::FnDef(..) | ty::FnPtr(_) => builder.copy_shim(), ty::Closure(_, args) => builder.tuple_like_shim(dest, src, args.as_closure().upvar_tys()), ty::Tuple(..) => builder.tuple_like_shim(dest, src, self_ty.tuple_fields()), ty::Coroutine(coroutine_def_id, args) => { @@ -981,3 +1043,114 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t let source = MirSource::from_instance(ty::InstanceDef::FnPtrAddrShim(def_id, self_ty)); new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span) } + +fn build_construct_coroutine_by_move_shim<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_closure_def_id: DefId, +) -> Body<'tcx> { + let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let ty::CoroutineClosure(_, args) = *self_ty.kind() else { + bug!(); + }; + + let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { + tcx.mk_fn_sig( + [self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()), + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.as_coroutine_closure().parent_args(), + tcx.coroutine_for_closure(coroutine_closure_def_id), + ty::ClosureKind::FnOnce, + tcx.lifetimes.re_erased, + args.as_coroutine_closure().tupled_upvars_ty(), + args.as_coroutine_closure().coroutine_captures_by_ref_ty(), + ), + sig.c_variadic, + sig.unsafety, + sig.abi, + ) + }); + let sig = tcx.liberate_late_bound_regions(coroutine_closure_def_id, poly_sig); + let ty::Coroutine(coroutine_def_id, coroutine_args) = *sig.output().kind() else { + bug!(); + }; + + let span = tcx.def_span(coroutine_closure_def_id); + let locals = local_decls_for_sig(&sig, span); + + let mut fields = vec![]; + for idx in 1..sig.inputs().len() { + fields.push(Operand::Move(Local::from_usize(idx + 1).into())); + } + for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() { + fields.push(Operand::Move(tcx.mk_place_field( + Local::from_usize(1).into(), + FieldIdx::from_usize(idx), + ty, + ))); + } + + let source_info = SourceInfo::outermost(span); + let rvalue = Rvalue::Aggregate( + Box::new(AggregateKind::Coroutine(coroutine_def_id, coroutine_args)), + IndexVec::from_raw(fields), + ); + let stmt = Statement { + source_info, + kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))), + }; + let statements = vec![stmt]; + let start_block = BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }; + + let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind: ty::ClosureKind::FnOnce, + }); + + let body = + new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span); + dump_mir(tcx, false, "coroutine_closure_by_move", &0, &body, |_, _| Ok(())); + + body +} + +fn build_construct_coroutine_by_mut_shim<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_closure_def_id: DefId, +) -> Body<'tcx> { + let mut body = tcx.optimized_mir(coroutine_closure_def_id).clone(); + let coroutine_closure_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let ty::CoroutineClosure(_, args) = *coroutine_closure_ty.kind() else { + bug!(); + }; + let args = args.as_coroutine_closure(); + + body.local_decls[RETURN_PLACE].ty = + tcx.instantiate_bound_regions_with_erased(args.coroutine_closure_sig().map_bound(|sig| { + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(coroutine_closure_def_id), + ty::ClosureKind::FnMut, + tcx.lifetimes.re_erased, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ) + })); + body.local_decls[CAPTURE_STRUCT_LOCAL].ty = + Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_closure_ty); + + body.source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind: ty::ClosureKind::FnMut, + }); + + body.pass_count = 0; + dump_mir(tcx, false, "coroutine_closure_by_mut", &0, &body, |_, _| Ok(())); + + body +} |
