diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
92 files changed, 32153 insertions, 0 deletions
diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs new file mode 100644 index 00000000000..5bd6fdcf485 --- /dev/null +++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs @@ -0,0 +1,123 @@ +use rustc_abi::ExternAbi; +use rustc_ast::InlineAsmOptions; +use rustc_middle::mir::*; +use rustc_middle::span_bug; +use rustc_middle::ty::{self, TyCtxt, layout}; +use rustc_target::spec::PanicStrategy; + +/// A pass that runs which is targeted at ensuring that codegen guarantees about +/// unwinding are upheld for compilations of panic=abort programs. +/// +/// When compiling with panic=abort codegen backends generally want to assume +/// that all Rust-defined functions do not unwind, and it's UB if they actually +/// do unwind. Foreign functions, however, can be declared as "may unwind" via +/// their ABI (e.g. `extern "C-unwind"`). To uphold the guarantees that +/// Rust-defined functions never unwind a well-behaved Rust program needs to +/// catch unwinding from foreign functions and force them to abort. +/// +/// This pass walks over all functions calls which may possibly unwind, +/// and if any are found sets their cleanup to a block that aborts the process. +/// This forces all unwinds, in panic=abort mode happening in foreign code, to +/// trigger a process abort. +#[derive(PartialEq)] +pub(super) struct AbortUnwindingCalls; + +impl<'tcx> crate::MirPass<'tcx> for AbortUnwindingCalls { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + let kind = tcx.def_kind(def_id); + + // We don't simplify the MIR of constants at this time because that + // namely results in a cyclic query when we call `tcx.type_of` below. + if !kind.is_fn_like() { + return; + } + + // Here we test for this function itself whether its ABI allows + // unwinding or not. + let body_ty = tcx.type_of(def_id).skip_binder(); + let body_abi = match body_ty.kind() { + ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), + ty::Closure(..) => ExternAbi::RustCall, + ty::CoroutineClosure(..) => ExternAbi::RustCall, + ty::Coroutine(..) => ExternAbi::Rust, + ty::Error(_) => return, + _ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty), + }; + let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi); + + // Look in this function body for any basic blocks which are terminated + // with a function call, and whose function we're calling may unwind. + // This will filter to functions with `extern "C-unwind"` ABIs, for + // example. + for block in body.basic_blocks.as_mut() { + let Some(terminator) = &mut block.terminator else { continue }; + let span = terminator.source_info.span; + + // If we see an `UnwindResume` terminator inside a function that cannot unwind, we need + // to replace it with `UnwindTerminate`. + if let TerminatorKind::UnwindResume = &terminator.kind + && !body_can_unwind + { + terminator.kind = TerminatorKind::UnwindTerminate(UnwindTerminateReason::Abi); + } + + if block.is_cleanup { + continue; + } + + let call_can_unwind = match &terminator.kind { + TerminatorKind::Call { func, .. } => { + let ty = func.ty(&body.local_decls, tcx); + let sig = ty.fn_sig(tcx); + let fn_def_id = match ty.kind() { + ty::FnPtr(..) => None, + &ty::FnDef(def_id, _) => Some(def_id), + _ => span_bug!(span, "invalid callee of type {:?}", ty), + }; + layout::fn_can_unwind(tcx, fn_def_id, sig.abi()) + } + TerminatorKind::Drop { .. } => { + tcx.sess.opts.unstable_opts.panic_in_drop == PanicStrategy::Unwind + && layout::fn_can_unwind(tcx, None, ExternAbi::Rust) + } + TerminatorKind::Assert { .. } | TerminatorKind::FalseUnwind { .. } => { + layout::fn_can_unwind(tcx, None, ExternAbi::Rust) + } + TerminatorKind::InlineAsm { options, .. } => { + options.contains(InlineAsmOptions::MAY_UNWIND) + } + _ if terminator.unwind().is_some() => { + span_bug!(span, "unexpected terminator that may unwind {:?}", terminator) + } + _ => continue, + }; + + if !call_can_unwind { + // If this function call can't unwind, then there's no need for it + // to have a landing pad. This means that we can remove any cleanup + // registered for it (and turn it into `UnwindAction::Unreachable`). + let cleanup = block.terminator_mut().unwind_mut().unwrap(); + *cleanup = UnwindAction::Unreachable; + } else if !body_can_unwind + && matches!(terminator.unwind(), Some(UnwindAction::Continue)) + { + // Otherwise if this function can unwind, then if the outer function + // can also unwind there's nothing to do. If the outer function + // can't unwind, however, we need to ensure that any `UnwindAction::Continue` + // is replaced with terminate. For those with `UnwindAction::Cleanup`, + // cleanup will still happen, and terminate will happen afterwards handled by + // the `UnwindResume` -> `UnwindTerminate` terminator replacement. + let cleanup = block.terminator_mut().unwind_mut().unwrap(); + *cleanup = UnwindAction::Terminate(UnwindTerminateReason::Abi); + } + } + + // We may have invalidated some `cleanup` blocks so clean those up now. + super::simplify::remove_dead_blocks(body); + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/add_call_guards.rs b/compiler/rustc_mir_transform/src/add_call_guards.rs new file mode 100644 index 00000000000..bacff287859 --- /dev/null +++ b/compiler/rustc_mir_transform/src/add_call_guards.rs @@ -0,0 +1,112 @@ +use rustc_index::{Idx, IndexVec}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use tracing::debug; + +#[derive(PartialEq)] +pub(super) enum AddCallGuards { + AllCallEdges, + CriticalCallEdges, +} +pub(super) use self::AddCallGuards::*; + +/** + * Breaks outgoing critical edges for call terminators in the MIR. + * + * Critical edges are edges that are neither the only edge leaving a + * block, nor the only edge entering one. + * + * When you want something to happen "along" an edge, you can either + * do at the end of the predecessor block, or at the start of the + * successor block. Critical edges have to be broken in order to prevent + * "edge actions" from affecting other edges. We need this for calls that are + * codegened to LLVM invoke instructions, because invoke is a block terminator + * in LLVM so we can't insert any code to handle the call's result into the + * block that performs the call. + * + * This function will break those edges by inserting new blocks along them. + * + * NOTE: Simplify CFG will happily undo most of the work this pass does. + * + */ + +impl<'tcx> crate::MirPass<'tcx> for AddCallGuards { + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut pred_count = IndexVec::from_elem(0u8, &body.basic_blocks); + for (_, data) in body.basic_blocks.iter_enumerated() { + for succ in data.terminator().successors() { + pred_count[succ] = pred_count[succ].saturating_add(1); + } + } + + // We need a place to store the new blocks generated + let mut new_blocks = Vec::new(); + + let cur_len = body.basic_blocks.len(); + let mut new_block = |source_info: SourceInfo, is_cleanup: bool, target: BasicBlock| { + let block = BasicBlockData { + statements: vec![], + is_cleanup, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Goto { target } }), + }; + let idx = cur_len + new_blocks.len(); + new_blocks.push(block); + BasicBlock::new(idx) + }; + + for block in body.basic_blocks_mut() { + match block.terminator { + Some(Terminator { + kind: TerminatorKind::Call { target: Some(ref mut destination), unwind, .. }, + source_info, + }) if pred_count[*destination] > 1 + && (generates_invoke(unwind) || self == &AllCallEdges) => + { + // It's a critical edge, break it + *destination = new_block(source_info, block.is_cleanup, *destination); + } + Some(Terminator { + kind: + TerminatorKind::InlineAsm { + asm_macro: InlineAsmMacro::Asm, + ref mut targets, + ref operands, + unwind, + .. + }, + source_info, + }) if self == &CriticalCallEdges => { + let has_outputs = operands.iter().any(|op| { + matches!(op, InlineAsmOperand::InOut { .. } | InlineAsmOperand::Out { .. }) + }); + let has_labels = + operands.iter().any(|op| matches!(op, InlineAsmOperand::Label { .. })); + if has_outputs && (has_labels || generates_invoke(unwind)) { + for target in targets.iter_mut() { + if pred_count[*target] > 1 { + *target = new_block(source_info, block.is_cleanup, *target); + } + } + } + } + _ => {} + } + } + + debug!("Broke {} N edges", new_blocks.len()); + + body.basic_blocks_mut().extend(new_blocks); + } + + fn is_required(&self) -> bool { + true + } +} + +/// Returns true if this unwind action is code generated as an invoke as opposed to a call. +fn generates_invoke(unwind: UnwindAction) -> bool { + match unwind { + UnwindAction::Continue | UnwindAction::Unreachable => false, + UnwindAction::Cleanup(_) | UnwindAction::Terminate(_) => true, + } +} diff --git a/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs b/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs new file mode 100644 index 00000000000..a414d120e68 --- /dev/null +++ b/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs @@ -0,0 +1,115 @@ +use rustc_middle::mir::*; +use rustc_middle::ty::{self, TyCtxt}; +use tracing::debug; + +use crate::patch::MirPatch; +use crate::util; + +/// This pass moves values being dropped that are within a packed +/// struct to a separate local before dropping them, to ensure that +/// they are dropped from an aligned address. +/// +/// For example, if we have something like +/// ```ignore (illustrative) +/// #[repr(packed)] +/// struct Foo { +/// dealign: u8, +/// data: Vec<u8> +/// } +/// +/// let foo = ...; +/// ``` +/// +/// We want to call `drop_in_place::<Vec<u8>>` on `data` from an aligned +/// address. This means we can't simply drop `foo.data` directly, because +/// its address is not aligned. +/// +/// Instead, we move `foo.data` to a local and drop that: +/// ```ignore (illustrative) +/// storage.live(drop_temp) +/// drop_temp = foo.data; +/// drop(drop_temp) -> next +/// next: +/// storage.dead(drop_temp) +/// ``` +/// +/// The storage instructions are required to avoid stack space +/// blowup. +pub(super) struct AddMovesForPackedDrops; + +impl<'tcx> crate::MirPass<'tcx> for AddMovesForPackedDrops { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("add_moves_for_packed_drops({:?} @ {:?})", body.source, body.span); + let mut patch = MirPatch::new(body); + // FIXME(#132279): This is used during the phase transition from analysis + // to runtime, so we have to manually specify the correct typing mode. + let typing_env = ty::TypingEnv::post_analysis(tcx, body.source.def_id()); + + for (bb, data) in body.basic_blocks.iter_enumerated() { + let loc = Location { block: bb, statement_index: data.statements.len() }; + let terminator = data.terminator(); + + match terminator.kind { + TerminatorKind::Drop { place, .. } + if util::is_disaligned(tcx, body, typing_env, place) => + { + add_move_for_packed_drop( + tcx, + body, + &mut patch, + terminator, + loc, + data.is_cleanup, + ); + } + _ => {} + } + } + + patch.apply(body); + } + + fn is_required(&self) -> bool { + true + } +} + +fn add_move_for_packed_drop<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + patch: &mut MirPatch<'tcx>, + terminator: &Terminator<'tcx>, + loc: Location, + is_cleanup: bool, +) { + debug!("add_move_for_packed_drop({:?} @ {:?})", terminator, loc); + let TerminatorKind::Drop { ref place, target, unwind, replace, drop, async_fut } = + terminator.kind + else { + unreachable!(); + }; + + let source_info = terminator.source_info; + let ty = place.ty(body, tcx).ty; + let temp = patch.new_temp(ty, source_info.span); + + let storage_dead_block = patch.new_block(BasicBlockData { + statements: vec![Statement { source_info, kind: StatementKind::StorageDead(temp) }], + terminator: Some(Terminator { source_info, kind: TerminatorKind::Goto { target } }), + is_cleanup, + }); + + patch.add_statement(loc, StatementKind::StorageLive(temp)); + patch.add_assign(loc, Place::from(temp), Rvalue::Use(Operand::Move(*place))); + patch.patch_terminator( + loc.block, + TerminatorKind::Drop { + place: Place::from(temp), + target: storage_dead_block, + unwind, + replace, + drop, + async_fut, + }, + ); +} diff --git a/compiler/rustc_mir_transform/src/add_retag.rs b/compiler/rustc_mir_transform/src/add_retag.rs new file mode 100644 index 00000000000..e5a28d1b66c --- /dev/null +++ b/compiler/rustc_mir_transform/src/add_retag.rs @@ -0,0 +1,189 @@ +//! This pass adds validation calls (AcquireValid, ReleaseValid) where appropriate. +//! It has to be run really early, before transformations like inlining, because +//! introducing these calls *adds* UB -- so, conceptually, this pass is actually part +//! of MIR building, and only after this pass we think of the program has having the +//! normal MIR semantics. + +use rustc_hir::LangItem; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; + +pub(super) struct AddRetag; + +/// Determine whether this type may contain a reference (or box), and thus needs retagging. +/// We will only recurse `depth` times into Tuples/ADTs to bound the cost of this. +fn may_contain_reference<'tcx>(ty: Ty<'tcx>, depth: u32, tcx: TyCtxt<'tcx>) -> bool { + match ty.kind() { + // Primitive types that are not references + ty::Bool + | ty::Char + | ty::Float(_) + | ty::Int(_) + | ty::Uint(_) + | ty::RawPtr(..) + | ty::FnPtr(..) + | ty::Str + | ty::FnDef(..) + | ty::Never => false, + // References and Boxes (`noalias` sources) + ty::Ref(..) => true, + ty::Adt(..) if ty.is_box() => true, + ty::Adt(adt, _) if tcx.is_lang_item(adt.did(), LangItem::PtrUnique) => true, + // Compound types: recurse + ty::Array(ty, _) | ty::Slice(ty) => { + // This does not branch so we keep the depth the same. + may_contain_reference(*ty, depth, tcx) + } + ty::Tuple(tys) => { + depth == 0 || tys.iter().any(|ty| may_contain_reference(ty, depth - 1, tcx)) + } + ty::Adt(adt, args) => { + depth == 0 + || adt.variants().iter().any(|v| { + v.fields.iter().any(|f| may_contain_reference(f.ty(tcx, args), depth - 1, tcx)) + }) + } + // Conservative fallback + _ => true, + } +} + +impl<'tcx> crate::MirPass<'tcx> for AddRetag { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.opts.unstable_opts.mir_emit_retag + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // We need an `AllCallEdges` pass before we can do any work. + super::add_call_guards::AllCallEdges.run_pass(tcx, body); + + let basic_blocks = body.basic_blocks.as_mut(); + let local_decls = &body.local_decls; + let needs_retag = |place: &Place<'tcx>| { + // We're not really interested in stores to "outside" locations, they are hard to keep + // track of anyway. + !place.is_indirect_first_projection() + && may_contain_reference(place.ty(&*local_decls, tcx).ty, /*depth*/ 3, tcx) + && !local_decls[place.local].is_deref_temp() + }; + + // PART 1 + // Retag arguments at the beginning of the start block. + { + // Gather all arguments, skip return value. + let places = local_decls.iter_enumerated().skip(1).take(body.arg_count).filter_map( + |(local, decl)| { + let place = Place::from(local); + needs_retag(&place).then_some((place, decl.source_info)) + }, + ); + + // Emit their retags. + basic_blocks[START_BLOCK].statements.splice( + 0..0, + places.map(|(place, source_info)| Statement { + source_info, + kind: StatementKind::Retag(RetagKind::FnEntry, Box::new(place)), + }), + ); + } + + // PART 2 + // Retag return values of functions. + // We collect the return destinations because we cannot mutate while iterating. + let returns = basic_blocks + .iter_mut() + .filter_map(|block_data| { + match block_data.terminator().kind { + TerminatorKind::Call { target: Some(target), destination, .. } + if needs_retag(&destination) => + { + // Remember the return destination for later + Some((block_data.terminator().source_info, destination, target)) + } + + // `Drop` is also a call, but it doesn't return anything so we are good. + TerminatorKind::Drop { .. } => None, + // Not a block ending in a Call -> ignore. + _ => None, + } + }) + .collect::<Vec<_>>(); + // Now we go over the returns we collected to retag the return values. + for (source_info, dest_place, dest_block) in returns { + basic_blocks[dest_block].statements.insert( + 0, + Statement { + source_info, + kind: StatementKind::Retag(RetagKind::Default, Box::new(dest_place)), + }, + ); + } + + // PART 3 + // Add retag after assignments. + for block_data in basic_blocks { + // We want to insert statements as we iterate. To this end, we + // iterate backwards using indices. + for i in (0..block_data.statements.len()).rev() { + let (retag_kind, place) = match block_data.statements[i].kind { + // Retag after assignments of reference type. + StatementKind::Assign(box (ref place, ref rvalue)) => { + let add_retag = match rvalue { + // Ptr-creating operations already do their own internal retagging, no + // need to also add a retag statement. *Except* if we are deref'ing a + // Box, because those get desugared to directly working with the inner + // raw pointer! That's relevant for `RawPtr` as Miri otherwise makes it + // a NOP when the original pointer is already raw. + Rvalue::RawPtr(_mutbl, place) => { + // Using `is_box_global` here is a bit sketchy: if this code is + // generic over the allocator, we'll not add a retag! This is a hack + // to make Stacked Borrows compatible with custom allocator code. + // It means the raw pointer inherits the tag of the box, which mostly works + // but can sometimes lead to unexpected aliasing errors. + // Long-term, we'll want to move to an aliasing model where "cast to + // raw pointer" is a complete NOP, and then this will no longer be + // an issue. + if place.is_indirect_first_projection() + && body.local_decls[place.local].ty.is_box_global(tcx) + { + Some(RetagKind::Raw) + } else { + None + } + } + Rvalue::Ref(..) => None, + _ => { + if needs_retag(place) { + Some(RetagKind::Default) + } else { + None + } + } + }; + if let Some(kind) = add_retag { + (kind, *place) + } else { + continue; + } + } + // Do nothing for the rest + _ => continue, + }; + // Insert a retag after the statement. + let source_info = block_data.statements[i].source_info; + block_data.statements.insert( + i + 1, + Statement { + source_info, + kind: StatementKind::Retag(retag_kind, Box::new(place)), + }, + ); + } + } + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/add_subtyping_projections.rs b/compiler/rustc_mir_transform/src/add_subtyping_projections.rs new file mode 100644 index 00000000000..92ee80eaa35 --- /dev/null +++ b/compiler/rustc_mir_transform/src/add_subtyping_projections.rs @@ -0,0 +1,69 @@ +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +use crate::patch::MirPatch; + +pub(super) struct Subtyper; + +struct SubTypeChecker<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + patcher: MirPatch<'tcx>, + local_decls: &'a LocalDecls<'tcx>, +} + +impl<'a, 'tcx> MutVisitor<'tcx> for SubTypeChecker<'a, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_assign( + &mut self, + place: &mut Place<'tcx>, + rvalue: &mut Rvalue<'tcx>, + location: Location, + ) { + // We don't need to do anything for deref temps as they are + // not part of the source code, but used for desugaring purposes. + if self.local_decls[place.local].is_deref_temp() { + return; + } + let mut place_ty = place.ty(self.local_decls, self.tcx).ty; + let mut rval_ty = rvalue.ty(self.local_decls, self.tcx); + // Not erasing this causes `Free Regions` errors in validator, + // when rval is `ReStatic`. + rval_ty = self.tcx.erase_regions(rval_ty); + place_ty = self.tcx.erase_regions(place_ty); + if place_ty != rval_ty { + let temp = self + .patcher + .new_temp(rval_ty, self.local_decls[place.as_ref().local].source_info.span); + let new_place = Place::from(temp); + self.patcher.add_assign(location, new_place, rvalue.clone()); + let subtyped = new_place.project_deeper(&[ProjectionElem::Subtype(place_ty)], self.tcx); + *rvalue = Rvalue::Use(Operand::Move(subtyped)); + } + } +} + +// Aim here is to do this kind of transformation: +// +// let place: place_ty = rval; +// // gets transformed to +// let temp: rval_ty = rval; +// let place: place_ty = temp as place_ty; +impl<'tcx> crate::MirPass<'tcx> for Subtyper { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let patch = MirPatch::new(body); + let mut checker = SubTypeChecker { tcx, patcher: patch, local_decls: &body.local_decls }; + + for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { + checker.visit_basic_block_data(bb, data); + } + checker.patcher.apply(body); + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs new file mode 100644 index 00000000000..8f88613b79f --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_alignment.rs @@ -0,0 +1,152 @@ +use rustc_abi::Align; +use rustc_index::IndexVec; +use rustc_middle::mir::interpret::Scalar; +use rustc_middle::mir::visit::PlaceContext; +use rustc_middle::mir::*; +use rustc_middle::ty::{Ty, TyCtxt}; +use rustc_session::Session; + +use crate::check_pointers::{BorrowedFieldProjectionMode, PointerCheck, check_pointers}; + +pub(super) struct CheckAlignment; + +impl<'tcx> crate::MirPass<'tcx> for CheckAlignment { + fn is_enabled(&self, sess: &Session) -> bool { + sess.ub_checks() + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // Skip trivially aligned place types. + let excluded_pointees = [tcx.types.bool, tcx.types.i8, tcx.types.u8]; + + // When checking the alignment of references to field projections (`&(*ptr).a`), + // we need to make sure that the reference is aligned according to the field type + // and not to the pointer type. + check_pointers( + tcx, + body, + &excluded_pointees, + insert_alignment_check, + BorrowedFieldProjectionMode::FollowProjections, + ); + } + + fn is_required(&self) -> bool { + true + } +} + +/// Inserts the actual alignment check's logic. Returns a +/// [AssertKind::MisalignedPointerDereference] on failure. +fn insert_alignment_check<'tcx>( + tcx: TyCtxt<'tcx>, + pointer: Place<'tcx>, + pointee_ty: Ty<'tcx>, + _context: PlaceContext, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + stmts: &mut Vec<Statement<'tcx>>, + source_info: SourceInfo, +) -> PointerCheck<'tcx> { + // Cast the pointer to a *const (). + let const_raw_ptr = Ty::new_imm_ptr(tcx, tcx.types.unit); + let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr); + let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into(); + stmts + .push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) }); + + // Transmute the pointer to a usize (equivalent to `ptr.addr()`). + let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize); + let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) }); + + // Get the alignment of the pointee + let alignment = + local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((alignment, rvalue))), + }); + + // Subtract 1 from the alignment to get the alignment mask + let alignment_mask = + local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + let one = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)), tcx.types.usize), + })); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + alignment_mask, + Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(alignment), one))), + ))), + }); + + // If this target does not have reliable alignment, further limit the mask by anding it with + // the mask for the highest reliable alignment. + #[allow(irrefutable_let_patterns)] + if let max_align = tcx.sess.target.max_reliable_alignment() + && max_align < Align::MAX + { + let max_mask = max_align.bytes() - 1; + let max_mask = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val( + ConstValue::Scalar(Scalar::from_target_usize(max_mask, &tcx)), + tcx.types.usize, + ), + })); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + alignment_mask, + Rvalue::BinaryOp( + BinOp::BitAnd, + Box::new((Operand::Copy(alignment_mask), max_mask)), + ), + ))), + }); + } + + // BitAnd the alignment mask with the pointer + let alignment_bits = + local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + alignment_bits, + Rvalue::BinaryOp( + BinOp::BitAnd, + Box::new((Operand::Copy(addr), Operand::Copy(alignment_mask))), + ), + ))), + }); + + // Check if the alignment bits are all zero + let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + let zero = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize), + })); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_ok, + Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))), + ))), + }); + + // Emit a check that asserts on the alignment and otherwise triggers a + // AssertKind::MisalignedPointerDereference. + PointerCheck { + cond: Operand::Copy(is_ok), + assert_kind: Box::new(AssertKind::MisalignedPointerDereference { + required: Operand::Copy(alignment), + found: Operand::Copy(addr), + }), + } +} diff --git a/compiler/rustc_mir_transform/src/check_call_recursion.rs b/compiler/rustc_mir_transform/src/check_call_recursion.rs new file mode 100644 index 00000000000..cace4cd6bba --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_call_recursion.rs @@ -0,0 +1,268 @@ +use std::ops::ControlFlow; + +use rustc_data_structures::graph::iterate::{ + NodeStatus, TriColorDepthFirstSearch, TriColorVisitor, +}; +use rustc_hir::LangItem; +use rustc_hir::def::DefKind; +use rustc_middle::mir::{self, BasicBlock, BasicBlocks, Body, Terminator, TerminatorKind}; +use rustc_middle::ty::{self, GenericArg, GenericArgs, Instance, Ty, TyCtxt}; +use rustc_session::lint::builtin::UNCONDITIONAL_RECURSION; +use rustc_span::Span; + +use crate::errors::UnconditionalRecursion; +use crate::pass_manager::MirLint; + +pub(super) struct CheckCallRecursion; + +impl<'tcx> MirLint<'tcx> for CheckCallRecursion { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + let def_id = body.source.def_id().expect_local(); + + if let DefKind::Fn | DefKind::AssocFn = tcx.def_kind(def_id) { + // If this is trait/impl method, extract the trait's args. + let trait_args = match tcx.trait_of_item(def_id.to_def_id()) { + Some(trait_def_id) => { + let trait_args_count = tcx.generics_of(trait_def_id).count(); + &GenericArgs::identity_for_item(tcx, def_id)[..trait_args_count] + } + _ => &[], + }; + + check_recursion(tcx, body, CallRecursion { trait_args }) + } + } +} + +/// Requires drop elaboration to have been performed. +pub(super) struct CheckDropRecursion; + +impl<'tcx> MirLint<'tcx> for CheckDropRecursion { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + let def_id = body.source.def_id().expect_local(); + + // First check if `body` is an `fn drop()` of `Drop` + if let DefKind::AssocFn = tcx.def_kind(def_id) + && let Some(trait_ref) = + tcx.impl_of_method(def_id.to_def_id()).and_then(|def_id| tcx.impl_trait_ref(def_id)) + && tcx.is_lang_item(trait_ref.instantiate_identity().def_id, LangItem::Drop) + // avoid erroneous `Drop` impls from causing ICEs below + && let sig = tcx.fn_sig(def_id).instantiate_identity() + && sig.inputs().skip_binder().len() == 1 + { + // It was. Now figure out for what type `Drop` is implemented and then + // check for recursion. + if let ty::Ref(_, dropped_ty, _) = + tcx.liberate_late_bound_regions(def_id.to_def_id(), sig.input(0)).kind() + { + check_recursion(tcx, body, RecursiveDrop { drop_for: *dropped_ty }); + } + } + } +} + +fn check_recursion<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + classifier: impl TerminatorClassifier<'tcx>, +) { + let def_id = body.source.def_id().expect_local(); + + if let DefKind::Fn | DefKind::AssocFn = tcx.def_kind(def_id) { + let mut vis = Search { tcx, body, classifier, reachable_recursive_calls: vec![] }; + if let Some(NonRecursive) = + TriColorDepthFirstSearch::new(&body.basic_blocks).run_from_start(&mut vis) + { + return; + } + if vis.reachable_recursive_calls.is_empty() { + return; + } + + vis.reachable_recursive_calls.sort(); + + let sp = tcx.def_span(def_id); + let hir_id = tcx.local_def_id_to_hir_id(def_id); + tcx.emit_node_span_lint( + UNCONDITIONAL_RECURSION, + hir_id, + sp, + UnconditionalRecursion { span: sp, call_sites: vis.reachable_recursive_calls }, + ); + } +} + +trait TerminatorClassifier<'tcx> { + fn is_recursive_terminator( + &self, + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + terminator: &Terminator<'tcx>, + ) -> bool; +} + +struct NonRecursive; + +struct Search<'mir, 'tcx, C: TerminatorClassifier<'tcx>> { + tcx: TyCtxt<'tcx>, + body: &'mir Body<'tcx>, + classifier: C, + + reachable_recursive_calls: Vec<Span>, +} + +struct CallRecursion<'tcx> { + trait_args: &'tcx [GenericArg<'tcx>], +} + +struct RecursiveDrop<'tcx> { + /// The type that `Drop` is implemented for. + drop_for: Ty<'tcx>, +} + +impl<'tcx> TerminatorClassifier<'tcx> for CallRecursion<'tcx> { + /// Returns `true` if `func` refers to the function we are searching in. + fn is_recursive_terminator( + &self, + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + terminator: &Terminator<'tcx>, + ) -> bool { + let TerminatorKind::Call { func, args, .. } = &terminator.kind else { + return false; + }; + + // Resolving function type to a specific instance that is being called is expensive. To + // avoid the cost we check the number of arguments first, which is sufficient to reject + // most of calls as non-recursive. + if args.len() != body.arg_count { + return false; + } + let caller = body.source.def_id(); + let typing_env = body.typing_env(tcx); + + let func_ty = func.ty(body, tcx); + if let ty::FnDef(callee, args) = *func_ty.kind() { + let Ok(normalized_args) = tcx.try_normalize_erasing_regions(typing_env, args) else { + return false; + }; + let (callee, call_args) = if let Ok(Some(instance)) = + Instance::try_resolve(tcx, typing_env, callee, normalized_args) + { + (instance.def_id(), instance.args) + } else { + (callee, normalized_args) + }; + + // FIXME(#57965): Make this work across function boundaries + + // If this is a trait fn, the args on the trait have to match, or we might be + // calling into an entirely different method (for example, a call from the default + // method in the trait to `<A as Trait<B>>::method`, where `A` and/or `B` are + // specific types). + return callee == caller && &call_args[..self.trait_args.len()] == self.trait_args; + } + + false + } +} + +impl<'tcx> TerminatorClassifier<'tcx> for RecursiveDrop<'tcx> { + fn is_recursive_terminator( + &self, + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + terminator: &Terminator<'tcx>, + ) -> bool { + let TerminatorKind::Drop { place, .. } = &terminator.kind else { return false }; + + let dropped_ty = place.ty(body, tcx).ty; + dropped_ty == self.drop_for + } +} + +impl<'mir, 'tcx, C: TerminatorClassifier<'tcx>> TriColorVisitor<BasicBlocks<'tcx>> + for Search<'mir, 'tcx, C> +{ + type BreakVal = NonRecursive; + + fn node_examined( + &mut self, + bb: BasicBlock, + prior_status: Option<NodeStatus>, + ) -> ControlFlow<Self::BreakVal> { + // Back-edge in the CFG (loop). + if let Some(NodeStatus::Visited) = prior_status { + return ControlFlow::Break(NonRecursive); + } + + match self.body[bb].terminator().kind { + // These terminators return control flow to the caller. + TerminatorKind::UnwindTerminate(_) + | TerminatorKind::CoroutineDrop + | TerminatorKind::UnwindResume + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Yield { .. } => ControlFlow::Break(NonRecursive), + + // A InlineAsm without targets (diverging and contains no labels) + // is treated as non-recursing. + TerminatorKind::InlineAsm { ref targets, .. } => { + if !targets.is_empty() { + ControlFlow::Continue(()) + } else { + ControlFlow::Break(NonRecursive) + } + } + + // These do not. + TerminatorKind::Assert { .. } + | TerminatorKind::Call { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::Goto { .. } + | TerminatorKind::SwitchInt { .. } => ControlFlow::Continue(()), + + // Note that tail call terminator technically returns to the caller, + // but for purposes of this lint it makes sense to count it as possibly recursive, + // since it's still a call. + // + // If this'll be repurposed for something else, this might need to be changed. + TerminatorKind::TailCall { .. } => ControlFlow::Continue(()), + } + } + + fn node_settled(&mut self, bb: BasicBlock) -> ControlFlow<Self::BreakVal> { + // When we examine a node for the last time, remember it if it is a recursive call. + let terminator = self.body[bb].terminator(); + + // FIXME(explicit_tail_calls): highlight tail calls as "recursive call site" + // + // We don't want to lint functions that recurse only through tail calls + // (such as `fn g() { become () }`), so just adding `| TailCall { ... }` + // here won't work. + // + // But at the same time we would like to highlight both calls in a function like + // `fn f() { if false { become f() } else { f() } }`, so we need to figure something out. + if self.classifier.is_recursive_terminator(self.tcx, self.body, terminator) { + self.reachable_recursive_calls.push(terminator.source_info.span); + } + + ControlFlow::Continue(()) + } + + fn ignore_edge(&mut self, bb: BasicBlock, target: BasicBlock) -> bool { + let terminator = self.body[bb].terminator(); + let ignore_unwind = terminator.unwind() == Some(&mir::UnwindAction::Cleanup(target)) + && terminator.successors().count() > 1; + if ignore_unwind || self.classifier.is_recursive_terminator(self.tcx, self.body, terminator) + { + return true; + } + match &terminator.kind { + TerminatorKind::FalseEdge { imaginary_target, .. } => imaginary_target == &target, + _ => false, + } + } +} diff --git a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs new file mode 100644 index 00000000000..375db17fb73 --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs @@ -0,0 +1,164 @@ +use rustc_hir::HirId; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_session::lint::builtin::CONST_ITEM_MUTATION; +use rustc_span::Span; +use rustc_span::def_id::DefId; + +use crate::errors; + +pub(super) struct CheckConstItemMutation; + +impl<'tcx> crate::MirLint<'tcx> for CheckConstItemMutation { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + let mut checker = ConstMutationChecker { body, tcx, target_local: None }; + checker.visit_body(body); + } +} + +struct ConstMutationChecker<'a, 'tcx> { + body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, + target_local: Option<Local>, +} + +impl<'tcx> ConstMutationChecker<'_, 'tcx> { + fn is_const_item(&self, local: Local) -> Option<DefId> { + if let LocalInfo::ConstRef { def_id } = *self.body.local_decls[local].local_info() { + Some(def_id) + } else { + None + } + } + + fn is_const_item_without_destructor(&self, local: Local) -> Option<DefId> { + let def_id = self.is_const_item(local)?; + + // We avoid linting mutation of a const item if the const's type has a + // Drop impl. The Drop logic observes the mutation which was performed. + // + // pub struct Log { msg: &'static str } + // pub const LOG: Log = Log { msg: "" }; + // impl Drop for Log { + // fn drop(&mut self) { println!("{}", self.msg); } + // } + // + // LOG.msg = "wow"; // prints "wow" + // + // FIXME(https://github.com/rust-lang/rust/issues/77425): + // Drop this exception once there is a stable attribute to suppress the + // const item mutation lint for a single specific const only. Something + // equivalent to: + // + // #[const_mutation_allowed] + // pub const LOG: Log = Log { msg: "" }; + // FIXME: this should not be checking for `Drop` impls, + // but whether it or any field has a Drop impl (`needs_drop`) + // as fields' Drop impls may make this observable, too. + match self.tcx.type_of(def_id).skip_binder().ty_adt_def().map(|adt| adt.has_dtor(self.tcx)) + { + Some(true) => None, + Some(false) | None => Some(def_id), + } + } + + /// If we should lint on this usage, return the [`HirId`], source [`Span`] + /// and [`Span`] of the const item to use in the lint. + fn should_lint_const_item_usage( + &self, + place: &Place<'tcx>, + const_item: DefId, + location: Location, + ) -> Option<(HirId, Span, Span)> { + // Don't lint on borrowing/assigning when a dereference is involved. + // If we 'leave' the temporary via a dereference, we must + // be modifying something else + // + // `unsafe { *FOO = 0; *BAR.field = 1; }` + // `unsafe { &mut *FOO }` + // `unsafe { (*ARRAY)[0] = val; }` + if !place.projection.iter().any(|p| matches!(p, PlaceElem::Deref)) { + let source_info = self.body.source_info(location); + let lint_root = self.body.source_scopes[source_info.scope] + .local_data + .as_ref() + .unwrap_crate_local() + .lint_root; + + Some((lint_root, source_info.span, self.tcx.def_span(const_item))) + } else { + None + } + } +} + +impl<'tcx> Visitor<'tcx> for ConstMutationChecker<'_, 'tcx> { + fn visit_statement(&mut self, stmt: &Statement<'tcx>, loc: Location) { + if let StatementKind::Assign(box (lhs, _)) = &stmt.kind { + // Check for assignment to fields of a constant + // Assigning directly to a constant (e.g. `FOO = true;`) is a hard error, + // so emitting a lint would be redundant. + if !lhs.projection.is_empty() + && let Some(def_id) = self.is_const_item_without_destructor(lhs.local) + && let Some((lint_root, span, item)) = + self.should_lint_const_item_usage(lhs, def_id, loc) + { + self.tcx.emit_node_span_lint( + CONST_ITEM_MUTATION, + lint_root, + span, + errors::ConstMutate::Modify { konst: item }, + ); + } + + // We are looking for MIR of the form: + // + // ``` + // _1 = const FOO; + // _2 = &mut _1; + // method_call(_2, ..) + // ``` + // + // Record our current LHS, so that we can detect this + // pattern in `visit_rvalue` + self.target_local = lhs.as_local(); + } + self.super_statement(stmt, loc); + self.target_local = None; + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, loc: Location) { + if let Rvalue::Ref(_, BorrowKind::Mut { .. }, place) = rvalue { + let local = place.local; + if let Some(def_id) = self.is_const_item(local) { + // If this Rvalue is being used as the right-hand side of a + // `StatementKind::Assign`, see if it ends up getting used as + // the `self` parameter of a method call (as the terminator of our current + // BasicBlock). If so, we emit a more specific lint. + let method_did = self.target_local.and_then(|target_local| { + find_self_call(self.tcx, self.body, target_local, loc.block) + }); + let lint_loc = + if method_did.is_some() { self.body.terminator_loc(loc.block) } else { loc }; + + let method_call = if let Some((method_did, _)) = method_did { + Some(self.tcx.def_span(method_did)) + } else { + None + }; + if let Some((lint_root, span, item)) = + self.should_lint_const_item_usage(place, def_id, lint_loc) + { + self.tcx.emit_node_span_lint( + CONST_ITEM_MUTATION, + lint_root, + span, + errors::ConstMutate::MutBorrow { method_call, konst: item }, + ); + } + } + } + self.super_rvalue(rvalue, loc); + } +} diff --git a/compiler/rustc_mir_transform/src/check_enums.rs b/compiler/rustc_mir_transform/src/check_enums.rs new file mode 100644 index 00000000000..e06e0c6122e --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_enums.rs @@ -0,0 +1,501 @@ +use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange}; +use rustc_hir::LangItem; +use rustc_index::IndexVec; +use rustc_middle::bug; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::ty::layout::PrimitiveExt; +use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv}; +use rustc_session::Session; +use tracing::debug; + +/// This pass inserts checks for a valid enum discriminant where they are most +/// likely to find UB, because checking everywhere like Miri would generate too +/// much MIR. +pub(super) struct CheckEnums; + +impl<'tcx> crate::MirPass<'tcx> for CheckEnums { + fn is_enabled(&self, sess: &Session) -> bool { + sess.ub_checks() + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // This pass emits new panics. If for whatever reason we do not have a panic + // implementation, running this pass may cause otherwise-valid code to not compile. + if tcx.lang_items().get(LangItem::PanicImpl).is_none() { + return; + } + + let typing_env = body.typing_env(tcx); + let basic_blocks = body.basic_blocks.as_mut(); + let local_decls = &mut body.local_decls; + + // This operation inserts new blocks. Each insertion changes the Location for all + // statements/blocks after. Iterating or visiting the MIR in order would require updating + // our current location after every insertion. By iterating backwards, we dodge this issue: + // The only Locations that an insertion changes have already been handled. + for block in basic_blocks.indices().rev() { + for statement_index in (0..basic_blocks[block].statements.len()).rev() { + let location = Location { block, statement_index }; + let statement = &basic_blocks[block].statements[statement_index]; + let source_info = statement.source_info; + + let mut finder = EnumFinder::new(tcx, local_decls, typing_env); + finder.visit_statement(statement, location); + + for check in finder.into_found_enums() { + debug!("Inserting enum check"); + let new_block = split_block(basic_blocks, location); + + match check { + EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => { + insert_direct_enum_check( + tcx, + local_decls, + basic_blocks, + block, + source_op, + discr, + op_size, + valid_discrs, + source_info, + new_block, + ) + } + EnumCheckType::Uninhabited => insert_uninhabited_enum_check( + tcx, + local_decls, + &mut basic_blocks[block], + source_info, + new_block, + ), + EnumCheckType::WithNiche { + source_op, + discr, + op_size, + offset, + valid_range, + } => insert_niche_check( + tcx, + local_decls, + &mut basic_blocks[block], + source_op, + valid_range, + discr, + op_size, + offset, + source_info, + new_block, + ), + } + } + } + } + } + + fn is_required(&self) -> bool { + true + } +} + +/// Represent the different kind of enum checks we can insert. +enum EnumCheckType<'tcx> { + /// We know we try to create an uninhabited enum from an inhabited variant. + Uninhabited, + /// We know the enum does no niche optimizations and can thus easily compute + /// the valid discriminants. + Direct { + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + valid_discrs: Vec<u128>, + }, + /// We try to construct an enum that has a niche. + WithNiche { + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + offset: Size, + valid_range: WrappingRange, + }, +} + +struct TyAndSize<'tcx> { + pub ty: Ty<'tcx>, + pub size: Size, +} + +/// A [Visitor] that finds the construction of enums and evaluates which checks +/// we should apply. +struct EnumFinder<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + local_decls: &'a mut LocalDecls<'tcx>, + typing_env: TypingEnv<'tcx>, + enums: Vec<EnumCheckType<'tcx>>, +} + +impl<'a, 'tcx> EnumFinder<'a, 'tcx> { + fn new( + tcx: TyCtxt<'tcx>, + local_decls: &'a mut LocalDecls<'tcx>, + typing_env: TypingEnv<'tcx>, + ) -> Self { + EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() } + } + + /// Returns the found enum creations and which checks should be inserted. + fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> { + self.enums + } +} + +impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> { + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue { + let ty::Adt(adt_def, _) = ty.kind() else { + return; + }; + if !adt_def.is_enum() { + return; + } + + let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else { + return; + }; + let Ok(op_layout) = self + .tcx + .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx))) + else { + return; + }; + + match enum_layout.variants { + Variants::Empty if op_layout.is_uninhabited() => return, + // An empty enum that tries to be constructed from an inhabited value, this + // is never correct. + Variants::Empty => { + // The enum layout is uninhabited but we construct it from sth inhabited. + // This is always UB. + self.enums.push(EnumCheckType::Uninhabited); + } + // Construction of Single value enums is always fine. + Variants::Single { .. } => {} + // Construction of an enum with multiple variants but no niche optimizations. + Variants::Multiple { + tag_encoding: TagEncoding::Direct, + tag: Scalar::Initialized { value, .. }, + .. + } => { + let valid_discrs = + adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect(); + + let discr = + TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) }; + self.enums.push(EnumCheckType::Direct { + source_op: op.to_copy(), + discr, + op_size: op_layout.size, + valid_discrs, + }); + } + // Construction of an enum with multiple variants and niche optimizations. + Variants::Multiple { + tag_encoding: TagEncoding::Niche { .. }, + tag: Scalar::Initialized { value, valid_range, .. }, + tag_field, + .. + } => { + let discr = + TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) }; + self.enums.push(EnumCheckType::WithNiche { + source_op: op.to_copy(), + discr, + op_size: op_layout.size, + offset: enum_layout.fields.offset(tag_field.as_usize()), + valid_range, + }); + } + _ => return, + } + + self.super_rvalue(rvalue, location); + } + } +} + +fn split_block( + basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, + location: Location, +) -> BasicBlock { + let block_data = &mut basic_blocks[location.block]; + + // Drain every statement after this one and move the current terminator to a new basic block. + let new_block = BasicBlockData { + statements: block_data.statements.split_off(location.statement_index), + terminator: block_data.terminator.take(), + is_cleanup: block_data.is_cleanup, + }; + + basic_blocks.push(new_block) +} + +/// Inserts the cast of an operand (any type) to a u128 value that holds the discriminant value. +fn insert_discr_cast_to_u128<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + block_data: &mut BasicBlockData<'tcx>, + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + offset: Option<Size>, + source_info: SourceInfo, +) -> Place<'tcx> { + let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> { + match size.bytes() { + 1 => tcx.types.u8, + 2 => tcx.types.u16, + 4 => tcx.types.u32, + 8 => tcx.types.u64, + 16 => tcx.types.u128, + invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid), + } + }; + + let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() { + // The discriminant is less wide than the operand, cast the operand into + // [MaybeUninit; N] and then index into it. + let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8); + let array_len = op_size.bytes(); + let mu_array_ty = Ty::new_array(tcx, mu, array_len); + let mu_array = + local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into(); + let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((mu_array, rvalue))), + }); + + // Index into the array of MaybeUninit to get something that is actually + // as wide as the discriminant. + let offset = offset.unwrap_or(Size::ZERO); + let smaller_mu_array = mu_array.project_deeper( + &[ProjectionElem::Subslice { + from: offset.bytes(), + to: offset.bytes() + discr.size.bytes(), + from_end: false, + }], + tcx, + ); + + (CastKind::Transmute, Operand::Copy(smaller_mu_array)) + } else { + let operand_int_ty = get_ty_for_size(tcx, op_size); + + let op_as_int = + local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into(); + let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((op_as_int, rvalue))), + }); + + (CastKind::IntToInt, Operand::Copy(op_as_int)) + }; + + // Cast the resulting value to the actual discriminant integer type. + let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty); + let discr_in_discr_ty = + local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))), + }); + + // Cast the discriminant to a u128 (base for comparisions of enum discriminants). + let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128); + let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128); + let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into(); + block_data + .statements + .push(Statement { source_info, kind: StatementKind::Assign(Box::new((discr, rvalue))) }); + + discr +} + +fn insert_direct_enum_check<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>, + current_block: BasicBlock, + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + discriminants: Vec<u128>, + source_info: SourceInfo, + new_block: BasicBlock, +) { + // Insert a new target block that is branched to in case of an invalid discriminant. + let invalid_discr_block_data = BasicBlockData::new(None, false); + let invalid_discr_block = basic_blocks.push(invalid_discr_block_data); + let block_data = &mut basic_blocks[current_block]; + let discr = insert_discr_cast_to_u128( + tcx, + local_decls, + block_data, + source_op, + discr, + op_size, + None, + source_info, + ); + + // Branch based on the discriminant value. + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::SwitchInt { + discr: Operand::Copy(discr), + targets: SwitchTargets::new( + discriminants.into_iter().map(|discr| (discr, new_block)), + invalid_discr_block, + ), + }, + }); + + // Abort in case of an invalid enum discriminant. + basic_blocks[invalid_discr_block].terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool), + })), + expected: true, + target: new_block, + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))), + // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. + // We never want to insert an unwind into unsafe code, because unwinding could + // make a failing UB check turn into much worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); +} + +fn insert_uninhabited_enum_check<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + block_data: &mut BasicBlockData<'tcx>, + source_info: SourceInfo, + new_block: BasicBlock, +) { + let is_ok: Place<'_> = + local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_ok, + Rvalue::Use(Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool), + }))), + ))), + }); + + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: Operand::Copy(is_ok), + expected: true, + target: new_block, + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new( + ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128), + }, + )))), + // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. + // We never want to insert an unwind into unsafe code, because unwinding could + // make a failing UB check turn into much worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); +} + +fn insert_niche_check<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + block_data: &mut BasicBlockData<'tcx>, + source_op: Operand<'tcx>, + valid_range: WrappingRange, + discr: TyAndSize<'tcx>, + op_size: Size, + offset: Size, + source_info: SourceInfo, + new_block: BasicBlock, +) { + let discr = insert_discr_cast_to_u128( + tcx, + local_decls, + block_data, + source_op, + discr, + op_size, + Some(offset), + source_info, + ); + + // Compare the discriminant agains the valid_range. + let start_const = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128), + })); + let end_start_diff_const = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val( + ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)), + tcx.types.u128, + ), + })); + + let discr_diff: Place<'_> = + local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + discr_diff, + Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))), + ))), + }); + + let is_ok: Place<'_> = + local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_ok, + Rvalue::BinaryOp( + // This is a `WrappingRange`, so make sure to get the wrapping right. + BinOp::Le, + Box::new((Operand::Copy(discr_diff), end_start_diff_const)), + ), + ))), + }); + + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: Operand::Copy(is_ok), + expected: true, + target: new_block, + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))), + // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. + // We never want to insert an unwind into unsafe code, because unwinding could + // make a failing UB check turn into much worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); +} diff --git a/compiler/rustc_mir_transform/src/check_inline.rs b/compiler/rustc_mir_transform/src/check_inline.rs new file mode 100644 index 00000000000..14d9532894f --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_inline.rs @@ -0,0 +1,91 @@ +//! Check that a body annotated with `#[rustc_force_inline]` will not fail to inline based on its +//! definition alone (irrespective of any specific caller). + +use rustc_attr_data_structures::InlineAttr; +use rustc_hir::def_id::DefId; +use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; +use rustc_middle::mir::{Body, TerminatorKind}; +use rustc_middle::ty; +use rustc_middle::ty::TyCtxt; +use rustc_span::sym; + +use crate::pass_manager::MirLint; + +pub(super) struct CheckForceInline; + +impl<'tcx> MirLint<'tcx> for CheckForceInline { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + let def_id = body.source.def_id(); + if !tcx.hir_body_owner_kind(def_id).is_fn_or_closure() || !def_id.is_local() { + return; + } + let InlineAttr::Force { attr_span, .. } = tcx.codegen_fn_attrs(def_id).inline else { + return; + }; + + if let Err(reason) = + is_inline_valid_on_fn(tcx, def_id).and_then(|_| is_inline_valid_on_body(tcx, body)) + { + tcx.dcx().emit_err(crate::errors::InvalidForceInline { + attr_span, + callee_span: tcx.def_span(def_id), + callee: tcx.def_path_str(def_id), + reason, + }); + } + } +} + +pub(super) fn is_inline_valid_on_fn<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: DefId, +) -> Result<(), &'static str> { + let codegen_attrs = tcx.codegen_fn_attrs(def_id); + if tcx.has_attr(def_id, sym::rustc_no_mir_inline) { + return Err("#[rustc_no_mir_inline]"); + } + + // FIXME(#127234): Coverage instrumentation currently doesn't handle inlined + // MIR correctly when Modified Condition/Decision Coverage is enabled. + if tcx.sess.instrument_coverage_mcdc() { + return Err("incompatible with MC/DC coverage"); + } + + let ty = tcx.type_of(def_id); + if match ty.instantiate_identity().kind() { + ty::FnDef(..) => tcx.fn_sig(def_id).instantiate_identity().c_variadic(), + ty::Closure(_, args) => args.as_closure().sig().c_variadic(), + _ => false, + } { + return Err("C variadic"); + } + + if codegen_attrs.flags.contains(CodegenFnAttrFlags::COLD) { + return Err("cold"); + } + + // Intrinsic fallback bodies are automatically made cross-crate inlineable, + // but at this stage we don't know whether codegen knows the intrinsic, + // so just conservatively don't inline it. This also ensures that we do not + // accidentally inline the body of an intrinsic that *must* be overridden. + if tcx.has_attr(def_id, sym::rustc_intrinsic) { + return Err("callee is an intrinsic"); + } + + Ok(()) +} + +pub(super) fn is_inline_valid_on_body<'tcx>( + _: TyCtxt<'tcx>, + body: &Body<'tcx>, +) -> Result<(), &'static str> { + if body + .basic_blocks + .iter() + .any(|bb| matches!(bb.terminator().kind, TerminatorKind::TailCall { .. })) + { + return Err("can't inline functions with tail calls"); + } + + Ok(()) +} diff --git a/compiler/rustc_mir_transform/src/check_null.rs b/compiler/rustc_mir_transform/src/check_null.rs new file mode 100644 index 00000000000..ad74e335bd9 --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_null.rs @@ -0,0 +1,139 @@ +use rustc_index::IndexVec; +use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext}; +use rustc_middle::mir::*; +use rustc_middle::ty::{Ty, TyCtxt}; +use rustc_session::Session; + +use crate::check_pointers::{BorrowedFieldProjectionMode, PointerCheck, check_pointers}; + +pub(super) struct CheckNull; + +impl<'tcx> crate::MirPass<'tcx> for CheckNull { + fn is_enabled(&self, sess: &Session) -> bool { + sess.ub_checks() + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + check_pointers( + tcx, + body, + &[], + insert_null_check, + BorrowedFieldProjectionMode::NoFollowProjections, + ); + } + + fn is_required(&self) -> bool { + true + } +} + +fn insert_null_check<'tcx>( + tcx: TyCtxt<'tcx>, + pointer: Place<'tcx>, + pointee_ty: Ty<'tcx>, + context: PlaceContext, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + stmts: &mut Vec<Statement<'tcx>>, + source_info: SourceInfo, +) -> PointerCheck<'tcx> { + // Cast the pointer to a *const (). + let const_raw_ptr = Ty::new_imm_ptr(tcx, tcx.types.unit); + let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr); + let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into(); + stmts + .push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) }); + + // Transmute the pointer to a usize (equivalent to `ptr.addr()`). + let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize); + let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) }); + + let zero = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_target_usize(0, &tcx), tcx.types.usize), + })); + + let pointee_should_be_checked = match context { + // Borrows pointing to "null" are UB even if the pointee is a ZST. + PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow) + | PlaceContext::MutatingUse(MutatingUseContext::Borrow) => { + // Pointer should be checked unconditionally. + Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_bool(true), tcx.types.bool), + })) + } + // Other usages of null pointers only are UB if the pointee is not a ZST. + _ => { + let rvalue = Rvalue::NullaryOp(NullOp::SizeOf, pointee_ty); + let sizeof_pointee = + local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((sizeof_pointee, rvalue))), + }); + + // Check that the pointee is not a ZST. + let is_pointee_not_zst = + local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_pointee_not_zst, + Rvalue::BinaryOp( + BinOp::Ne, + Box::new((Operand::Copy(sizeof_pointee), zero.clone())), + ), + ))), + }); + + // Pointer needs to be checked only if pointee is not a ZST. + Operand::Copy(is_pointee_not_zst) + } + }; + + // Check whether the pointer is null. + let is_null = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_null, + Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(addr), zero))), + ))), + }); + + // We want to throw an exception if the pointer is null and the pointee is not unconditionally + // allowed (which for all non-borrow place uses, is when the pointee is ZST). + let should_throw_exception = + local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + should_throw_exception, + Rvalue::BinaryOp( + BinOp::BitAnd, + Box::new((Operand::Copy(is_null), pointee_should_be_checked)), + ), + ))), + }); + + // The final condition whether this pointer usage is ok or not. + let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + stmts.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_ok, + Rvalue::UnaryOp(UnOp::Not, Operand::Copy(should_throw_exception)), + ))), + }); + + // Emit a PointerCheck that asserts on the condition and otherwise triggers + // a AssertKind::NullPointerDereference. + PointerCheck { + cond: Operand::Copy(is_ok), + assert_kind: Box::new(AssertKind::NullPointerDereference), + } +} diff --git a/compiler/rustc_mir_transform/src/check_packed_ref.rs b/compiler/rustc_mir_transform/src/check_packed_ref.rs new file mode 100644 index 00000000000..e9b85ba6e9d --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_packed_ref.rs @@ -0,0 +1,55 @@ +use rustc_middle::mir::visit::{PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::span_bug; +use rustc_middle::ty::{self, TyCtxt}; + +use crate::{errors, util}; + +pub(super) struct CheckPackedRef; + +impl<'tcx> crate::MirLint<'tcx> for CheckPackedRef { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + let typing_env = body.typing_env(tcx); + let source_info = SourceInfo::outermost(body.span); + let mut checker = PackedRefChecker { body, tcx, typing_env, source_info }; + checker.visit_body(body); + } +} + +struct PackedRefChecker<'a, 'tcx> { + body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + source_info: SourceInfo, +} + +impl<'tcx> Visitor<'tcx> for PackedRefChecker<'_, 'tcx> { + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + // Make sure we know where in the MIR we are. + self.source_info = terminator.source_info; + self.super_terminator(terminator, location); + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + // Make sure we know where in the MIR we are. + self.source_info = statement.source_info; + self.super_statement(statement, location); + } + + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { + if context.is_borrow() && util::is_disaligned(self.tcx, self.body, self.typing_env, *place) + { + let def_id = self.body.source.instance.def_id(); + if let Some(impl_def_id) = self.tcx.impl_of_method(def_id) + && self.tcx.is_builtin_derived(impl_def_id) + { + // If we ever reach here it means that the generated derive + // code is somehow doing an unaligned reference, which it + // shouldn't do. + span_bug!(self.source_info.span, "builtin derive created an unaligned reference"); + } else { + self.tcx.dcx().emit_err(errors::UnalignedPackedRef { span: self.source_info.span }); + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/check_pointers.rs b/compiler/rustc_mir_transform/src/check_pointers.rs new file mode 100644 index 00000000000..bf94f1aad24 --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_pointers.rs @@ -0,0 +1,245 @@ +use rustc_hir::lang_items::LangItem; +use rustc_index::IndexVec; +use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use tracing::{debug, trace}; + +/// Details of a pointer check, the condition on which we decide whether to +/// fail the assert and an [AssertKind] that defines the behavior on failure. +pub(crate) struct PointerCheck<'tcx> { + pub(crate) cond: Operand<'tcx>, + pub(crate) assert_kind: Box<AssertKind<Operand<'tcx>>>, +} + +/// When checking for borrows of field projections (`&(*ptr).a`), we might want +/// to check for the field type (type of `.a` in the example). This enum defines +/// the variations (pass the pointer [Ty] or the field [Ty]). +#[derive(Copy, Clone)] +pub(crate) enum BorrowedFieldProjectionMode { + FollowProjections, + NoFollowProjections, +} + +/// Utility for adding a check for read/write on every sized, raw pointer. +/// +/// Visits every read/write access to a [Sized], raw pointer and inserts a +/// new basic block directly before the pointer access. (Read/write accesses +/// are determined by the `PlaceContext` of the MIR visitor.) Then calls +/// `on_finding` to insert the actual logic for a pointer check (e.g. check for +/// alignment). A check can choose to follow borrows of field projections via +/// the `field_projection_mode` parameter. +/// +/// This utility takes care of the right order of blocks, the only thing a +/// caller must do in `on_finding` is: +/// - Append [Statement]s to `stmts`. +/// - Append [LocalDecl]s to `local_decls`. +/// - Return a [PointerCheck] that contains the condition and an [AssertKind]. +/// The AssertKind must be a panic with `#[rustc_nounwind]`. The condition +/// should always return the boolean `is_ok`, so evaluate to true in case of +/// success and fail the check otherwise. +/// This utility will insert a terminator block that asserts on the condition +/// and panics on failure. +pub(crate) fn check_pointers<'tcx, F>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + excluded_pointees: &[Ty<'tcx>], + on_finding: F, + field_projection_mode: BorrowedFieldProjectionMode, +) where + F: Fn( + /* tcx: */ TyCtxt<'tcx>, + /* pointer: */ Place<'tcx>, + /* pointee_ty: */ Ty<'tcx>, + /* context: */ PlaceContext, + /* local_decls: */ &mut IndexVec<Local, LocalDecl<'tcx>>, + /* stmts: */ &mut Vec<Statement<'tcx>>, + /* source_info: */ SourceInfo, + ) -> PointerCheck<'tcx>, +{ + // This pass emits new panics. If for whatever reason we do not have a panic + // implementation, running this pass may cause otherwise-valid code to not compile. + if tcx.lang_items().get(LangItem::PanicImpl).is_none() { + return; + } + + let typing_env = body.typing_env(tcx); + let basic_blocks = body.basic_blocks.as_mut(); + let local_decls = &mut body.local_decls; + + // This operation inserts new blocks. Each insertion changes the Location for all + // statements/blocks after. Iterating or visiting the MIR in order would require updating + // our current location after every insertion. By iterating backwards, we dodge this issue: + // The only Locations that an insertion changes have already been handled. + for block in basic_blocks.indices().rev() { + for statement_index in (0..basic_blocks[block].statements.len()).rev() { + let location = Location { block, statement_index }; + let statement = &basic_blocks[block].statements[statement_index]; + let source_info = statement.source_info; + + let mut finder = PointerFinder::new( + tcx, + local_decls, + typing_env, + excluded_pointees, + field_projection_mode, + ); + finder.visit_statement(statement, location); + + for (local, ty, context) in finder.into_found_pointers() { + debug!("Inserting check for {:?}", ty); + let new_block = split_block(basic_blocks, location); + + // Invoke `on_finding` which appends to `local_decls` and the + // blocks statements. It returns information about the assert + // we're performing in the Terminator. + let block_data = &mut basic_blocks[block]; + let pointer_check = on_finding( + tcx, + local, + ty, + context, + local_decls, + &mut block_data.statements, + source_info, + ); + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: pointer_check.cond, + expected: true, + target: new_block, + msg: pointer_check.assert_kind, + // This calls a panic function associated with the pointer check, which + // is #[rustc_nounwind]. We never want to insert an unwind into unsafe + // code, because unwinding could make a failing UB check turn into much + // worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); + } + } + } +} + +struct PointerFinder<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + local_decls: &'a mut LocalDecls<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + pointers: Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)>, + excluded_pointees: &'a [Ty<'tcx>], + field_projection_mode: BorrowedFieldProjectionMode, +} + +impl<'a, 'tcx> PointerFinder<'a, 'tcx> { + fn new( + tcx: TyCtxt<'tcx>, + local_decls: &'a mut LocalDecls<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + excluded_pointees: &'a [Ty<'tcx>], + field_projection_mode: BorrowedFieldProjectionMode, + ) -> Self { + PointerFinder { + tcx, + local_decls, + typing_env, + excluded_pointees, + pointers: Vec::new(), + field_projection_mode, + } + } + + fn into_found_pointers(self) -> Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)> { + self.pointers + } + + /// Whether or not we should visit a [Place] with [PlaceContext]. + /// + /// We generally only visit Reads/Writes to a place and only Borrows if + /// requested. + fn should_visit_place(&self, context: PlaceContext) -> bool { + match context { + PlaceContext::MutatingUse( + MutatingUseContext::Store + | MutatingUseContext::Call + | MutatingUseContext::Yield + | MutatingUseContext::Drop + | MutatingUseContext::Borrow, + ) => true, + PlaceContext::NonMutatingUse( + NonMutatingUseContext::Copy + | NonMutatingUseContext::Move + | NonMutatingUseContext::SharedBorrow, + ) => true, + _ => false, + } + } +} + +impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> { + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { + if !self.should_visit_place(context) || !place.is_indirect() { + return; + } + + // Get the place and type we visit. + let pointer = Place::from(place.local); + let pointer_ty = pointer.ty(self.local_decls, self.tcx).ty; + + // We only want to check places based on raw pointers + let &ty::RawPtr(mut pointee_ty, _) = pointer_ty.kind() else { + trace!("Indirect, but not based on an raw ptr, not checking {:?}", place); + return; + }; + + // If we see a borrow of a field projection, we want to pass the field type to the + // check and not the pointee type. + if matches!(self.field_projection_mode, BorrowedFieldProjectionMode::FollowProjections) + && matches!( + context, + PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow) + | PlaceContext::MutatingUse(MutatingUseContext::Borrow) + ) + { + // Naturally, the field type is type of the initial place we look at. + pointee_ty = place.ty(self.local_decls, self.tcx).ty; + } + + // Ideally we'd support this in the future, but for now we are limited to sized types. + if !pointee_ty.is_sized(self.tcx, self.typing_env) { + trace!("Raw pointer, but pointee is not known to be sized: {:?}", pointer_ty); + return; + } + + // We don't need to look for slices, we already rejected unsized types above. + let element_ty = match pointee_ty.kind() { + ty::Array(ty, _) => *ty, + _ => pointee_ty, + }; + // Check if we excluded this pointee type from the check. + if self.excluded_pointees.contains(&element_ty) { + trace!("Skipping pointer for type: {:?}", pointee_ty); + return; + } + + self.pointers.push((pointer, pointee_ty, context)); + + self.super_place(place, context, location); + } +} + +fn split_block( + basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, + location: Location, +) -> BasicBlock { + let block_data = &mut basic_blocks[location.block]; + + // Drain every statement after this one and move the current terminator to a new basic block. + let new_block = BasicBlockData { + statements: block_data.statements.split_off(location.statement_index), + terminator: block_data.terminator.take(), + is_cleanup: block_data.is_cleanup, + }; + + basic_blocks.push(new_block) +} diff --git a/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs b/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs new file mode 100644 index 00000000000..4be67b873f7 --- /dev/null +++ b/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs @@ -0,0 +1,80 @@ +//! This module provides a pass that removes parts of MIR that are no longer relevant after +//! analysis phase and borrowck. In particular, it removes false edges, user type annotations and +//! replaces following statements with [`Nop`]s: +//! +//! - [`AscribeUserType`] +//! - [`FakeRead`] +//! - [`Assign`] statements with a [`Fake`] borrow +//! - [`Coverage`] statements of kind [`BlockMarker`] or [`SpanMarker`] +//! +//! [`AscribeUserType`]: rustc_middle::mir::StatementKind::AscribeUserType +//! [`Assign`]: rustc_middle::mir::StatementKind::Assign +//! [`FakeRead`]: rustc_middle::mir::StatementKind::FakeRead +//! [`Nop`]: rustc_middle::mir::StatementKind::Nop +//! [`Fake`]: rustc_middle::mir::BorrowKind::Fake +//! [`Coverage`]: rustc_middle::mir::StatementKind::Coverage +//! [`BlockMarker`]: rustc_middle::mir::coverage::CoverageKind::BlockMarker +//! [`SpanMarker`]: rustc_middle::mir::coverage::CoverageKind::SpanMarker + +use rustc_middle::mir::coverage::CoverageKind; +use rustc_middle::mir::{Body, BorrowKind, CastKind, Rvalue, StatementKind, TerminatorKind}; +use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::adjustment::PointerCoercion; + +pub(super) struct CleanupPostBorrowck; + +impl<'tcx> crate::MirPass<'tcx> for CleanupPostBorrowck { + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + for basic_block in body.basic_blocks.as_mut() { + for statement in basic_block.statements.iter_mut() { + match statement.kind { + StatementKind::AscribeUserType(..) + | StatementKind::Assign(box (_, Rvalue::Ref(_, BorrowKind::Fake(_), _))) + | StatementKind::Coverage( + // These kinds of coverage statements are markers inserted during + // MIR building, and are not needed after InstrumentCoverage. + CoverageKind::BlockMarker { .. } | CoverageKind::SpanMarker { .. }, + ) + | StatementKind::FakeRead(..) + | StatementKind::BackwardIncompatibleDropHint { .. } => statement.make_nop(), + StatementKind::Assign(box ( + _, + Rvalue::Cast( + ref mut cast_kind @ CastKind::PointerCoercion( + PointerCoercion::ArrayToPointer + | PointerCoercion::MutToConstPointer, + _, + ), + .., + ), + )) => { + // BorrowCk needed to track whether these cases were coercions or casts, + // to know whether to check lifetimes in their pointees, + // but from now on that distinction doesn't matter, + // so just make them ordinary pointer casts instead. + *cast_kind = CastKind::PtrToPtr; + } + _ => (), + } + } + let terminator = basic_block.terminator_mut(); + match terminator.kind { + TerminatorKind::FalseEdge { real_target, .. } + | TerminatorKind::FalseUnwind { real_target, .. } => { + terminator.kind = TerminatorKind::Goto { target: real_target }; + } + _ => {} + } + } + + body.user_type_annotations.raw.clear(); + + for decl in &mut body.local_decls { + decl.user_ty = None; + } + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/copy_prop.rs b/compiler/rustc_mir_transform/src/copy_prop.rs new file mode 100644 index 00000000000..cddeefca681 --- /dev/null +++ b/compiler/rustc_mir_transform/src/copy_prop.rs @@ -0,0 +1,156 @@ +use rustc_index::IndexSlice; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use tracing::{debug, instrument}; + +use crate::ssa::SsaLocals; + +/// Unify locals that copy each other. +/// +/// We consider patterns of the form +/// _a = rvalue +/// _b = move? _a +/// _c = move? _a +/// _d = move? _c +/// where each of the locals is only assigned once. +/// +/// We want to replace all those locals by `_a`, either copied or moved. +pub(super) struct CopyProp; + +impl<'tcx> crate::MirPass<'tcx> for CopyProp { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 1 + } + + #[instrument(level = "trace", skip(self, tcx, body))] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); + + let typing_env = body.typing_env(tcx); + let ssa = SsaLocals::new(tcx, body, typing_env); + debug!(borrowed_locals = ?ssa.borrowed_locals()); + debug!(copy_classes = ?ssa.copy_classes()); + + let mut any_replacement = false; + let mut storage_to_remove = DenseBitSet::new_empty(body.local_decls.len()); + for (local, &head) in ssa.copy_classes().iter_enumerated() { + if local != head { + any_replacement = true; + storage_to_remove.insert(head); + } + } + + if !any_replacement { + return; + } + + let fully_moved = fully_moved_locals(&ssa, body); + debug!(?fully_moved); + + Replacer { tcx, copy_classes: ssa.copy_classes(), fully_moved, storage_to_remove } + .visit_body_preserves_cfg(body); + + crate::simplify::remove_unused_definitions(body); + } + + fn is_required(&self) -> bool { + false + } +} + +/// `SsaLocals` computed equivalence classes between locals considering copy/move assignments. +/// +/// This function also returns whether all the `move?` in the pattern are `move` and not copies. +/// A local which is in the bitset can be replaced by `move _a`. Otherwise, it must be +/// replaced by `copy _a`, as we cannot move multiple times from `_a`. +/// +/// If an operand copies `_c`, it must happen before the assignment `_d = _c`, otherwise it is UB. +/// This means that replacing it by a copy of `_a` if ok, since this copy happens before `_c` is +/// moved, and therefore that `_d` is moved. +#[instrument(level = "trace", skip(ssa, body))] +fn fully_moved_locals(ssa: &SsaLocals, body: &Body<'_>) -> DenseBitSet<Local> { + let mut fully_moved = DenseBitSet::new_filled(body.local_decls.len()); + + for (_, rvalue, _) in ssa.assignments(body) { + let (Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) + | Rvalue::CopyForDeref(place)) = rvalue + else { + continue; + }; + + let Some(rhs) = place.as_local() else { continue }; + if !ssa.is_ssa(rhs) { + continue; + } + + if let Rvalue::Use(Operand::Copy(_)) | Rvalue::CopyForDeref(_) = rvalue { + fully_moved.remove(rhs); + } + } + + ssa.meet_copy_equivalence(&mut fully_moved); + + fully_moved +} + +/// Utility to help performing substitution of `*pattern` by `target`. +struct Replacer<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + fully_moved: DenseBitSet<Local>, + storage_to_remove: DenseBitSet<Local>, + copy_classes: &'a IndexSlice<Local, Local>, +} + +impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + #[tracing::instrument(level = "trace", skip(self))] + fn visit_local(&mut self, local: &mut Local, ctxt: PlaceContext, _: Location) { + let new_local = self.copy_classes[*local]; + match ctxt { + // Do not modify the local in storage statements. + PlaceContext::NonUse(NonUseContext::StorageLive | NonUseContext::StorageDead) => {} + // We access the value. + _ => *local = new_local, + } + } + + #[tracing::instrument(level = "trace", skip(self))] + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, loc: Location) { + if let Operand::Move(place) = *operand + // A move out of a projection of a copy is equivalent to a copy of the original + // projection. + && !place.is_indirect_first_projection() + && !self.fully_moved.contains(place.local) + { + *operand = Operand::Copy(place); + } + self.super_operand(operand, loc); + } + + #[tracing::instrument(level = "trace", skip(self))] + fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, loc: Location) { + // When removing storage statements, we need to remove both (#107511). + if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = stmt.kind + && self.storage_to_remove.contains(l) + { + stmt.make_nop(); + return; + } + + self.super_statement(stmt, loc); + + // Do not leave tautological assignments around. + if let StatementKind::Assign(box (lhs, ref rhs)) = stmt.kind + && let Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)) | Rvalue::CopyForDeref(rhs) = + *rhs + && lhs == rhs + { + stmt.make_nop(); + } + } +} diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs new file mode 100644 index 00000000000..06c6b46a9c2 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -0,0 +1,1986 @@ +//! This is the implementation of the pass which transforms coroutines into state machines. +//! +//! MIR generation for coroutines creates a function which has a self argument which +//! passes by value. This argument is effectively a coroutine type which only contains upvars and +//! is only used for this argument inside the MIR for the coroutine. +//! It is passed by value to enable upvars to be moved out of it. Drop elaboration runs on that +//! MIR before this pass and creates drop flags for MIR locals. +//! It will also drop the coroutine argument (which only consists of upvars) if any of the upvars +//! are moved out of. This pass elaborates the drops of upvars / coroutine argument in the case +//! that none of the upvars were moved out of. This is because we cannot have any drops of this +//! coroutine in the MIR, since it is used to create the drop glue for the coroutine. We'd get +//! infinite recursion otherwise. +//! +//! This pass creates the implementation for either the `Coroutine::resume` or `Future::poll` +//! function and the drop shim for the coroutine based on the MIR input. +//! It converts the coroutine argument from Self to &mut Self adding derefs in the MIR as needed. +//! It computes the final layout of the coroutine struct which looks like this: +//! First upvars are stored +//! It is followed by the coroutine state field. +//! Then finally the MIR locals which are live across a suspension point are stored. +//! ```ignore (illustrative) +//! struct Coroutine { +//! upvars..., +//! state: u32, +//! mir_locals..., +//! } +//! ``` +//! This pass computes the meaning of the state field and the MIR locals which are live +//! across a suspension point. There are however three hardcoded coroutine states: +//! 0 - Coroutine have not been resumed yet +//! 1 - Coroutine has returned / is completed +//! 2 - Coroutine has been poisoned +//! +//! It also rewrites `return x` and `yield y` as setting a new coroutine state and returning +//! `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`, +//! or `Poll::Ready(x)` and `Poll::Pending` respectively. +//! MIR locals which are live across a suspension point are moved to the coroutine struct +//! with references to them being updated with references to the coroutine struct. +//! +//! The pass creates two functions which have a switch on the coroutine state giving +//! the action to take. +//! +//! One of them is the implementation of `Coroutine::resume` / `Future::poll`. +//! For coroutines with state 0 (unresumed) it starts the execution of the coroutine. +//! For coroutines with state 1 (returned) and state 2 (poisoned) it panics. +//! Otherwise it continues the execution from the last suspension point. +//! +//! The other function is the drop glue for the coroutine. +//! For coroutines with state 0 (unresumed) it drops the upvars of the coroutine. +//! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing. +//! Otherwise it drops all the values in scope at the last suspension point. + +mod by_move_body; +mod drop; +use std::{iter, ops}; + +pub(super) use by_move_body::coroutine_by_move_body_def_id; +use drop::{ + cleanup_async_drops, create_coroutine_drop_shim, create_coroutine_drop_shim_async, + create_coroutine_drop_shim_proxy_async, elaborate_coroutine_drops, expand_async_drops, + has_expandable_async_drops, insert_clean_drop, +}; +use rustc_abi::{FieldIdx, VariantIdx}; +use rustc_data_structures::fx::FxHashSet; +use rustc_errors::pluralize; +use rustc_hir as hir; +use rustc_hir::lang_items::LangItem; +use rustc_hir::{CoroutineDesugaring, CoroutineKind}; +use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet}; +use rustc_index::{Idx, IndexVec}; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::util::Discr; +use rustc_middle::ty::{ + self, CoroutineArgs, CoroutineArgsExt, GenericArgsRef, InstanceKind, Ty, TyCtxt, TypingMode, +}; +use rustc_middle::{bug, span_bug}; +use rustc_mir_dataflow::impls::{ + MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive, + always_storage_live_locals, +}; +use rustc_mir_dataflow::{ + Analysis, Results, ResultsCursor, ResultsVisitor, visit_reachable_results, +}; +use rustc_span::def_id::{DefId, LocalDefId}; +use rustc_span::source_map::dummy_spanned; +use rustc_span::symbol::sym; +use rustc_span::{DUMMY_SP, Span}; +use rustc_target::spec::PanicStrategy; +use rustc_trait_selection::error_reporting::InferCtxtErrorExt; +use rustc_trait_selection::infer::TyCtxtInferExt as _; +use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode, ObligationCtxt}; +use tracing::{debug, instrument, trace}; + +use crate::deref_separator::deref_finder; +use crate::{abort_unwinding_calls, errors, pass_manager as pm, simplify}; + +pub(super) struct StateTransform; + +struct RenameLocalVisitor<'tcx> { + from: Local, + to: Local, + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + if *local == self.from { + *local = self.to; + } + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) { + match terminator.kind { + TerminatorKind::Return => { + // Do not replace the implicit `_0` access here, as that's not possible. The + // transform already handles `return` correctly. + } + _ => self.super_terminator(terminator, location), + } + } +} + +struct SelfArgVisitor<'tcx> { + tcx: TyCtxt<'tcx>, + new_base: Place<'tcx>, +} + +impl<'tcx> SelfArgVisitor<'tcx> { + fn new(tcx: TyCtxt<'tcx>, elem: ProjectionElem<Local, Ty<'tcx>>) -> Self { + Self { tcx, new_base: Place { local: SELF_ARG, projection: tcx.mk_place_elems(&[elem]) } } + } +} + +impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + assert_ne!(*local, SELF_ARG); + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + if place.local == SELF_ARG { + replace_base(place, self.new_base, self.tcx); + } else { + self.visit_local(&mut place.local, context, location); + + for elem in place.projection.iter() { + if let PlaceElem::Index(local) = elem { + assert_ne!(local, SELF_ARG); + } + } + } + } +} + +fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) { + place.local = new_base.local; + + let mut new_projection = new_base.projection.to_vec(); + new_projection.append(&mut place.projection.to_vec()); + + place.projection = tcx.mk_place_elems(&new_projection); +} + +const SELF_ARG: Local = Local::from_u32(1); +const CTX_ARG: Local = Local::from_u32(2); + +/// A `yield` point in the coroutine. +struct SuspensionPoint<'tcx> { + /// State discriminant used when suspending or resuming at this point. + state: usize, + /// The block to jump to after resumption. + resume: BasicBlock, + /// Where to move the resume argument after resumption. + resume_arg: Place<'tcx>, + /// Which block to jump to if the coroutine is dropped in this state. + drop: Option<BasicBlock>, + /// Set of locals that have live storage while at this suspension point. + storage_liveness: GrowableBitSet<Local>, +} + +struct TransformVisitor<'tcx> { + tcx: TyCtxt<'tcx>, + coroutine_kind: hir::CoroutineKind, + + // The type of the discriminant in the coroutine struct + discr_ty: Ty<'tcx>, + + // Mapping from Local to (type of local, coroutine struct index) + remap: IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>, + + // A map from a suspension point in a block to the locals which have live storage at that point + storage_liveness: IndexVec<BasicBlock, Option<DenseBitSet<Local>>>, + + // A list of suspension points, generated during the transform + suspension_points: Vec<SuspensionPoint<'tcx>>, + + // The set of locals that have no `StorageLive`/`StorageDead` annotations. + always_live_locals: DenseBitSet<Local>, + + // The original RETURN_PLACE local + old_ret_local: Local, + + old_yield_ty: Ty<'tcx>, + + old_ret_ty: Ty<'tcx>, +} + +impl<'tcx> TransformVisitor<'tcx> { + fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock { + let block = body.basic_blocks.next_index(); + let source_info = SourceInfo::outermost(body.span); + + let none_value = match self.coroutine_kind { + CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => { + span_bug!(body.span, "`Future`s are not fused inherently") + } + CoroutineKind::Coroutine(_) => span_bug!(body.span, "`Coroutine`s cannot be fused"), + // `gen` continues return `None` + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { + let option_def_id = self.tcx.require_lang_item(LangItem::Option, body.span); + make_aggregate_adt( + option_def_id, + VariantIdx::ZERO, + self.tcx.mk_args(&[self.old_yield_ty.into()]), + IndexVec::new(), + ) + } + // `async gen` continues to return `Poll::Ready(None)` + CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => { + let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() }; + let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() }; + let yield_ty = args.type_at(0); + Rvalue::Use(Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + const_: Const::Unevaluated( + UnevaluatedConst::new( + self.tcx.require_lang_item(LangItem::AsyncGenFinished, body.span), + self.tcx.mk_args(&[yield_ty.into()]), + ), + self.old_yield_ty, + ), + user_ty: None, + }))) + } + }; + + let statements = vec![Statement { + kind: StatementKind::Assign(Box::new((Place::return_place(), none_value))), + source_info, + }]; + + body.basic_blocks_mut().push(BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }); + + block + } + + // Make a `CoroutineState` or `Poll` variant assignment. + // + // `core::ops::CoroutineState` only has single element tuple variants, + // so we can just write to the downcasted first field and then set the + // discriminant to the appropriate variant. + fn make_state( + &self, + val: Operand<'tcx>, + source_info: SourceInfo, + is_return: bool, + statements: &mut Vec<Statement<'tcx>>, + ) { + const ZERO: VariantIdx = VariantIdx::ZERO; + const ONE: VariantIdx = VariantIdx::from_usize(1); + let rvalue = match self.coroutine_kind { + CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => { + let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, source_info.span); + let args = self.tcx.mk_args(&[self.old_ret_ty.into()]); + let (variant_idx, operands) = if is_return { + (ZERO, IndexVec::from_raw(vec![val])) // Poll::Ready(val) + } else { + (ONE, IndexVec::new()) // Poll::Pending + }; + make_aggregate_adt(poll_def_id, variant_idx, args, operands) + } + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { + let option_def_id = self.tcx.require_lang_item(LangItem::Option, source_info.span); + let args = self.tcx.mk_args(&[self.old_yield_ty.into()]); + let (variant_idx, operands) = if is_return { + (ZERO, IndexVec::new()) // None + } else { + (ONE, IndexVec::from_raw(vec![val])) // Some(val) + }; + make_aggregate_adt(option_def_id, variant_idx, args, operands) + } + CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => { + if is_return { + let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() }; + let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() }; + let yield_ty = args.type_at(0); + Rvalue::Use(Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + const_: Const::Unevaluated( + UnevaluatedConst::new( + self.tcx.require_lang_item( + LangItem::AsyncGenFinished, + source_info.span, + ), + self.tcx.mk_args(&[yield_ty.into()]), + ), + self.old_yield_ty, + ), + user_ty: None, + }))) + } else { + Rvalue::Use(val) + } + } + CoroutineKind::Coroutine(_) => { + let coroutine_state_def_id = + self.tcx.require_lang_item(LangItem::CoroutineState, source_info.span); + let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]); + let variant_idx = if is_return { + ONE // CoroutineState::Complete(val) + } else { + ZERO // CoroutineState::Yielded(val) + }; + make_aggregate_adt( + coroutine_state_def_id, + variant_idx, + args, + IndexVec::from_raw(vec![val]), + ) + } + }; + + statements.push(Statement { + kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))), + source_info, + }); + } + + // Create a Place referencing a coroutine struct field + fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> { + let self_place = Place::from(SELF_ARG); + let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index); + let mut projection = base.projection.to_vec(); + projection.push(ProjectionElem::Field(idx, ty)); + + Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) } + } + + // Create a statement which changes the discriminant + fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> { + let self_place = Place::from(SELF_ARG); + Statement { + source_info, + kind: StatementKind::SetDiscriminant { + place: Box::new(self_place), + variant_index: state_disc, + }, + } + } + + // Create a statement which reads the discriminant into a temporary + fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) { + let temp_decl = LocalDecl::new(self.discr_ty, body.span); + let local_decls_len = body.local_decls.push(temp_decl); + let temp = Place::from(local_decls_len); + + let self_place = Place::from(SELF_ARG); + let assign = Statement { + source_info: SourceInfo::outermost(body.span), + kind: StatementKind::Assign(Box::new((temp, Rvalue::Discriminant(self_place)))), + }; + (assign, temp) + } +} + +impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + assert!(!self.remap.contains(*local)); + } + + fn visit_place( + &mut self, + place: &mut Place<'tcx>, + _context: PlaceContext, + _location: Location, + ) { + // Replace an Local in the remap with a coroutine struct access + if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) { + replace_base(place, self.make_field(variant_index, idx, ty), self.tcx); + } + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { + // Remove StorageLive and StorageDead statements for remapped locals + for s in &mut data.statements { + if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = s.kind + && self.remap.contains(l) + { + s.make_nop(); + } + } + + let ret_val = match data.terminator().kind { + TerminatorKind::Return => { + Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None)) + } + TerminatorKind::Yield { ref value, resume, resume_arg, drop } => { + Some((false, Some((resume, resume_arg)), value.clone(), drop)) + } + _ => None, + }; + + if let Some((is_return, resume, v, drop)) = ret_val { + let source_info = data.terminator().source_info; + // We must assign the value first in case it gets declared dead below + self.make_state(v, source_info, is_return, &mut data.statements); + let state = if let Some((resume, mut resume_arg)) = resume { + // Yield + let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len(); + + // The resume arg target location might itself be remapped if its base local is + // live across a yield. + if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) { + replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx); + } + + let storage_liveness: GrowableBitSet<Local> = + self.storage_liveness[block].clone().unwrap().into(); + + for i in 0..self.always_live_locals.domain_size() { + let l = Local::new(i); + let needs_storage_dead = storage_liveness.contains(l) + && !self.remap.contains(l) + && !self.always_live_locals.contains(l); + if needs_storage_dead { + data.statements + .push(Statement { source_info, kind: StatementKind::StorageDead(l) }); + } + } + + self.suspension_points.push(SuspensionPoint { + state, + resume, + resume_arg, + drop, + storage_liveness, + }); + + VariantIdx::new(state) + } else { + // Return + VariantIdx::new(CoroutineArgs::RETURNED) // state for returned + }; + data.statements.push(self.set_discr(state, source_info)); + data.terminator_mut().kind = TerminatorKind::Return; + } + + self.super_basic_block_data(block, data); + } +} + +fn make_aggregate_adt<'tcx>( + def_id: DefId, + variant_idx: VariantIdx, + args: GenericArgsRef<'tcx>, + operands: IndexVec<FieldIdx, Operand<'tcx>>, +) -> Rvalue<'tcx> { + Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands) +} + +fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let coroutine_ty = body.local_decls.raw[1].ty; + + let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty); + + // Replace the by value coroutine argument + body.local_decls.raw[1].ty = ref_coroutine_ty; + + // Add a deref to accesses of the coroutine state + SelfArgVisitor::new(tcx, ProjectionElem::Deref).visit_body(body); +} + +fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let ref_coroutine_ty = body.local_decls.raw[1].ty; + + let pin_did = tcx.require_lang_item(LangItem::Pin, body.span); + let pin_adt_ref = tcx.adt_def(pin_did); + let args = tcx.mk_args(&[ref_coroutine_ty.into()]); + let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args); + + // Replace the by ref coroutine argument + body.local_decls.raw[1].ty = pin_ref_coroutine_ty; + + // Add the Pin field access to accesses of the coroutine state + SelfArgVisitor::new(tcx, ProjectionElem::Field(FieldIdx::ZERO, ref_coroutine_ty)) + .visit_body(body); +} + +/// Allocates a new local and replaces all references of `local` with it. Returns the new local. +/// +/// `local` will be changed to a new local decl with type `ty`. +/// +/// Note that the new local will be uninitialized. It is the caller's responsibility to assign some +/// valid value to it before its first use. +fn replace_local<'tcx>( + local: Local, + ty: Ty<'tcx>, + body: &mut Body<'tcx>, + tcx: TyCtxt<'tcx>, +) -> Local { + let new_decl = LocalDecl::new(ty, body.span); + let new_local = body.local_decls.push(new_decl); + body.local_decls.swap(local, new_local); + + RenameLocalVisitor { from: local, to: new_local, tcx }.visit_body(body); + + new_local +} + +/// Transforms the `body` of the coroutine applying the following transforms: +/// +/// - Eliminates all the `get_context` calls that async lowering created. +/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`). +/// +/// The `Local`s that have their types replaced are: +/// - The `resume` argument itself. +/// - The argument to `get_context`. +/// - The yielded value of a `yield`. +/// +/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the +/// `get_context` function is being used to convert that back to a `&mut Context<'_>`. +/// +/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection, +/// but rather directly use `&mut Context<'_>`, however that would currently +/// lead to higher-kinded lifetime errors. +/// See <https://github.com/rust-lang/rust/issues/105501>. +/// +/// The async lowering step and the type / lifetime inference / checking are +/// still using the `ResumeTy` indirection for the time being, and that indirection +/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`. +fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> { + let context_mut_ref = Ty::new_task_context(tcx); + + // replace the type of the `resume` argument + replace_resume_ty_local(tcx, body, CTX_ARG, context_mut_ref); + + let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span); + + for bb in body.basic_blocks.indices() { + let bb_data = &body[bb]; + if bb_data.is_cleanup { + continue; + } + + match &bb_data.terminator().kind { + TerminatorKind::Call { func, .. } => { + let func_ty = func.ty(body, tcx); + if let ty::FnDef(def_id, _) = *func_ty.kind() + && def_id == get_context_def_id + { + let local = eliminate_get_context_call(&mut body[bb]); + replace_resume_ty_local(tcx, body, local, context_mut_ref); + } + } + TerminatorKind::Yield { resume_arg, .. } => { + replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref); + } + _ => {} + } + } + context_mut_ref +} + +fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local { + let terminator = bb_data.terminator.take().unwrap(); + let TerminatorKind::Call { args, destination, target, .. } = terminator.kind else { + bug!(); + }; + let [arg] = *Box::try_from(args).unwrap(); + let local = arg.node.place().unwrap().local; + + let arg = Rvalue::Use(arg.node); + let assign = Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new((destination, arg))), + }; + bb_data.statements.push(assign); + bb_data.terminator = Some(Terminator { + source_info: terminator.source_info, + kind: TerminatorKind::Goto { target: target.unwrap() }, + }); + local +} + +#[cfg_attr(not(debug_assertions), allow(unused))] +fn replace_resume_ty_local<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + local: Local, + context_mut_ref: Ty<'tcx>, +) { + let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref); + // We have to replace the `ResumeTy` that is used for type and borrow checking + // with `&mut Context<'_>` in MIR. + #[cfg(debug_assertions)] + { + if let ty::Adt(resume_ty_adt, _) = local_ty.kind() { + let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span)); + assert_eq!(*resume_ty_adt, expected_adt); + } else { + panic!("expected `ResumeTy`, found `{:?}`", local_ty); + }; + } +} + +/// Transforms the `body` of the coroutine applying the following transform: +/// +/// - Remove the `resume` argument. +/// +/// Ideally the async lowering would not add the `resume` argument. +/// +/// The async lowering step and the type / lifetime inference / checking are +/// still using the `resume` argument for the time being. After this transform, +/// the coroutine body doesn't have the `resume` argument. +fn transform_gen_context<'tcx>(body: &mut Body<'tcx>) { + // This leaves the local representing the `resume` argument in place, + // but turns it into a regular local variable. This is cheaper than + // adjusting all local references in the body after removing it. + body.arg_count = 1; +} + +struct LivenessInfo { + /// Which locals are live across any suspension point. + saved_locals: CoroutineSavedLocals, + + /// The set of saved locals live at each suspension point. + live_locals_at_suspension_points: Vec<DenseBitSet<CoroutineSavedLocal>>, + + /// Parallel vec to the above with SourceInfo for each yield terminator. + source_info_at_suspension_points: Vec<SourceInfo>, + + /// For every saved local, the set of other saved locals that are + /// storage-live at the same time as this local. We cannot overlap locals in + /// the layout which have conflicting storage. + storage_conflicts: BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>, + + /// For every suspending block, the locals which are storage-live across + /// that suspension point. + storage_liveness: IndexVec<BasicBlock, Option<DenseBitSet<Local>>>, +} + +/// Computes which locals have to be stored in the state-machine for the +/// given coroutine. +/// +/// The basic idea is as follows: +/// - a local is live until we encounter a `StorageDead` statement. In +/// case none exist, the local is considered to be always live. +/// - a local has to be stored if it is either directly used after the +/// the suspend point, or if it is live and has been previously borrowed. +fn locals_live_across_suspend_points<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + always_live_locals: &DenseBitSet<Local>, + movable: bool, +) -> LivenessInfo { + // Calculate when MIR locals have live storage. This gives us an upper bound of their + // lifetimes. + let mut storage_live = MaybeStorageLive::new(std::borrow::Cow::Borrowed(always_live_locals)) + .iterate_to_fixpoint(tcx, body, None) + .into_results_cursor(body); + + // Calculate the MIR locals that have been previously borrowed (even if they are still active). + let borrowed_locals = MaybeBorrowedLocals.iterate_to_fixpoint(tcx, body, Some("coroutine")); + let mut borrowed_locals_analysis1 = borrowed_locals.analysis; + let mut borrowed_locals_analysis2 = borrowed_locals_analysis1.clone(); // trivial + let borrowed_locals_cursor1 = ResultsCursor::new_borrowing( + body, + &mut borrowed_locals_analysis1, + &borrowed_locals.results, + ); + let mut borrowed_locals_cursor2 = ResultsCursor::new_borrowing( + body, + &mut borrowed_locals_analysis2, + &borrowed_locals.results, + ); + + // Calculate the MIR locals that we need to keep storage around for. + let mut requires_storage = + MaybeRequiresStorage::new(borrowed_locals_cursor1).iterate_to_fixpoint(tcx, body, None); + let mut requires_storage_cursor = ResultsCursor::new_borrowing( + body, + &mut requires_storage.analysis, + &requires_storage.results, + ); + + // Calculate the liveness of MIR locals ignoring borrows. + let mut liveness = + MaybeLiveLocals.iterate_to_fixpoint(tcx, body, Some("coroutine")).into_results_cursor(body); + + let mut storage_liveness_map = IndexVec::from_elem(None, &body.basic_blocks); + let mut live_locals_at_suspension_points = Vec::new(); + let mut source_info_at_suspension_points = Vec::new(); + let mut live_locals_at_any_suspension_point = DenseBitSet::new_empty(body.local_decls.len()); + + for (block, data) in body.basic_blocks.iter_enumerated() { + if let TerminatorKind::Yield { .. } = data.terminator().kind { + let loc = Location { block, statement_index: data.statements.len() }; + + liveness.seek_to_block_end(block); + let mut live_locals = liveness.get().clone(); + + if !movable { + // The `liveness` variable contains the liveness of MIR locals ignoring borrows. + // This is correct for movable coroutines since borrows cannot live across + // suspension points. However for immovable coroutines we need to account for + // borrows, so we conservatively assume that all borrowed locals are live until + // we find a StorageDead statement referencing the locals. + // To do this we just union our `liveness` result with `borrowed_locals`, which + // contains all the locals which has been borrowed before this suspension point. + // If a borrow is converted to a raw reference, we must also assume that it lives + // forever. Note that the final liveness is still bounded by the storage liveness + // of the local, which happens using the `intersect` operation below. + borrowed_locals_cursor2.seek_before_primary_effect(loc); + live_locals.union(borrowed_locals_cursor2.get()); + } + + // Store the storage liveness for later use so we can restore the state + // after a suspension point + storage_live.seek_before_primary_effect(loc); + storage_liveness_map[block] = Some(storage_live.get().clone()); + + // Locals live are live at this point only if they are used across + // suspension points (the `liveness` variable) + // and their storage is required (the `storage_required` variable) + requires_storage_cursor.seek_before_primary_effect(loc); + live_locals.intersect(requires_storage_cursor.get()); + + // The coroutine argument is ignored. + live_locals.remove(SELF_ARG); + + debug!("loc = {:?}, live_locals = {:?}", loc, live_locals); + + // Add the locals live at this suspension point to the set of locals which live across + // any suspension points + live_locals_at_any_suspension_point.union(&live_locals); + + live_locals_at_suspension_points.push(live_locals); + source_info_at_suspension_points.push(data.terminator().source_info); + } + } + + debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point); + let saved_locals = CoroutineSavedLocals(live_locals_at_any_suspension_point); + + // Renumber our liveness_map bitsets to include only the locals we are + // saving. + let live_locals_at_suspension_points = live_locals_at_suspension_points + .iter() + .map(|live_here| saved_locals.renumber_bitset(live_here)) + .collect(); + + let storage_conflicts = compute_storage_conflicts( + body, + &saved_locals, + always_live_locals.clone(), + &mut requires_storage.analysis, + &requires_storage.results, + ); + + LivenessInfo { + saved_locals, + live_locals_at_suspension_points, + source_info_at_suspension_points, + storage_conflicts, + storage_liveness: storage_liveness_map, + } +} + +/// The set of `Local`s that must be saved across yield points. +/// +/// `CoroutineSavedLocal` is indexed in terms of the elements in this set; +/// i.e. `CoroutineSavedLocal::new(1)` corresponds to the second local +/// included in this set. +struct CoroutineSavedLocals(DenseBitSet<Local>); + +impl CoroutineSavedLocals { + /// Returns an iterator over each `CoroutineSavedLocal` along with the `Local` it corresponds + /// to. + fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (CoroutineSavedLocal, Local)> { + self.iter().enumerate().map(|(i, l)| (CoroutineSavedLocal::from(i), l)) + } + + /// Transforms a `DenseBitSet<Local>` that contains only locals saved across yield points to the + /// equivalent `DenseBitSet<CoroutineSavedLocal>`. + fn renumber_bitset(&self, input: &DenseBitSet<Local>) -> DenseBitSet<CoroutineSavedLocal> { + assert!(self.superset(input), "{:?} not a superset of {:?}", self.0, input); + let mut out = DenseBitSet::new_empty(self.count()); + for (saved_local, local) in self.iter_enumerated() { + if input.contains(local) { + out.insert(saved_local); + } + } + out + } + + fn get(&self, local: Local) -> Option<CoroutineSavedLocal> { + if !self.contains(local) { + return None; + } + + let idx = self.iter().take_while(|&l| l < local).count(); + Some(CoroutineSavedLocal::new(idx)) + } +} + +impl ops::Deref for CoroutineSavedLocals { + type Target = DenseBitSet<Local>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// For every saved local, looks for which locals are StorageLive at the same +/// time. Generates a bitset for every local of all the other locals that may be +/// StorageLive simultaneously with that local. This is used in the layout +/// computation; see `CoroutineLayout` for more. +fn compute_storage_conflicts<'mir, 'tcx>( + body: &'mir Body<'tcx>, + saved_locals: &'mir CoroutineSavedLocals, + always_live_locals: DenseBitSet<Local>, + analysis: &mut MaybeRequiresStorage<'mir, 'tcx>, + results: &Results<DenseBitSet<Local>>, +) -> BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal> { + assert_eq!(body.local_decls.len(), saved_locals.domain_size()); + + debug!("compute_storage_conflicts({:?})", body.span); + debug!("always_live = {:?}", always_live_locals); + + // Locals that are always live or ones that need to be stored across + // suspension points are not eligible for overlap. + let mut ineligible_locals = always_live_locals; + ineligible_locals.intersect(&**saved_locals); + + // Compute the storage conflicts for all eligible locals. + let mut visitor = StorageConflictVisitor { + body, + saved_locals, + local_conflicts: BitMatrix::from_row_n(&ineligible_locals, body.local_decls.len()), + eligible_storage_live: DenseBitSet::new_empty(body.local_decls.len()), + }; + + visit_reachable_results(body, analysis, results, &mut visitor); + + let local_conflicts = visitor.local_conflicts; + + // Compress the matrix using only stored locals (Local -> CoroutineSavedLocal). + // + // NOTE: Today we store a full conflict bitset for every local. Technically + // this is twice as many bits as we need, since the relation is symmetric. + // However, in practice these bitsets are not usually large. The layout code + // also needs to keep track of how many conflicts each local has, so it's + // simpler to keep it this way for now. + let mut storage_conflicts = BitMatrix::new(saved_locals.count(), saved_locals.count()); + for (saved_local_a, local_a) in saved_locals.iter_enumerated() { + if ineligible_locals.contains(local_a) { + // Conflicts with everything. + storage_conflicts.insert_all_into_row(saved_local_a); + } else { + // Keep overlap information only for stored locals. + for (saved_local_b, local_b) in saved_locals.iter_enumerated() { + if local_conflicts.contains(local_a, local_b) { + storage_conflicts.insert(saved_local_a, saved_local_b); + } + } + } + } + storage_conflicts +} + +struct StorageConflictVisitor<'a, 'tcx> { + body: &'a Body<'tcx>, + saved_locals: &'a CoroutineSavedLocals, + // FIXME(tmandry): Consider using sparse bitsets here once we have good + // benchmarks for coroutines. + local_conflicts: BitMatrix<Local, Local>, + // We keep this bitset as a buffer to avoid reallocating memory. + eligible_storage_live: DenseBitSet<Local>, +} + +impl<'a, 'tcx> ResultsVisitor<'tcx, MaybeRequiresStorage<'a, 'tcx>> + for StorageConflictVisitor<'a, 'tcx> +{ + fn visit_after_early_statement_effect( + &mut self, + _analysis: &mut MaybeRequiresStorage<'a, 'tcx>, + state: &DenseBitSet<Local>, + _statement: &Statement<'tcx>, + loc: Location, + ) { + self.apply_state(state, loc); + } + + fn visit_after_early_terminator_effect( + &mut self, + _analysis: &mut MaybeRequiresStorage<'a, 'tcx>, + state: &DenseBitSet<Local>, + _terminator: &Terminator<'tcx>, + loc: Location, + ) { + self.apply_state(state, loc); + } +} + +impl StorageConflictVisitor<'_, '_> { + fn apply_state(&mut self, state: &DenseBitSet<Local>, loc: Location) { + // Ignore unreachable blocks. + if let TerminatorKind::Unreachable = self.body.basic_blocks[loc.block].terminator().kind { + return; + } + + self.eligible_storage_live.clone_from(state); + self.eligible_storage_live.intersect(&**self.saved_locals); + + for local in self.eligible_storage_live.iter() { + self.local_conflicts.union_row_with(&self.eligible_storage_live, local); + } + + if self.eligible_storage_live.count() > 1 { + trace!("at {:?}, eligible_storage_live={:?}", loc, self.eligible_storage_live); + } + } +} + +fn compute_layout<'tcx>( + liveness: LivenessInfo, + body: &Body<'tcx>, +) -> ( + IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>, + CoroutineLayout<'tcx>, + IndexVec<BasicBlock, Option<DenseBitSet<Local>>>, +) { + let LivenessInfo { + saved_locals, + live_locals_at_suspension_points, + source_info_at_suspension_points, + storage_conflicts, + storage_liveness, + } = liveness; + + // Gather live local types and their indices. + let mut locals = IndexVec::<CoroutineSavedLocal, _>::new(); + let mut tys = IndexVec::<CoroutineSavedLocal, _>::new(); + for (saved_local, local) in saved_locals.iter_enumerated() { + debug!("coroutine saved local {:?} => {:?}", saved_local, local); + + locals.push(local); + let decl = &body.local_decls[local]; + debug!(?decl); + + // Do not `unwrap_crate_local` here, as post-borrowck cleanup may have already cleared + // the information. This is alright, since `ignore_for_traits` is only relevant when + // this code runs on pre-cleanup MIR, and `ignore_for_traits = false` is the safer + // default. + let ignore_for_traits = match decl.local_info { + // Do not include raw pointers created from accessing `static` items, as those could + // well be re-created by another access to the same static. + ClearCrossCrate::Set(box LocalInfo::StaticRef { is_thread_local, .. }) => { + !is_thread_local + } + // Fake borrows are only read by fake reads, so do not have any reality in + // post-analysis MIR. + ClearCrossCrate::Set(box LocalInfo::FakeBorrow) => true, + _ => false, + }; + let decl = + CoroutineSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits }; + debug!(?decl); + + tys.push(decl); + } + + // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states. + // In debuginfo, these will correspond to the beginning (UNRESUMED) or end + // (RETURNED, POISONED) of the function. + let body_span = body.source_scopes[OUTERMOST_SOURCE_SCOPE].span; + let mut variant_source_info: IndexVec<VariantIdx, SourceInfo> = [ + SourceInfo::outermost(body_span.shrink_to_lo()), + SourceInfo::outermost(body_span.shrink_to_hi()), + SourceInfo::outermost(body_span.shrink_to_hi()), + ] + .iter() + .copied() + .collect(); + + // Build the coroutine variant field list. + // Create a map from local indices to coroutine struct indices. + let mut variant_fields: IndexVec<VariantIdx, IndexVec<FieldIdx, CoroutineSavedLocal>> = + iter::repeat(IndexVec::new()).take(CoroutineArgs::RESERVED_VARIANTS).collect(); + let mut remap = IndexVec::from_elem_n(None, saved_locals.domain_size()); + for (suspension_point_idx, live_locals) in live_locals_at_suspension_points.iter().enumerate() { + let variant_index = + VariantIdx::from(CoroutineArgs::RESERVED_VARIANTS + suspension_point_idx); + let mut fields = IndexVec::new(); + for (idx, saved_local) in live_locals.iter().enumerate() { + fields.push(saved_local); + // Note that if a field is included in multiple variants, we will + // just use the first one here. That's fine; fields do not move + // around inside coroutines, so it doesn't matter which variant + // index we access them by. + let idx = FieldIdx::from_usize(idx); + remap[locals[saved_local]] = Some((tys[saved_local].ty, variant_index, idx)); + } + variant_fields.push(fields); + variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]); + } + debug!("coroutine variant_fields = {:?}", variant_fields); + debug!("coroutine storage_conflicts = {:#?}", storage_conflicts); + + let mut field_names = IndexVec::from_elem(None, &tys); + for var in &body.var_debug_info { + let VarDebugInfoContents::Place(place) = &var.value else { continue }; + let Some(local) = place.as_local() else { continue }; + let Some(&Some((_, variant, field))) = remap.get(local) else { + continue; + }; + + let saved_local = variant_fields[variant][field]; + field_names.get_or_insert_with(saved_local, || var.name); + } + + let layout = CoroutineLayout { + field_tys: tys, + field_names, + variant_fields, + variant_source_info, + storage_conflicts, + }; + debug!(?layout); + + (remap, layout, storage_liveness) +} + +/// Replaces the entry point of `body` with a block that switches on the coroutine discriminant and +/// dispatches to blocks according to `cases`. +/// +/// After this function, the former entry point of the function will be bb1. +fn insert_switch<'tcx>( + body: &mut Body<'tcx>, + cases: Vec<(usize, BasicBlock)>, + transform: &TransformVisitor<'tcx>, + default_block: BasicBlock, +) { + let (assign, discr) = transform.get_discr(body); + let switch_targets = + SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block); + let switch = TerminatorKind::SwitchInt { discr: Operand::Move(discr), targets: switch_targets }; + + let source_info = SourceInfo::outermost(body.span); + body.basic_blocks_mut().raw.insert( + 0, + BasicBlockData { + statements: vec![assign], + terminator: Some(Terminator { source_info, kind: switch }), + is_cleanup: false, + }, + ); + + for b in body.basic_blocks_mut().iter_mut() { + b.terminator_mut().successors_mut(|target| *target += 1); + } +} + +fn insert_term_block<'tcx>(body: &mut Body<'tcx>, kind: TerminatorKind<'tcx>) -> BasicBlock { + let source_info = SourceInfo::outermost(body.span); + body.basic_blocks_mut().push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { source_info, kind }), + is_cleanup: false, + }) +} + +fn return_poll_ready_assign<'tcx>(tcx: TyCtxt<'tcx>, source_info: SourceInfo) -> Statement<'tcx> { + // Poll::Ready(()) + let poll_def_id = tcx.require_lang_item(LangItem::Poll, source_info.span); + let args = tcx.mk_args(&[tcx.types.unit.into()]); + let val = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::zero_sized(tcx.types.unit), + })); + let ready_val = Rvalue::Aggregate( + Box::new(AggregateKind::Adt(poll_def_id, VariantIdx::from_usize(0), args, None, None)), + IndexVec::from_raw(vec![val]), + ); + Statement { + kind: StatementKind::Assign(Box::new((Place::return_place(), ready_val))), + source_info, + } +} + +fn insert_poll_ready_block<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> BasicBlock { + let source_info = SourceInfo::outermost(body.span); + body.basic_blocks_mut().push(BasicBlockData { + statements: [return_poll_ready_assign(tcx, source_info)].to_vec(), + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }) +} + +fn insert_panic_block<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + message: AssertMessage<'tcx>, +) -> BasicBlock { + let assert_block = body.basic_blocks.next_index(); + let kind = TerminatorKind::Assert { + cond: Operand::Constant(Box::new(ConstOperand { + span: body.span, + user_ty: None, + const_: Const::from_bool(tcx, false), + })), + expected: true, + msg: Box::new(message), + target: assert_block, + unwind: UnwindAction::Continue, + }; + + insert_term_block(body, kind) +} + +fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, typing_env: ty::TypingEnv<'tcx>) -> bool { + // Returning from a function with an uninhabited return type is undefined behavior. + if body.return_ty().is_privately_uninhabited(tcx, typing_env) { + return false; + } + + // If there's a return terminator the function may return. + body.basic_blocks.iter().any(|block| matches!(block.terminator().kind, TerminatorKind::Return)) + // Otherwise the function can't return. +} + +fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool { + // Nothing can unwind when landing pads are off. + if tcx.sess.panic_strategy() == PanicStrategy::Abort { + return false; + } + + // Unwinds can only start at certain terminators. + for block in body.basic_blocks.iter() { + match block.terminator().kind { + // These never unwind. + TerminatorKind::Goto { .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => {} + + // Resume will *continue* unwinding, but if there's no other unwinding terminator it + // will never be reached. + TerminatorKind::UnwindResume => {} + + TerminatorKind::Yield { .. } => { + unreachable!("`can_unwind` called before coroutine transform") + } + + // These may unwind. + TerminatorKind::Drop { .. } + | TerminatorKind::Call { .. } + | TerminatorKind::InlineAsm { .. } + | TerminatorKind::Assert { .. } => return true, + + TerminatorKind::TailCall { .. } => { + unreachable!("tail calls can't be present in generators") + } + } + } + + // If we didn't find an unwinding terminator, the function cannot unwind. + false +} + +// Poison the coroutine when it unwinds +fn generate_poison_block_and_redirect_unwinds_there<'tcx>( + transform: &TransformVisitor<'tcx>, + body: &mut Body<'tcx>, +) { + let source_info = SourceInfo::outermost(body.span); + let poison_block = body.basic_blocks_mut().push(BasicBlockData { + statements: vec![ + transform.set_discr(VariantIdx::new(CoroutineArgs::POISONED), source_info), + ], + terminator: Some(Terminator { source_info, kind: TerminatorKind::UnwindResume }), + is_cleanup: true, + }); + + for (idx, block) in body.basic_blocks_mut().iter_enumerated_mut() { + let source_info = block.terminator().source_info; + + if let TerminatorKind::UnwindResume = block.terminator().kind { + // An existing `Resume` terminator is redirected to jump to our dedicated + // "poisoning block" above. + if idx != poison_block { + *block.terminator_mut() = + Terminator { source_info, kind: TerminatorKind::Goto { target: poison_block } }; + } + } else if !block.is_cleanup + // Any terminators that *can* unwind but don't have an unwind target set are also + // pointed at our poisoning block (unless they're part of the cleanup path). + && let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut() + { + *unwind = UnwindAction::Cleanup(poison_block); + } + } +} + +fn create_coroutine_resume_function<'tcx>( + tcx: TyCtxt<'tcx>, + transform: TransformVisitor<'tcx>, + body: &mut Body<'tcx>, + can_return: bool, + can_unwind: bool, +) { + // Poison the coroutine when it unwinds + if can_unwind { + generate_poison_block_and_redirect_unwinds_there(&transform, body); + } + + let mut cases = create_cases(body, &transform, Operation::Resume); + + use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn}; + + // Jump to the entry point on the unresumed + cases.insert(0, (CoroutineArgs::UNRESUMED, START_BLOCK)); + + // Panic when resumed on the returned or poisoned state + if can_unwind { + cases.insert( + 1, + ( + CoroutineArgs::POISONED, + insert_panic_block(tcx, body, ResumedAfterPanic(transform.coroutine_kind)), + ), + ); + } + + if can_return { + let block = match transform.coroutine_kind { + CoroutineKind::Desugared(CoroutineDesugaring::Async, _) + | CoroutineKind::Coroutine(_) => { + // For `async_drop_in_place<T>::{closure}` we just keep return Poll::Ready, + // because async drop of such coroutine keeps polling original coroutine + if tcx.is_async_drop_in_place_coroutine(body.source.def_id()) { + insert_poll_ready_block(tcx, body) + } else { + insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind)) + } + } + CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) + | CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { + transform.insert_none_ret_block(body) + } + }; + cases.insert(1, (CoroutineArgs::RETURNED, block)); + } + + let default_block = insert_term_block(body, TerminatorKind::Unreachable); + insert_switch(body, cases, &transform, default_block); + + make_coroutine_state_argument_indirect(tcx, body); + + match transform.coroutine_kind { + CoroutineKind::Coroutine(_) + | CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) => + { + make_coroutine_state_argument_pinned(tcx, body); + } + // Iterator::next doesn't accept a pinned argument, + // unlike for all other coroutine kinds. + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {} + } + + // Make sure we remove dead blocks to remove + // unrelated code from the drop part of the function + simplify::remove_dead_blocks(body); + + pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None); + + dump_mir(tcx, false, "coroutine_resume", &0, body, |_, _| Ok(())); +} + +/// An operation that can be performed on a coroutine. +#[derive(PartialEq, Copy, Clone)] +enum Operation { + Resume, + Drop, +} + +impl Operation { + fn target_block(self, point: &SuspensionPoint<'_>) -> Option<BasicBlock> { + match self { + Operation::Resume => Some(point.resume), + Operation::Drop => point.drop, + } + } +} + +fn create_cases<'tcx>( + body: &mut Body<'tcx>, + transform: &TransformVisitor<'tcx>, + operation: Operation, +) -> Vec<(usize, BasicBlock)> { + let source_info = SourceInfo::outermost(body.span); + + transform + .suspension_points + .iter() + .filter_map(|point| { + // Find the target for this suspension point, if applicable + operation.target_block(point).map(|target| { + let mut statements = Vec::new(); + + // Create StorageLive instructions for locals with live storage + for l in body.local_decls.indices() { + let needs_storage_live = point.storage_liveness.contains(l) + && !transform.remap.contains(l) + && !transform.always_live_locals.contains(l); + if needs_storage_live { + statements + .push(Statement { source_info, kind: StatementKind::StorageLive(l) }); + } + } + + if operation == Operation::Resume { + // Move the resume argument to the destination place of the `Yield` terminator + let resume_arg = CTX_ARG; + statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + point.resume_arg, + Rvalue::Use(Operand::Move(resume_arg.into())), + ))), + }); + } + + // Then jump to the real target + let block = body.basic_blocks_mut().push(BasicBlockData { + statements, + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::Goto { target }, + }), + is_cleanup: false, + }); + + (point.state, block) + }) + }) + .collect() +} + +#[instrument(level = "debug", skip(tcx), ret)] +pub(crate) fn mir_coroutine_witnesses<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: LocalDefId, +) -> Option<CoroutineLayout<'tcx>> { + let (body, _) = tcx.mir_promoted(def_id); + let body = body.borrow(); + let body = &*body; + + // The first argument is the coroutine type passed by value + let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + + let movable = match *coroutine_ty.kind() { + ty::Coroutine(def_id, _) => tcx.coroutine_movability(def_id) == hir::Movability::Movable, + ty::Error(_) => return None, + _ => span_bug!(body.span, "unexpected coroutine type {}", coroutine_ty), + }; + + // The witness simply contains all locals live across suspend points. + + let always_live_locals = always_storage_live_locals(body); + let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); + + // Extract locals which are live across suspension point into `layout` + // `remap` gives a mapping from local indices onto coroutine struct indices + // `storage_liveness` tells us which locals have live storage at suspension points + let (_, coroutine_layout, _) = compute_layout(liveness_info, body); + + check_suspend_tys(tcx, &coroutine_layout, body); + check_field_tys_sized(tcx, &coroutine_layout, def_id); + + Some(coroutine_layout) +} + +fn check_field_tys_sized<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_layout: &CoroutineLayout<'tcx>, + def_id: LocalDefId, +) { + // No need to check if unsized_fn_params is disabled, + // since we will error during typeck. + if !tcx.features().unsized_fn_params() { + return; + } + + // FIXME(#132279): @lcnr believes that we may want to support coroutines + // whose `Sized`-ness relies on the hidden types of opaques defined by the + // parent function. In this case we'd have to be able to reveal only these + // opaques here. + let infcx = tcx.infer_ctxt().ignoring_regions().build(TypingMode::non_body_analysis()); + let param_env = tcx.param_env(def_id); + + let ocx = ObligationCtxt::new_with_diagnostics(&infcx); + for field_ty in &coroutine_layout.field_tys { + ocx.register_bound( + ObligationCause::new( + field_ty.source_info.span, + def_id, + ObligationCauseCode::SizedCoroutineInterior(def_id), + ), + param_env, + field_ty.ty, + tcx.require_lang_item(hir::LangItem::Sized, field_ty.source_info.span), + ); + } + + let errors = ocx.select_all_or_error(); + debug!(?errors); + if !errors.is_empty() { + infcx.err_ctxt().report_fulfillment_errors(errors); + } +} + +impl<'tcx> crate::MirPass<'tcx> for StateTransform { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let Some(old_yield_ty) = body.yield_ty() else { + // This only applies to coroutines + return; + }; + let old_ret_ty = body.return_ty(); + + assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none()); + + dump_mir(tcx, false, "coroutine_before", &0, body, |_, _| Ok(())); + + // The first argument is the coroutine type passed by value + let coroutine_ty = body.local_decls.raw[1].ty; + let coroutine_kind = body.coroutine_kind().unwrap(); + + // Get the discriminant type and args which typeck computed + let ty::Coroutine(_, args) = coroutine_ty.kind() else { + tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}")); + }; + let discr_ty = args.as_coroutine().discr_ty(tcx); + + let new_ret_ty = match coroutine_kind { + CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => { + // Compute Poll<return_ty> + let poll_did = tcx.require_lang_item(LangItem::Poll, body.span); + let poll_adt_ref = tcx.adt_def(poll_did); + let poll_args = tcx.mk_args(&[old_ret_ty.into()]); + Ty::new_adt(tcx, poll_adt_ref, poll_args) + } + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { + // Compute Option<yield_ty> + let option_did = tcx.require_lang_item(LangItem::Option, body.span); + let option_adt_ref = tcx.adt_def(option_did); + let option_args = tcx.mk_args(&[old_yield_ty.into()]); + Ty::new_adt(tcx, option_adt_ref, option_args) + } + CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => { + // The yield ty is already `Poll<Option<yield_ty>>` + old_yield_ty + } + CoroutineKind::Coroutine(_) => { + // Compute CoroutineState<yield_ty, return_ty> + let state_did = tcx.require_lang_item(LangItem::CoroutineState, body.span); + let state_adt_ref = tcx.adt_def(state_did); + let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]); + Ty::new_adt(tcx, state_adt_ref, state_args) + } + }; + + // We rename RETURN_PLACE which has type mir.return_ty to old_ret_local + // RETURN_PLACE then is a fresh unused local with type ret_ty. + let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx); + + // We need to insert clean drop for unresumed state and perform drop elaboration + // (finally in open_drop_for_tuple) before async drop expansion. + // Async drops, produced by this drop elaboration, will be expanded, + // and corresponding futures kept in layout. + let has_async_drops = matches!( + coroutine_kind, + CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) + ) && has_expandable_async_drops(tcx, body, coroutine_ty); + + // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies. + if matches!( + coroutine_kind, + CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) + ) { + let context_mut_ref = transform_async_context(tcx, body); + expand_async_drops(tcx, body, context_mut_ref, coroutine_kind, coroutine_ty); + dump_mir(tcx, false, "coroutine_async_drop_expand", &0, body, |_, _| Ok(())); + } else { + cleanup_async_drops(body); + } + + // We also replace the resume argument and insert an `Assign`. + // This is needed because the resume argument `_2` might be live across a `yield`, in which + // case there is no `Assign` to it that the transform can turn into a store to the coroutine + // state. After the yield the slot in the coroutine state would then be uninitialized. + let resume_local = CTX_ARG; + let resume_ty = body.local_decls[resume_local].ty; + let old_resume_local = replace_local(resume_local, resume_ty, body, tcx); + + // When first entering the coroutine, move the resume argument into its old local + // (which is now a generator interior). + let source_info = SourceInfo::outermost(body.span); + let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements; + stmts.insert( + 0, + Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + old_resume_local.into(), + Rvalue::Use(Operand::Move(resume_local.into())), + ))), + }, + ); + + let always_live_locals = always_storage_live_locals(body); + + let movable = coroutine_kind.movability() == hir::Movability::Movable; + let liveness_info = + locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); + + if tcx.sess.opts.unstable_opts.validate_mir { + let mut vis = EnsureCoroutineFieldAssignmentsNeverAlias { + assigned_local: None, + saved_locals: &liveness_info.saved_locals, + storage_conflicts: &liveness_info.storage_conflicts, + }; + + vis.visit_body(body); + } + + // Extract locals which are live across suspension point into `layout` + // `remap` gives a mapping from local indices onto coroutine struct indices + // `storage_liveness` tells us which locals have live storage at suspension points + let (remap, layout, storage_liveness) = compute_layout(liveness_info, body); + + let can_return = can_return(tcx, body, body.typing_env(tcx)); + + // Run the transformation which converts Places from Local to coroutine struct + // accesses for locals in `remap`. + // It also rewrites `return x` and `yield y` as writing a new coroutine state and returning + // either `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`, + // or `Poll::Ready(x)` and `Poll::Pending` respectively depending on the coroutine kind. + let mut transform = TransformVisitor { + tcx, + coroutine_kind, + remap, + storage_liveness, + always_live_locals, + suspension_points: Vec::new(), + old_ret_local, + discr_ty, + old_ret_ty, + old_yield_ty, + }; + transform.visit_body(body); + + // Update our MIR struct to reflect the changes we've made + body.arg_count = 2; // self, resume arg + body.spread_arg = None; + + // Remove the context argument within generator bodies. + if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) { + transform_gen_context(body); + } + + // The original arguments to the function are no longer arguments, mark them as such. + // Otherwise they'll conflict with our new arguments, which although they don't have + // argument_index set, will get emitted as unnamed arguments. + for var in &mut body.var_debug_info { + var.argument_index = None; + } + + body.coroutine.as_mut().unwrap().yield_ty = None; + body.coroutine.as_mut().unwrap().resume_ty = None; + body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout); + + // FIXME: Drops, produced by insert_clean_drop + elaborate_coroutine_drops, + // are currently sync only. To allow async for them, we need to move those calls + // before expand_async_drops, and fix the related problems. + // + // Insert `drop(coroutine_struct)` which is used to drop upvars for coroutines in + // the unresumed state. + // This is expanded to a drop ladder in `elaborate_coroutine_drops`. + let drop_clean = insert_clean_drop(tcx, body, has_async_drops); + + dump_mir(tcx, false, "coroutine_pre-elab", &0, body, |_, _| Ok(())); + + // Expand `drop(coroutine_struct)` to a drop ladder which destroys upvars. + // If any upvars are moved out of, drop elaboration will handle upvar destruction. + // However we need to also elaborate the code generated by `insert_clean_drop`. + elaborate_coroutine_drops(tcx, body); + + dump_mir(tcx, false, "coroutine_post-transform", &0, body, |_, _| Ok(())); + + let can_unwind = can_unwind(tcx, body); + + // Create a copy of our MIR and use it to create the drop shim for the coroutine + if has_async_drops { + // If coroutine has async drops, generating async drop shim + let mut drop_shim = + create_coroutine_drop_shim_async(tcx, &transform, body, drop_clean, can_unwind); + // Run derefer to fix Derefs that are not in the first place + deref_finder(tcx, &mut drop_shim); + body.coroutine.as_mut().unwrap().coroutine_drop_async = Some(drop_shim); + } else { + // If coroutine has no async drops, generating sync drop shim + let mut drop_shim = + create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean); + // Run derefer to fix Derefs that are not in the first place + deref_finder(tcx, &mut drop_shim); + body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim); + + // For coroutine with sync drop, generating async proxy for `future_drop_poll` call + let mut proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body); + deref_finder(tcx, &mut proxy_shim); + body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim); + } + + // Create the Coroutine::resume / Future::poll function + create_coroutine_resume_function(tcx, transform, body, can_return, can_unwind); + + // Run derefer to fix Derefs that are not in the first place + deref_finder(tcx, body); + } + + fn is_required(&self) -> bool { + true + } +} + +/// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields +/// in the coroutine state machine but whose storage is not marked as conflicting +/// +/// Validation needs to happen immediately *before* `TransformVisitor` is invoked, not after. +/// +/// This condition would arise when the assignment is the last use of `_5` but the initial +/// definition of `_4` if we weren't extra careful to mark all locals used inside a statement as +/// conflicting. Non-conflicting coroutine saved locals may be stored at the same location within +/// the coroutine state machine, which would result in ill-formed MIR: the left-hand and right-hand +/// sides of an assignment may not alias. This caused a miscompilation in [#73137]. +/// +/// [#73137]: https://github.com/rust-lang/rust/issues/73137 +struct EnsureCoroutineFieldAssignmentsNeverAlias<'a> { + saved_locals: &'a CoroutineSavedLocals, + storage_conflicts: &'a BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>, + assigned_local: Option<CoroutineSavedLocal>, +} + +impl EnsureCoroutineFieldAssignmentsNeverAlias<'_> { + fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<CoroutineSavedLocal> { + if place.is_indirect() { + return None; + } + + self.saved_locals.get(place.local) + } + + fn check_assigned_place(&mut self, place: Place<'_>, f: impl FnOnce(&mut Self)) { + if let Some(assigned_local) = self.saved_local_for_direct_place(place) { + assert!(self.assigned_local.is_none(), "`check_assigned_place` must not recurse"); + + self.assigned_local = Some(assigned_local); + f(self); + self.assigned_local = None; + } + } +} + +impl<'tcx> Visitor<'tcx> for EnsureCoroutineFieldAssignmentsNeverAlias<'_> { + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { + let Some(lhs) = self.assigned_local else { + // This visitor only invokes `visit_place` for the right-hand side of an assignment + // and only after setting `self.assigned_local`. However, the default impl of + // `Visitor::super_body` may call `visit_place` with a `NonUseContext` for places + // with debuginfo. Ignore them here. + assert!(!context.is_use()); + return; + }; + + let Some(rhs) = self.saved_local_for_direct_place(*place) else { return }; + + if !self.storage_conflicts.contains(lhs, rhs) { + bug!( + "Assignment between coroutine saved locals whose storage is not \ + marked as conflicting: {:?}: {:?} = {:?}", + location, + lhs, + rhs, + ); + } + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::Assign(box (lhs, rhs)) => { + self.check_assigned_place(*lhs, |this| this.visit_rvalue(rhs, location)); + } + + StatementKind::FakeRead(..) + | StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(..) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Retag(..) + | StatementKind::AscribeUserType(..) + | StatementKind::PlaceMention(..) + | StatementKind::Coverage(..) + | StatementKind::Intrinsic(..) + | StatementKind::ConstEvalCounter + | StatementKind::BackwardIncompatibleDropHint { .. } + | StatementKind::Nop => {} + } + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + // Checking for aliasing in terminators is probably overkill, but until we have actual + // semantics, we should be conservative here. + match &terminator.kind { + TerminatorKind::Call { + func, + args, + destination, + target: Some(_), + unwind: _, + call_source: _, + fn_span: _, + } => { + self.check_assigned_place(*destination, |this| { + this.visit_operand(func, location); + for arg in args { + this.visit_operand(&arg.node, location); + } + }); + } + + TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => { + self.check_assigned_place(*resume_arg, |this| this.visit_operand(value, location)); + } + + // FIXME: Does `asm!` have any aliasing requirements? + TerminatorKind::InlineAsm { .. } => {} + + TerminatorKind::Call { .. } + | TerminatorKind::Goto { .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::TailCall { .. } + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::Assert { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => {} + } + } +} + +fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &CoroutineLayout<'tcx>, body: &Body<'tcx>) { + let mut linted_tys = FxHashSet::default(); + + for (variant, yield_source_info) in + layout.variant_fields.iter().zip(&layout.variant_source_info) + { + debug!(?variant); + for &local in variant { + let decl = &layout.field_tys[local]; + debug!(?decl); + + if !decl.ignore_for_traits && linted_tys.insert(decl.ty) { + let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { + continue; + }; + + check_must_not_suspend_ty( + tcx, + decl.ty, + hir_id, + SuspendCheckData { + source_span: decl.source_info.span, + yield_span: yield_source_info.span, + plural_len: 1, + ..Default::default() + }, + ); + } + } + } +} + +#[derive(Default)] +struct SuspendCheckData<'a> { + source_span: Span, + yield_span: Span, + descr_pre: &'a str, + descr_post: &'a str, + plural_len: usize, +} + +// Returns whether it emitted a diagnostic or not +// Note that this fn and the proceeding one are based on the code +// for creating must_use diagnostics +// +// Note that this technique was chosen over things like a `Suspend` marker trait +// as it is simpler and has precedent in the compiler +fn check_must_not_suspend_ty<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + hir_id: hir::HirId, + data: SuspendCheckData<'_>, +) -> bool { + if ty.is_unit() { + return false; + } + + let plural_suffix = pluralize!(data.plural_len); + + debug!("Checking must_not_suspend for {}", ty); + + match *ty.kind() { + ty::Adt(_, args) if ty.is_box() => { + let boxed_ty = args.type_at(0); + let allocator_ty = args.type_at(1); + check_must_not_suspend_ty( + tcx, + boxed_ty, + hir_id, + SuspendCheckData { descr_pre: &format!("{}boxed ", data.descr_pre), ..data }, + ) || check_must_not_suspend_ty( + tcx, + allocator_ty, + hir_id, + SuspendCheckData { descr_pre: &format!("{}allocator ", data.descr_pre), ..data }, + ) + } + ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data), + // FIXME: support adding the attribute to TAITs + ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => { + let mut has_emitted = false; + for &(predicate, _) in tcx.explicit_item_bounds(def).skip_binder() { + // We only look at the `DefId`, so it is safe to skip the binder here. + if let ty::ClauseKind::Trait(ref poly_trait_predicate) = + predicate.kind().skip_binder() + { + let def_id = poly_trait_predicate.trait_ref.def_id; + let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix); + if check_must_not_suspend_def( + tcx, + def_id, + hir_id, + SuspendCheckData { descr_pre, ..data }, + ) { + has_emitted = true; + break; + } + } + } + has_emitted + } + ty::Dynamic(binder, _, _) => { + let mut has_emitted = false; + for predicate in binder.iter() { + if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() { + let def_id = trait_ref.def_id; + let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post); + if check_must_not_suspend_def( + tcx, + def_id, + hir_id, + SuspendCheckData { descr_post, ..data }, + ) { + has_emitted = true; + break; + } + } + } + has_emitted + } + ty::Tuple(fields) => { + let mut has_emitted = false; + for (i, ty) in fields.iter().enumerate() { + let descr_post = &format!(" in tuple element {i}"); + if check_must_not_suspend_ty( + tcx, + ty, + hir_id, + SuspendCheckData { descr_post, ..data }, + ) { + has_emitted = true; + } + } + has_emitted + } + ty::Array(ty, len) => { + let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix); + check_must_not_suspend_ty( + tcx, + ty, + hir_id, + SuspendCheckData { + descr_pre, + // FIXME(must_not_suspend): This is wrong. We should handle printing unevaluated consts. + plural_len: len.try_to_target_usize(tcx).unwrap_or(0) as usize + 1, + ..data + }, + ) + } + // If drop tracking is enabled, we want to look through references, since the referent + // may not be considered live across the await point. + ty::Ref(_region, ty, _mutability) => { + let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix); + check_must_not_suspend_ty(tcx, ty, hir_id, SuspendCheckData { descr_pre, ..data }) + } + _ => false, + } +} + +fn check_must_not_suspend_def( + tcx: TyCtxt<'_>, + def_id: DefId, + hir_id: hir::HirId, + data: SuspendCheckData<'_>, +) -> bool { + if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) { + let reason = attr.value_str().map(|s| errors::MustNotSuspendReason { + span: data.source_span, + reason: s.as_str().to_string(), + }); + tcx.emit_node_span_lint( + rustc_session::lint::builtin::MUST_NOT_SUSPEND, + hir_id, + data.source_span, + errors::MustNotSupend { + tcx, + yield_sp: data.yield_span, + reason, + src_sp: data.source_span, + pre: data.descr_pre, + def_id, + post: data.descr_post, + }, + ); + + true + } else { + false + } +} diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs new file mode 100644 index 00000000000..0a839d91404 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -0,0 +1,368 @@ +//! This pass constructs a second coroutine body sufficient for return from +//! `FnOnce`/`AsyncFnOnce` implementations for coroutine-closures (e.g. async closures). +//! +//! Consider an async closure like: +//! ```rust +//! let x = vec![1, 2, 3]; +//! +//! let closure = async move || { +//! println!("{x:#?}"); +//! }; +//! ``` +//! +//! This desugars to something like: +//! ```rust,ignore (invalid-borrowck) +//! let x = vec![1, 2, 3]; +//! +//! let closure = move || { +//! async { +//! println!("{x:#?}"); +//! } +//! }; +//! ``` +//! +//! Important to note here is that while the outer closure *moves* `x: Vec<i32>` +//! into its upvars, the inner `async` coroutine simply captures a ref of `x`. +//! This is the "magic" of async closures -- the futures that they return are +//! allowed to borrow from their parent closure's upvars. +//! +//! However, what happens when we call `closure` with `AsyncFnOnce` (or `FnOnce`, +//! since all async closures implement that too)? Well, recall the signature: +//! ``` +//! use std::future::Future; +//! pub trait AsyncFnOnce<Args> +//! { +//! type CallOnceFuture: Future<Output = Self::Output>; +//! type Output; +//! fn async_call_once( +//! self, +//! args: Args +//! ) -> Self::CallOnceFuture; +//! } +//! ``` +//! +//! This signature *consumes* the async closure (`self`) and returns a `CallOnceFuture`. +//! How do we deal with the fact that the coroutine is supposed to take a reference +//! to the captured `x` from the parent closure, when that parent closure has been +//! destroyed? +//! +//! This is the second piece of magic of async closures. We can simply create a +//! *second* `async` coroutine body where that `x` that was previously captured +//! by reference is now captured by value. This means that we consume the outer +//! closure and return a new coroutine that will hold onto all of these captures, +//! and drop them when it is finished (i.e. after it has been `.await`ed). +//! +//! We do this with the analysis below, which detects the captures that come from +//! borrowing from the outer closure, and we simply peel off a `deref` projection +//! from them. This second body is stored alongside the first body, and optimized +//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`, +//! we use this "by-move" body instead. +//! +//! ## How does this work? +//! +//! This pass essentially remaps the body of the (child) closure of the coroutine-closure +//! to take the set of upvars of the parent closure by value. This at least requires +//! changing a by-ref upvar to be by-value in the case that the outer coroutine-closure +//! captures something by value; however, it may also require renumbering field indices +//! in case precise captures (edition 2021 closure capture rules) caused the inner coroutine +//! to split one field capture into two. + +use rustc_abi::{FieldIdx, VariantIdx}; +use rustc_data_structures::steal::Steal; +use rustc_data_structures::unord::UnordMap; +use rustc_hir as hir; +use rustc_hir::def::DefKind; +use rustc_hir::def_id::{DefId, LocalDefId}; +use rustc_hir::definitions::DisambiguatorState; +use rustc_middle::bug; +use rustc_middle::hir::place::{Projection, ProjectionKind}; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::{self, dump_mir}; +use rustc_middle::ty::{self, InstanceKind, Ty, TyCtxt, TypeVisitableExt}; + +pub(crate) fn coroutine_by_move_body_def_id<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_def_id: LocalDefId, +) -> DefId { + let body = tcx.mir_built(coroutine_def_id).borrow(); + + // If the typeck results are tainted, no need to make a by-ref body. + if body.tainted_by_errors.is_some() { + return coroutine_def_id.to_def_id(); + } + + let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) = + tcx.coroutine_kind(coroutine_def_id) + else { + bug!("should only be invoked on coroutine-closures"); + }; + + // Also, let's skip processing any bodies with errors, since there's no guarantee + // the MIR body will be constructed well. + let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + + let ty::Coroutine(_, args) = *coroutine_ty.kind() else { + bug!("tried to create by-move body of non-coroutine receiver"); + }; + let args = args.as_coroutine(); + + let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap(); + + let parent_def_id = tcx.local_parent(coroutine_def_id); + let ty::CoroutineClosure(_, parent_args) = + *tcx.type_of(parent_def_id).instantiate_identity().kind() + else { + bug!("coroutine's parent was not a coroutine-closure"); + }; + if parent_args.references_error() { + return coroutine_def_id.to_def_id(); + } + + let parent_closure_args = parent_args.as_coroutine_closure(); + let num_args = parent_closure_args + .coroutine_closure_sig() + .skip_binder() + .tupled_inputs_ty + .tuple_fields() + .len(); + + let field_remapping: UnordMap<_, _> = ty::analyze_coroutine_closure_captures( + tcx.closure_captures(parent_def_id).iter().copied(), + tcx.closure_captures(coroutine_def_id).iter().skip(num_args).copied(), + |(parent_field_idx, parent_capture), (child_field_idx, child_capture)| { + // Store this set of additional projections (fields and derefs). + // We need to re-apply them later. + let mut child_precise_captures = + child_capture.place.projections[parent_capture.place.projections.len()..].to_vec(); + + // If the parent capture is by-ref, then we need to apply an additional + // deref before applying any further projections to this place. + if parent_capture.is_by_ref() { + child_precise_captures.insert( + 0, + Projection { ty: parent_capture.place.ty(), kind: ProjectionKind::Deref }, + ); + } + // If the child capture is by-ref, then we need to apply a "ref" + // projection (i.e. `&`) at the end. But wait! We don't have that + // as a projection kind. So instead, we can apply its dual and + // *peel* a deref off of the place when it shows up in the MIR body. + // Luckily, by construction this is always possible. + let peel_deref = if child_capture.is_by_ref() { + assert!( + parent_capture.is_by_ref() || coroutine_kind != ty::ClosureKind::FnOnce, + "`FnOnce` coroutine-closures return coroutines that capture from \ + their body; it will always result in a borrowck error!" + ); + true + } else { + false + }; + + // Regarding the behavior above, you may think that it's redundant to both + // insert a deref and then peel a deref if the parent and child are both + // captured by-ref. This would be correct, except for the case where we have + // precise capturing projections, since the inserted deref is to the *beginning* + // and the peeled deref is at the *end*. I cannot seem to actually find a + // case where this happens, though, but let's keep this code flexible. + + // Finally, store the type of the parent's captured place. We need + // this when building the field projection in the MIR body later on. + let mut parent_capture_ty = parent_capture.place.ty(); + parent_capture_ty = match parent_capture.info.capture_kind { + ty::UpvarCapture::ByValue | ty::UpvarCapture::ByUse => parent_capture_ty, + ty::UpvarCapture::ByRef(kind) => Ty::new_ref( + tcx, + tcx.lifetimes.re_erased, + parent_capture_ty, + kind.to_mutbl_lossy(), + ), + }; + + Some(( + FieldIdx::from_usize(child_field_idx + num_args), + ( + FieldIdx::from_usize(parent_field_idx + num_args), + parent_capture_ty, + peel_deref, + child_precise_captures, + ), + )) + }, + ) + .flatten() + .collect(); + + if coroutine_kind == ty::ClosureKind::FnOnce { + assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len()); + // The by-move body is just the body :) + return coroutine_def_id.to_def_id(); + } + + let by_move_coroutine_ty = tcx + .instantiate_bound_regions_with_erased(parent_closure_args.coroutine_closure_sig()) + .to_coroutine_given_kind_and_upvars( + tcx, + parent_closure_args.parent_args(), + coroutine_def_id.to_def_id(), + ty::ClosureKind::FnOnce, + tcx.lifetimes.re_erased, + parent_closure_args.tupled_upvars_ty(), + parent_closure_args.coroutine_captures_by_ref_ty(), + ); + + let mut by_move_body = body.clone(); + MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body); + + // This path is unique since we're in a query so we'll only be called once with `parent_def_id` + // and this is the only location creating `SyntheticCoroutineBody`. + let body_def = tcx.create_def( + parent_def_id, + None, + DefKind::SyntheticCoroutineBody, + None, + &mut DisambiguatorState::new(), + ); + by_move_body.source = + mir::MirSource::from_instance(InstanceKind::Item(body_def.def_id().to_def_id())); + dump_mir(tcx, false, "built", &"after", &by_move_body, |_, _| Ok(())); + + // Feed HIR because we try to access this body's attrs in the inliner. + body_def.feed_hir(); + // Inherited from the by-ref coroutine. + body_def.codegen_fn_attrs(tcx.codegen_fn_attrs(coroutine_def_id).clone()); + body_def.coverage_attr_on(tcx.coverage_attr_on(coroutine_def_id)); + body_def.constness(tcx.constness(coroutine_def_id)); + body_def.coroutine_kind(tcx.coroutine_kind(coroutine_def_id)); + body_def.def_ident_span(tcx.def_ident_span(coroutine_def_id)); + body_def.def_span(tcx.def_span(coroutine_def_id)); + body_def.explicit_predicates_of(tcx.explicit_predicates_of(coroutine_def_id)); + body_def.generics_of(tcx.generics_of(coroutine_def_id).clone()); + body_def.param_env(tcx.param_env(coroutine_def_id)); + body_def.predicates_of(tcx.predicates_of(coroutine_def_id)); + + // The type of the coroutine is the `by_move_coroutine_ty`. + body_def.type_of(ty::EarlyBinder::bind(by_move_coroutine_ty)); + + body_def.mir_built(tcx.arena.alloc(Steal::new(by_move_body))); + + body_def.def_id().to_def_id() +} + +struct MakeByMoveBody<'tcx> { + tcx: TyCtxt<'tcx>, + field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, Vec<Projection<'tcx>>)>, + by_move_coroutine_ty: Ty<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place( + &mut self, + place: &mut mir::Place<'tcx>, + context: mir::visit::PlaceContext, + location: mir::Location, + ) { + // Initializing an upvar local always starts with `CAPTURE_STRUCT_LOCAL` and a + // field projection. If this is in `field_remapping`, then it must not be an + // arg from calling the closure, but instead an upvar. + if place.local == ty::CAPTURE_STRUCT_LOCAL + && let Some((&mir::ProjectionElem::Field(idx, _), projection)) = + place.projection.split_first() + && let Some(&(remapped_idx, remapped_ty, peel_deref, ref bridging_projections)) = + self.field_remapping.get(&idx) + { + // As noted before, if the parent closure captures a field by value, and + // the child captures a field by ref, then for the by-move body we're + // generating, we also are taking that field by value. Peel off a deref, + // since a layer of ref'ing has now become redundant. + let final_projections = if peel_deref { + let Some((mir::ProjectionElem::Deref, projection)) = projection.split_first() + else { + bug!( + "There should be at least a single deref for an upvar local initialization, found {projection:#?}" + ); + }; + // There may be more derefs, since we may also implicitly reborrow + // a captured mut pointer. + projection + } else { + projection + }; + + // These projections are applied in order to "bridge" the local that we are + // currently transforming *from* the old upvar that the by-ref coroutine used + // to capture *to* the upvar of the parent coroutine-closure. For example, if + // the parent captures `&s` but the child captures `&(s.field)`, then we will + // apply a field projection. + let bridging_projections = bridging_projections.iter().map(|elem| match elem.kind { + ProjectionKind::Deref => mir::ProjectionElem::Deref, + ProjectionKind::Field(idx, VariantIdx::ZERO) => { + mir::ProjectionElem::Field(idx, elem.ty) + } + _ => unreachable!("precise captures only through fields and derefs"), + }); + + // We start out with an adjusted field index (and ty), representing the + // upvar that we get from our parent closure. We apply any of the additional + // projections to make sure that to the rest of the body of the closure, the + // place looks the same, and then apply that final deref if necessary. + *place = mir::Place { + local: place.local, + projection: self.tcx.mk_place_elems_from_iter( + [mir::ProjectionElem::Field(remapped_idx, remapped_ty)] + .into_iter() + .chain(bridging_projections) + .chain(final_projections.iter().copied()), + ), + }; + } + self.super_place(place, context, location); + } + + fn visit_statement(&mut self, statement: &mut mir::Statement<'tcx>, location: mir::Location) { + // Remove fake borrows of closure captures if that capture has been + // replaced with a by-move version of that capture. + // + // For example, imagine we capture `Foo` in the parent and `&Foo` + // in the child. We will emit two fake borrows like: + // + // ``` + // _2 = &fake shallow (*(_1.0: &Foo)); + // _3 = &fake shallow (_1.0: &Foo); + // ``` + // + // However, since this transform is responsible for replacing + // `_1.0: &Foo` with `_1.0: Foo`, that makes the second fake borrow + // obsolete, and we should replace it with a nop. + // + // As a side-note, we don't actually even care about fake borrows + // here at all since they're fully a MIR borrowck artifact, and we + // don't need to borrowck by-move MIR bodies. But it's best to preserve + // as much as we can between these two bodies :) + if let mir::StatementKind::Assign(box (_, rvalue)) = &statement.kind + && let mir::Rvalue::Ref(_, mir::BorrowKind::Fake(mir::FakeBorrowKind::Shallow), place) = + rvalue + && let mir::PlaceRef { + local: ty::CAPTURE_STRUCT_LOCAL, + projection: [mir::ProjectionElem::Field(idx, _)], + } = place.as_ref() + && let Some(&(_, _, true, _)) = self.field_remapping.get(&idx) + { + statement.kind = mir::StatementKind::Nop; + } + + self.super_statement(statement, location); + } + + fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) { + // Replace the type of the self arg. + if local == ty::CAPTURE_STRUCT_LOCAL { + local_decl.ty = self.by_move_coroutine_ty; + } + self.super_local_decl(local, local_decl); + } +} diff --git a/compiler/rustc_mir_transform/src/coroutine/drop.rs b/compiler/rustc_mir_transform/src/coroutine/drop.rs new file mode 100644 index 00000000000..dc68629ec0d --- /dev/null +++ b/compiler/rustc_mir_transform/src/coroutine/drop.rs @@ -0,0 +1,759 @@ +//! Drops and async drops related logic for coroutine transformation pass + +use super::*; + +// Fix return Poll<Rv>::Pending statement into Poll<()>::Pending for async drop function +struct FixReturnPendingVisitor<'tcx> { + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for FixReturnPendingVisitor<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_assign( + &mut self, + place: &mut Place<'tcx>, + rvalue: &mut Rvalue<'tcx>, + _location: Location, + ) { + if place.local != RETURN_PLACE { + return; + } + + // Converting `_0 = Poll::<Rv>::Pending` to `_0 = Poll::<()>::Pending` + if let Rvalue::Aggregate(kind, _) = rvalue { + if let AggregateKind::Adt(_, _, ref mut args, _, _) = **kind { + *args = self.tcx.mk_args(&[self.tcx.types.unit.into()]); + } + } + } +} + +// rv = call fut.poll() +fn build_poll_call<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + poll_unit_place: &Place<'tcx>, + switch_block: BasicBlock, + fut_pin_place: &Place<'tcx>, + fut_ty: Ty<'tcx>, + context_ref_place: &Place<'tcx>, + unwind: UnwindAction, +) -> BasicBlock { + let poll_fn = tcx.require_lang_item(LangItem::FuturePoll, DUMMY_SP); + let poll_fn = Ty::new_fn_def(tcx, poll_fn, [fut_ty]); + let poll_fn = Operand::Constant(Box::new(ConstOperand { + span: DUMMY_SP, + user_ty: None, + const_: Const::zero_sized(poll_fn), + })); + let call = TerminatorKind::Call { + func: poll_fn.clone(), + args: [ + dummy_spanned(Operand::Move(*fut_pin_place)), + dummy_spanned(Operand::Move(*context_ref_place)), + ] + .into(), + destination: *poll_unit_place, + target: Some(switch_block), + unwind, + call_source: CallSource::Misc, + fn_span: DUMMY_SP, + }; + insert_term_block(body, call) +} + +// pin_fut = Pin::new_unchecked(&mut fut) +fn build_pin_fut<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + fut_place: Place<'tcx>, + unwind: UnwindAction, +) -> (BasicBlock, Place<'tcx>) { + let span = body.span; + let source_info = SourceInfo::outermost(span); + let fut_ty = fut_place.ty(&body.local_decls, tcx).ty; + let fut_ref_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, fut_ty); + let fut_ref_place = Place::from(body.local_decls.push(LocalDecl::new(fut_ref_ty, span))); + let pin_fut_new_unchecked_fn = + Ty::new_fn_def(tcx, tcx.require_lang_item(LangItem::PinNewUnchecked, span), [fut_ref_ty]); + let fut_pin_ty = pin_fut_new_unchecked_fn.fn_sig(tcx).output().skip_binder(); + let fut_pin_place = Place::from(body.local_decls.push(LocalDecl::new(fut_pin_ty, span))); + let pin_fut_new_unchecked_fn = Operand::Constant(Box::new(ConstOperand { + span, + user_ty: None, + const_: Const::zero_sized(pin_fut_new_unchecked_fn), + })); + + let storage_live = + Statement { source_info, kind: StatementKind::StorageLive(fut_pin_place.local) }; + + let fut_ref_assign = Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + fut_ref_place, + Rvalue::Ref( + tcx.lifetimes.re_erased, + BorrowKind::Mut { kind: MutBorrowKind::Default }, + fut_place, + ), + ))), + }; + + // call Pin<FutTy>::new_unchecked(&mut fut) + let pin_fut_bb = body.basic_blocks_mut().push(BasicBlockData { + statements: [storage_live, fut_ref_assign].to_vec(), + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::Call { + func: pin_fut_new_unchecked_fn, + args: [dummy_spanned(Operand::Move(fut_ref_place))].into(), + destination: fut_pin_place, + target: None, // will be fixed later + unwind, + call_source: CallSource::Misc, + fn_span: span, + }, + }), + is_cleanup: false, + }); + (pin_fut_bb, fut_pin_place) +} + +// Build Poll switch for async drop +// match rv { +// Ready() => ready_block +// Pending => yield_block +//} +fn build_poll_switch<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + poll_enum: Ty<'tcx>, + poll_unit_place: &Place<'tcx>, + fut_pin_place: &Place<'tcx>, + ready_block: BasicBlock, + yield_block: BasicBlock, +) -> BasicBlock { + let poll_enum_adt = poll_enum.ty_adt_def().unwrap(); + + let Discr { val: poll_ready_discr, ty: poll_discr_ty } = poll_enum + .discriminant_for_variant( + tcx, + poll_enum_adt + .variant_index_with_id(tcx.require_lang_item(LangItem::PollReady, DUMMY_SP)), + ) + .unwrap(); + let poll_pending_discr = poll_enum + .discriminant_for_variant( + tcx, + poll_enum_adt + .variant_index_with_id(tcx.require_lang_item(LangItem::PollPending, DUMMY_SP)), + ) + .unwrap() + .val; + let source_info = SourceInfo::outermost(body.span); + let poll_discr_place = + Place::from(body.local_decls.push(LocalDecl::new(poll_discr_ty, source_info.span))); + let discr_assign = Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + poll_discr_place, + Rvalue::Discriminant(*poll_unit_place), + ))), + }; + let storage_dead = + Statement { source_info, kind: StatementKind::StorageDead(fut_pin_place.local) }; + let unreachable_block = insert_term_block(body, TerminatorKind::Unreachable); + body.basic_blocks_mut().push(BasicBlockData { + statements: [storage_dead, discr_assign].to_vec(), + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::SwitchInt { + discr: Operand::Move(poll_discr_place), + targets: SwitchTargets::new( + [(poll_ready_discr, ready_block), (poll_pending_discr, yield_block)] + .into_iter(), + unreachable_block, + ), + }, + }), + is_cleanup: false, + }) +} + +// Gather blocks, reachable through 'drop' targets of Yield and Drop terminators (chained) +fn gather_dropline_blocks<'tcx>(body: &mut Body<'tcx>) -> DenseBitSet<BasicBlock> { + let mut dropline: DenseBitSet<BasicBlock> = DenseBitSet::new_empty(body.basic_blocks.len()); + for (bb, data) in traversal::reverse_postorder(body) { + if dropline.contains(bb) { + data.terminator().successors().for_each(|v| { + dropline.insert(v); + }); + } else { + match data.terminator().kind { + TerminatorKind::Yield { drop: Some(v), .. } => { + dropline.insert(v); + } + TerminatorKind::Drop { drop: Some(v), .. } => { + dropline.insert(v); + } + _ => (), + } + } + } + dropline +} + +/// Cleanup all async drops (reset to sync) +pub(super) fn cleanup_async_drops<'tcx>(body: &mut Body<'tcx>) { + for block in body.basic_blocks_mut() { + if let TerminatorKind::Drop { + place: _, + target: _, + unwind: _, + replace: _, + ref mut drop, + ref mut async_fut, + } = block.terminator_mut().kind + { + if drop.is_some() || async_fut.is_some() { + *drop = None; + *async_fut = None; + } + } + } +} + +pub(super) fn has_expandable_async_drops<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + coroutine_ty: Ty<'tcx>, +) -> bool { + for bb in START_BLOCK..body.basic_blocks.next_index() { + // Drops in unwind path (cleanup blocks) are not expanded to async drops, only sync drops in unwind path + if body[bb].is_cleanup { + continue; + } + let TerminatorKind::Drop { place, target: _, unwind: _, replace: _, drop: _, async_fut } = + body[bb].terminator().kind + else { + continue; + }; + let place_ty = place.ty(&body.local_decls, tcx).ty; + if place_ty == coroutine_ty { + continue; + } + if async_fut.is_none() { + continue; + } + return true; + } + return false; +} + +/// Expand Drop terminator for async drops into mainline poll-switch and dropline poll-switch +pub(super) fn expand_async_drops<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + context_mut_ref: Ty<'tcx>, + coroutine_kind: hir::CoroutineKind, + coroutine_ty: Ty<'tcx>, +) { + let dropline = gather_dropline_blocks(body); + // Clean drop and async_fut fields if potentially async drop is not expanded (stays sync) + let remove_asyncness = |block: &mut BasicBlockData<'tcx>| { + if let TerminatorKind::Drop { + place: _, + target: _, + unwind: _, + replace: _, + ref mut drop, + ref mut async_fut, + } = block.terminator_mut().kind + { + *drop = None; + *async_fut = None; + } + }; + for bb in START_BLOCK..body.basic_blocks.next_index() { + // Drops in unwind path (cleanup blocks) are not expanded to async drops, only sync drops in unwind path + if body[bb].is_cleanup { + remove_asyncness(&mut body[bb]); + continue; + } + let TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut } = + body[bb].terminator().kind + else { + continue; + }; + + let place_ty = place.ty(&body.local_decls, tcx).ty; + if place_ty == coroutine_ty { + remove_asyncness(&mut body[bb]); + continue; + } + + let Some(fut_local) = async_fut else { + remove_asyncness(&mut body[bb]); + continue; + }; + + let is_dropline_bb = dropline.contains(bb); + + if !is_dropline_bb && drop.is_none() { + remove_asyncness(&mut body[bb]); + continue; + } + + let fut_place = Place::from(fut_local); + let fut_ty = fut_place.ty(&body.local_decls, tcx).ty; + + // poll-code: + // state_call_drop: + // #bb_pin: fut_pin = Pin<FutT>::new_unchecked(&mut fut) + // #bb_call: rv = call fut.poll() (or future_drop_poll(fut) for internal future drops) + // #bb_check: match (rv) + // pending => return rv (yield) + // ready => *continue_bb|drop_bb* + + let source_info = body[bb].terminator.as_ref().unwrap().source_info; + + // Compute Poll<> (aka Poll with void return) + let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, source_info.span)); + let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()])); + let poll_decl = LocalDecl::new(poll_enum, source_info.span); + let poll_unit_place = Place::from(body.local_decls.push(poll_decl)); + + // First state-loop yield for mainline + let context_ref_place = + Place::from(body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span))); + let arg = Rvalue::Use(Operand::Move(Place::from(CTX_ARG))); + body[bb].statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((context_ref_place, arg))), + }); + let yield_block = insert_term_block(body, TerminatorKind::Unreachable); // `kind` replaced later to yield + let (pin_bb, fut_pin_place) = + build_pin_fut(tcx, body, fut_place.clone(), UnwindAction::Continue); + let switch_block = build_poll_switch( + tcx, + body, + poll_enum, + &poll_unit_place, + &fut_pin_place, + target, + yield_block, + ); + let call_bb = build_poll_call( + tcx, + body, + &poll_unit_place, + switch_block, + &fut_pin_place, + fut_ty, + &context_ref_place, + unwind, + ); + + // Second state-loop yield for transition to dropline (when coroutine async drop started) + let mut dropline_transition_bb: Option<BasicBlock> = None; + let mut dropline_yield_bb: Option<BasicBlock> = None; + let mut dropline_context_ref: Option<Place<'_>> = None; + let mut dropline_call_bb: Option<BasicBlock> = None; + if !is_dropline_bb { + let context_ref_place2: Place<'_> = Place::from( + body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)), + ); + let drop_yield_block = insert_term_block(body, TerminatorKind::Unreachable); // `kind` replaced later to yield + let (pin_bb2, fut_pin_place2) = + build_pin_fut(tcx, body, fut_place, UnwindAction::Continue); + let drop_switch_block = build_poll_switch( + tcx, + body, + poll_enum, + &poll_unit_place, + &fut_pin_place2, + drop.unwrap(), + drop_yield_block, + ); + let drop_call_bb = build_poll_call( + tcx, + body, + &poll_unit_place, + drop_switch_block, + &fut_pin_place2, + fut_ty, + &context_ref_place2, + unwind, + ); + dropline_transition_bb = Some(pin_bb2); + dropline_yield_bb = Some(drop_yield_block); + dropline_context_ref = Some(context_ref_place2); + dropline_call_bb = Some(drop_call_bb); + } + + let value = + if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)) + { + // For AsyncGen we need `yield Poll<OptRet>::Pending` + let full_yield_ty = body.yield_ty().unwrap(); + let ty::Adt(_poll_adt, args) = *full_yield_ty.kind() else { bug!() }; + let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() }; + let yield_ty = args.type_at(0); + Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + const_: Const::Unevaluated( + UnevaluatedConst::new( + tcx.require_lang_item(LangItem::AsyncGenPending, source_info.span), + tcx.mk_args(&[yield_ty.into()]), + ), + full_yield_ty, + ), + user_ty: None, + })) + } else { + // value needed only for return-yields or gen-coroutines, so just const here + Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::from_bool(tcx, false), + })) + }; + + use rustc_middle::mir::AssertKind::ResumedAfterDrop; + let panic_bb = insert_panic_block(tcx, body, ResumedAfterDrop(coroutine_kind)); + + if is_dropline_bb { + body[yield_block].terminator_mut().kind = TerminatorKind::Yield { + value: value.clone(), + resume: panic_bb, + resume_arg: context_ref_place, + drop: Some(pin_bb), + }; + } else { + body[yield_block].terminator_mut().kind = TerminatorKind::Yield { + value: value.clone(), + resume: pin_bb, + resume_arg: context_ref_place, + drop: dropline_transition_bb, + }; + body[dropline_yield_bb.unwrap()].terminator_mut().kind = TerminatorKind::Yield { + value, + resume: panic_bb, + resume_arg: dropline_context_ref.unwrap(), + drop: dropline_transition_bb, + }; + } + + if let TerminatorKind::Call { ref mut target, .. } = body[pin_bb].terminator_mut().kind { + *target = Some(call_bb); + } else { + bug!() + } + if !is_dropline_bb { + if let TerminatorKind::Call { ref mut target, .. } = + body[dropline_transition_bb.unwrap()].terminator_mut().kind + { + *target = dropline_call_bb; + } else { + bug!() + } + } + + body[bb].terminator_mut().kind = TerminatorKind::Goto { target: pin_bb }; + } +} + +pub(super) fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + use crate::elaborate_drop::{Unwind, elaborate_drop}; + use crate::patch::MirPatch; + use crate::shim::DropShimElaborator; + + // Note that `elaborate_drops` only drops the upvars of a coroutine, and + // this is ok because `open_drop` can only be reached within that own + // coroutine's resume function. + let typing_env = body.typing_env(tcx); + + let mut elaborator = DropShimElaborator { + body, + patch: MirPatch::new(body), + tcx, + typing_env, + produce_async_drops: false, + }; + + for (block, block_data) in body.basic_blocks.iter_enumerated() { + let (target, unwind, source_info, dropline) = match block_data.terminator() { + Terminator { + source_info, + kind: TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut: _ }, + } => { + if let Some(local) = place.as_local() + && local == SELF_ARG + { + (target, unwind, source_info, *drop) + } else { + continue; + } + } + _ => continue, + }; + let unwind = if block_data.is_cleanup { + Unwind::InCleanup + } else { + Unwind::To(match *unwind { + UnwindAction::Cleanup(tgt) => tgt, + UnwindAction::Continue => elaborator.patch.resume_block(), + UnwindAction::Unreachable => elaborator.patch.unreachable_cleanup_block(), + UnwindAction::Terminate(reason) => elaborator.patch.terminate_block(reason), + }) + }; + elaborate_drop( + &mut elaborator, + *source_info, + Place::from(SELF_ARG), + (), + *target, + unwind, + block, + dropline, + ); + } + elaborator.patch.apply(body); +} + +pub(super) fn insert_clean_drop<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + has_async_drops: bool, +) -> BasicBlock { + let source_info = SourceInfo::outermost(body.span); + let return_block = if has_async_drops { + insert_poll_ready_block(tcx, body) + } else { + insert_term_block(body, TerminatorKind::Return) + }; + + // FIXME: When move insert_clean_drop + elaborate_coroutine_drops before async drops expand, + // also set dropline here: + // let dropline = if has_async_drops { Some(return_block) } else { None }; + let dropline = None; + + let term = TerminatorKind::Drop { + place: Place::from(SELF_ARG), + target: return_block, + unwind: UnwindAction::Continue, + replace: false, + drop: dropline, + async_fut: None, + }; + + // Create a block to destroy an unresumed coroutines. This can only destroy upvars. + body.basic_blocks_mut().push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { source_info, kind: term }), + is_cleanup: false, + }) +} + +pub(super) fn create_coroutine_drop_shim<'tcx>( + tcx: TyCtxt<'tcx>, + transform: &TransformVisitor<'tcx>, + coroutine_ty: Ty<'tcx>, + body: &Body<'tcx>, + drop_clean: BasicBlock, +) -> Body<'tcx> { + let mut body = body.clone(); + // Take the coroutine info out of the body, since the drop shim is + // not a coroutine body itself; it just has its drop built out of it. + let _ = body.coroutine.take(); + // Make sure the resume argument is not included here, since we're + // building a body for `drop_in_place`. + body.arg_count = 1; + + let source_info = SourceInfo::outermost(body.span); + + let mut cases = create_cases(&mut body, transform, Operation::Drop); + + cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean)); + + // The returned state and the poisoned state fall through to the default + // case which is just to return + + let default_block = insert_term_block(&mut body, TerminatorKind::Return); + insert_switch(&mut body, cases, transform, default_block); + + for block in body.basic_blocks_mut() { + let kind = &mut block.terminator_mut().kind; + if let TerminatorKind::CoroutineDrop = *kind { + *kind = TerminatorKind::Return; + } + } + + // Replace the return variable + body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(tcx.types.unit, source_info); + + make_coroutine_state_argument_indirect(tcx, &mut body); + + // Change the coroutine argument from &mut to *mut + body.local_decls[SELF_ARG] = + LocalDecl::with_source_info(Ty::new_mut_ptr(tcx, coroutine_ty), source_info); + + // Make sure we remove dead blocks to remove + // unrelated code from the resume part of the function + simplify::remove_dead_blocks(&mut body); + + // Update the body's def to become the drop glue. + let coroutine_instance = body.source.instance; + let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, body.span); + let drop_instance = InstanceKind::DropGlue(drop_in_place, Some(coroutine_ty)); + + // Temporary change MirSource to coroutine's instance so that dump_mir produces more sensible + // filename. + body.source.instance = coroutine_instance; + dump_mir(tcx, false, "coroutine_drop", &0, &body, |_, _| Ok(())); + body.source.instance = drop_instance; + + // Creating a coroutine drop shim happens on `Analysis(PostCleanup) -> Runtime(Initial)` + // but the pass manager doesn't update the phase of the coroutine drop shim. Update the + // phase of the drop shim so that later on when we run the pass manager on the shim, in + // the `mir_shims` query, we don't ICE on the intra-pass validation before we've updated + // the phase of the body from analysis. + body.phase = MirPhase::Runtime(RuntimePhase::Initial); + + body +} + +// Create async drop shim function to drop coroutine itself +pub(super) fn create_coroutine_drop_shim_async<'tcx>( + tcx: TyCtxt<'tcx>, + transform: &TransformVisitor<'tcx>, + body: &Body<'tcx>, + drop_clean: BasicBlock, + can_unwind: bool, +) -> Body<'tcx> { + let mut body = body.clone(); + // Take the coroutine info out of the body, since the drop shim is + // not a coroutine body itself; it just has its drop built out of it. + let _ = body.coroutine.take(); + + FixReturnPendingVisitor { tcx }.visit_body(&mut body); + + // Poison the coroutine when it unwinds + if can_unwind { + generate_poison_block_and_redirect_unwinds_there(transform, &mut body); + } + + let source_info = SourceInfo::outermost(body.span); + + let mut cases = create_cases(&mut body, transform, Operation::Drop); + + cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean)); + + use rustc_middle::mir::AssertKind::ResumedAfterPanic; + // Panic when resumed on the returned or poisoned state + if can_unwind { + cases.insert( + 1, + ( + CoroutineArgs::POISONED, + insert_panic_block(tcx, &mut body, ResumedAfterPanic(transform.coroutine_kind)), + ), + ); + } + + // RETURNED state also goes to default_block with `return Ready<()>`. + // For fully-polled coroutine, async drop has nothing to do. + let default_block = insert_poll_ready_block(tcx, &mut body); + insert_switch(&mut body, cases, transform, default_block); + + for block in body.basic_blocks_mut() { + let kind = &mut block.terminator_mut().kind; + if let TerminatorKind::CoroutineDrop = *kind { + *kind = TerminatorKind::Return; + block.statements.push(return_poll_ready_assign(tcx, source_info)); + } + } + + // Replace the return variable: Poll<RetT> to Poll<()> + let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, body.span)); + let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()])); + body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info); + + make_coroutine_state_argument_indirect(tcx, &mut body); + + match transform.coroutine_kind { + // Iterator::next doesn't accept a pinned argument, + // unlike for all other coroutine kinds. + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {} + _ => { + make_coroutine_state_argument_pinned(tcx, &mut body); + } + } + + // Make sure we remove dead blocks to remove + // unrelated code from the resume part of the function + simplify::remove_dead_blocks(&mut body); + + pm::run_passes_no_validate( + tcx, + &mut body, + &[&abort_unwinding_calls::AbortUnwindingCalls], + None, + ); + + dump_mir(tcx, false, "coroutine_drop_async", &0, &body, |_, _| Ok(())); + + body +} + +// Create async drop shim proxy function for future_drop_poll +// It is just { call coroutine_drop(); return Poll::Ready(); } +pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, +) -> Body<'tcx> { + let mut body = body.clone(); + // Take the coroutine info out of the body, since the drop shim is + // not a coroutine body itself; it just has its drop built out of it. + let _ = body.coroutine.take(); + let basic_blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>> = IndexVec::new(); + body.basic_blocks = BasicBlocks::new(basic_blocks); + body.var_debug_info.clear(); + + // Keeping return value and args + body.local_decls.truncate(1 + body.arg_count); + + let source_info = SourceInfo::outermost(body.span); + + // Replace the return variable: Poll<RetT> to Poll<()> + let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, body.span)); + let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()])); + body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info); + + // call coroutine_drop() + let call_bb = body.basic_blocks_mut().push(BasicBlockData { + statements: Vec::new(), + terminator: None, + is_cleanup: false, + }); + + // return Poll::Ready() + let ret_bb = insert_poll_ready_block(tcx, &mut body); + + let kind = TerminatorKind::Drop { + place: Place::from(SELF_ARG), + target: ret_bb, + unwind: UnwindAction::Continue, + replace: false, + drop: None, + async_fut: None, + }; + body.basic_blocks_mut()[call_bb].terminator = Some(Terminator { source_info, kind }); + + dump_mir(tcx, false, "coroutine_drop_proxy_async", &0, &body, |_, _| Ok(())); + + body +} diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs new file mode 100644 index 00000000000..00a8293966b --- /dev/null +++ b/compiler/rustc_mir_transform/src/cost_checker.rs @@ -0,0 +1,200 @@ +use rustc_middle::bug; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; + +const INSTR_COST: usize = 5; +const CALL_PENALTY: usize = 25; +const LANDINGPAD_PENALTY: usize = 50; +const RESUME_PENALTY: usize = 45; +const LARGE_SWITCH_PENALTY: usize = 20; +const CONST_SWITCH_BONUS: usize = 10; + +/// Verify that the callee body is compatible with the caller. +#[derive(Clone)] +pub(super) struct CostChecker<'b, 'tcx> { + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + penalty: usize, + bonus: usize, + callee_body: &'b Body<'tcx>, + instance: Option<ty::Instance<'tcx>>, +} + +impl<'b, 'tcx> CostChecker<'b, 'tcx> { + pub(super) fn new( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + instance: Option<ty::Instance<'tcx>>, + callee_body: &'b Body<'tcx>, + ) -> CostChecker<'b, 'tcx> { + CostChecker { tcx, typing_env, callee_body, instance, penalty: 0, bonus: 0 } + } + + /// Add function-level costs not well-represented by the block-level costs. + /// + /// Needed because the `CostChecker` is used sometimes for just blocks, + /// and even the full `Inline` doesn't call `visit_body`, so there's nowhere + /// to put this logic in the visitor. + pub(super) fn add_function_level_costs(&mut self) { + // If the only has one Call (or similar), inlining isn't increasing the total + // number of calls, so give extra encouragement to inlining that. + if self.callee_body.basic_blocks.iter().filter(|bbd| is_call_like(bbd.terminator())).count() + == 1 + { + self.bonus += CALL_PENALTY; + } + } + + pub(super) fn cost(&self) -> usize { + usize::saturating_sub(self.penalty, self.bonus) + } + + fn instantiate_ty(&self, v: Ty<'tcx>) -> Ty<'tcx> { + if let Some(instance) = self.instance { + instance.instantiate_mir(self.tcx, ty::EarlyBinder::bind(&v)) + } else { + v + } + } +} + +impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + // Most costs are in rvalues and terminators, not in statements. + match statement.kind { + StatementKind::Intrinsic(ref ndi) => { + self.penalty += match **ndi { + NonDivergingIntrinsic::Assume(..) => INSTR_COST, + NonDivergingIntrinsic::CopyNonOverlapping(..) => CALL_PENALTY, + }; + } + _ => self.super_statement(statement, location), + } + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) { + match rvalue { + Rvalue::NullaryOp(NullOp::UbChecks, ..) + if !self + .tcx + .sess + .opts + .unstable_opts + .inline_mir_preserve_debug + .unwrap_or(self.tcx.sess.ub_checks()) => + { + // If this is in optimized MIR it's because it's used later, + // so if we don't need UB checks this session, give a bonus + // here to offset the cost of the call later. + self.bonus += CALL_PENALTY; + } + // These are essentially constants that didn't end up in an Operand, + // so treat them as also being free. + Rvalue::NullaryOp(..) => {} + _ => self.penalty += INSTR_COST, + } + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) { + match &terminator.kind { + TerminatorKind::Drop { place, unwind, .. } => { + // If the place doesn't actually need dropping, treat it like a regular goto. + let ty = self.instantiate_ty(place.ty(self.callee_body, self.tcx).ty); + if ty.needs_drop(self.tcx, self.typing_env) { + self.penalty += CALL_PENALTY; + if let UnwindAction::Cleanup(_) = unwind { + self.penalty += LANDINGPAD_PENALTY; + } + } + } + TerminatorKind::Call { func, unwind, .. } => { + self.penalty += if let Some((def_id, ..)) = func.const_fn_def() + && self.tcx.intrinsic(def_id).is_some() + { + // Don't give intrinsics the extra penalty for calls + INSTR_COST + } else { + CALL_PENALTY + }; + if let UnwindAction::Cleanup(_) = unwind { + self.penalty += LANDINGPAD_PENALTY; + } + } + TerminatorKind::TailCall { .. } => { + self.penalty += CALL_PENALTY; + } + TerminatorKind::SwitchInt { discr, targets } => { + if discr.constant().is_some() { + // Not only will this become a `Goto`, but likely other + // things will be removable as unreachable. + self.bonus += CONST_SWITCH_BONUS; + } else if targets.all_targets().len() > 3 { + // More than false/true/unreachable gets extra cost. + self.penalty += LARGE_SWITCH_PENALTY; + } else { + self.penalty += INSTR_COST; + } + } + TerminatorKind::Assert { unwind, msg, .. } => { + self.penalty += if msg.is_optional_overflow_check() + && !self + .tcx + .sess + .opts + .unstable_opts + .inline_mir_preserve_debug + .unwrap_or(self.tcx.sess.overflow_checks()) + { + INSTR_COST + } else { + CALL_PENALTY + }; + if let UnwindAction::Cleanup(_) = unwind { + self.penalty += LANDINGPAD_PENALTY; + } + } + TerminatorKind::UnwindResume => self.penalty += RESUME_PENALTY, + TerminatorKind::InlineAsm { unwind, .. } => { + self.penalty += INSTR_COST; + if let UnwindAction::Cleanup(_) = unwind { + self.penalty += LANDINGPAD_PENALTY; + } + } + TerminatorKind::Unreachable => { + self.bonus += INSTR_COST; + } + TerminatorKind::Goto { .. } | TerminatorKind::Return => {} + TerminatorKind::UnwindTerminate(..) => {} + kind @ (TerminatorKind::FalseUnwind { .. } + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::Yield { .. } + | TerminatorKind::CoroutineDrop) => { + bug!("{kind:?} should not be in runtime MIR"); + } + } + } +} + +/// A terminator that's more call-like (might do a bunch of work, might panic, etc) +/// than it is goto-/return-like (no side effects, etc). +/// +/// Used to treat multi-call functions (which could inline exponentially) +/// different from those that only do one or none of these "complex" things. +pub(super) fn is_call_like(terminator: &Terminator<'_>) -> bool { + use TerminatorKind::*; + match terminator.kind { + Call { .. } | TailCall { .. } | Drop { .. } | Assert { .. } | InlineAsm { .. } => true, + + Goto { .. } + | SwitchInt { .. } + | UnwindResume + | UnwindTerminate(_) + | Return + | Unreachable => false, + + Yield { .. } | CoroutineDrop | FalseEdge { .. } | FalseUnwind { .. } => { + unreachable!() + } + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs new file mode 100644 index 00000000000..5568d42ab8f --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -0,0 +1,200 @@ +use std::cmp::Ordering; + +use either::Either; +use itertools::Itertools; +use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; +use rustc_data_structures::graph::DirectedGraph; +use rustc_index::IndexVec; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::mir::coverage::{CounterId, CovTerm, Expression, ExpressionId, Op}; + +use crate::coverage::counters::balanced_flow::BalancedFlowGraph; +use crate::coverage::counters::node_flow::{ + CounterTerm, NodeCounters, NodeFlowData, node_flow_data_for_balanced_graph, +}; +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph}; + +mod balanced_flow; +pub(crate) mod node_flow; +mod union_find; + +/// Struct containing the results of [`prepare_bcb_counters_data`]. +pub(crate) struct BcbCountersData { + pub(crate) node_flow_data: NodeFlowData<BasicCoverageBlock>, + pub(crate) priority_list: Vec<BasicCoverageBlock>, +} + +/// Analyzes the coverage graph to create intermediate data structures that +/// will later be used (during codegen) to create physical counters or counter +/// expressions for each BCB node that needs one. +pub(crate) fn prepare_bcb_counters_data(graph: &CoverageGraph) -> BcbCountersData { + // Create the derived graphs that are necessary for subsequent steps. + let balanced_graph = BalancedFlowGraph::for_graph(graph, |n| !graph[n].is_out_summable); + let node_flow_data = node_flow_data_for_balanced_graph(&balanced_graph); + + // Also create a "priority list" of coverage graph nodes, to help determine + // which ones get physical counters or counter expressions. This needs to + // be done now, because the later parts of the counter-creation process + // won't have access to the original coverage graph. + let priority_list = make_node_flow_priority_list(graph, balanced_graph); + + BcbCountersData { node_flow_data, priority_list } +} + +/// Arranges the nodes in `balanced_graph` into a list, such that earlier nodes +/// take priority in being given a counter expression instead of a physical counter. +fn make_node_flow_priority_list( + graph: &CoverageGraph, + balanced_graph: BalancedFlowGraph<&CoverageGraph>, +) -> Vec<BasicCoverageBlock> { + // A "reloop" node has exactly one out-edge, which jumps back to the top + // of an enclosing loop. Reloop nodes are typically visited more times + // than loop-exit nodes, so try to avoid giving them physical counters. + let is_reloop_node = IndexVec::<BasicCoverageBlock, _>::from_fn_n( + |node| match graph.successors[node].as_slice() { + &[succ] => graph.dominates(succ, node), + _ => false, + }, + graph.num_nodes(), + ); + + let mut nodes = balanced_graph.iter_nodes().rev().collect::<Vec<_>>(); + // The first node is the sink, which must not get a physical counter. + assert_eq!(nodes[0], balanced_graph.sink); + // Sort the real nodes, such that earlier (lesser) nodes take priority + // in being given a counter expression instead of a physical counter. + nodes[1..].sort_by(|&a, &b| { + // Start with a dummy `Equal` to make the actual tests line up nicely. + Ordering::Equal + // Prefer a physical counter for return/yield nodes. + .then_with(|| Ord::cmp(&graph[a].is_out_summable, &graph[b].is_out_summable)) + // Prefer an expression for reloop nodes (see definition above). + .then_with(|| Ord::cmp(&is_reloop_node[a], &is_reloop_node[b]).reverse()) + // Otherwise, prefer a physical counter for dominating nodes. + .then_with(|| graph.cmp_in_dominator_order(a, b).reverse()) + }); + nodes +} + +// Converts node counters into a form suitable for embedding into MIR. +pub(crate) fn transcribe_counters( + old: &NodeCounters<BasicCoverageBlock>, + bcb_needs_counter: &DenseBitSet<BasicCoverageBlock>, + bcbs_seen: &DenseBitSet<BasicCoverageBlock>, +) -> CoverageCounters { + let mut new = CoverageCounters::with_num_bcbs(bcb_needs_counter.domain_size()); + + for bcb in bcb_needs_counter.iter() { + if !bcbs_seen.contains(bcb) { + // This BCB's code was removed by MIR opts, so its counter is always zero. + new.set_node_counter(bcb, CovTerm::Zero); + continue; + } + + // Our counter-creation algorithm doesn't guarantee that a node's list + // of terms starts or ends with a positive term, so partition the + // counters into "positive" and "negative" lists for easier handling. + let (mut pos, mut neg): (Vec<_>, Vec<_>) = old.counter_terms[bcb] + .iter() + // Filter out any BCBs that were removed by MIR opts; + // this treats them as having an execution count of 0. + .filter(|term| bcbs_seen.contains(term.node)) + .partition_map(|&CounterTerm { node, op }| match op { + Op::Add => Either::Left(node), + Op::Subtract => Either::Right(node), + }); + + // These intermediate sorts are not strictly necessary, but were helpful + // in reducing churn when switching to the current counter-creation scheme. + // They also help to slightly decrease the overall size of the expression + // table, due to more subexpressions being shared. + pos.sort(); + neg.sort(); + + let mut new_counters_for_sites = |sites: Vec<BasicCoverageBlock>| { + sites.into_iter().map(|node| new.ensure_phys_counter(node)).collect::<Vec<_>>() + }; + let pos = new_counters_for_sites(pos); + let neg = new_counters_for_sites(neg); + + let pos_counter = new.make_sum(&pos).unwrap_or(CovTerm::Zero); + let new_counter = new.make_subtracted_sum(pos_counter, &neg); + new.set_node_counter(bcb, new_counter); + } + + new +} + +/// Generates and stores coverage counter and coverage expression information +/// associated with nodes in the coverage graph. +pub(super) struct CoverageCounters { + /// List of places where a counter-increment statement should be injected + /// into MIR, each with its corresponding counter ID. + pub(crate) phys_counter_for_node: FxIndexMap<BasicCoverageBlock, CounterId>, + pub(crate) next_counter_id: CounterId, + + /// Coverage counters/expressions that are associated with individual BCBs. + pub(crate) node_counters: IndexVec<BasicCoverageBlock, Option<CovTerm>>, + + /// Table of expression data, associating each expression ID with its + /// corresponding operator (+ or -) and its LHS/RHS operands. + pub(crate) expressions: IndexVec<ExpressionId, Expression>, + /// Remember expressions that have already been created (or simplified), + /// so that we don't create unnecessary duplicates. + expressions_memo: FxHashMap<Expression, CovTerm>, +} + +impl CoverageCounters { + fn with_num_bcbs(num_bcbs: usize) -> Self { + Self { + phys_counter_for_node: FxIndexMap::default(), + next_counter_id: CounterId::ZERO, + node_counters: IndexVec::from_elem_n(None, num_bcbs), + expressions: IndexVec::new(), + expressions_memo: FxHashMap::default(), + } + } + + /// Returns the physical counter for the given node, creating it if necessary. + fn ensure_phys_counter(&mut self, bcb: BasicCoverageBlock) -> CovTerm { + let id = *self.phys_counter_for_node.entry(bcb).or_insert_with(|| { + let id = self.next_counter_id; + self.next_counter_id = id + 1; + id + }); + CovTerm::Counter(id) + } + + fn make_expression(&mut self, lhs: CovTerm, op: Op, rhs: CovTerm) -> CovTerm { + let new_expr = Expression { lhs, op, rhs }; + *self.expressions_memo.entry(new_expr.clone()).or_insert_with(|| { + let id = self.expressions.push(new_expr); + CovTerm::Expression(id) + }) + } + + /// Creates a counter that is the sum of the given counters. + /// + /// Returns `None` if the given list of counters was empty. + fn make_sum(&mut self, counters: &[CovTerm]) -> Option<CovTerm> { + counters + .iter() + .copied() + .reduce(|accum, counter| self.make_expression(accum, Op::Add, counter)) + } + + /// Creates a counter whose value is `lhs - SUM(rhs)`. + fn make_subtracted_sum(&mut self, lhs: CovTerm, rhs: &[CovTerm]) -> CovTerm { + let Some(rhs_sum) = self.make_sum(rhs) else { return lhs }; + self.make_expression(lhs, Op::Subtract, rhs_sum) + } + + fn set_node_counter(&mut self, bcb: BasicCoverageBlock, counter: CovTerm) -> CovTerm { + let existing = self.node_counters[bcb].replace(counter); + assert!( + existing.is_none(), + "node {bcb:?} already has a counter: {existing:?} => {counter:?}" + ); + counter + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/balanced_flow.rs b/compiler/rustc_mir_transform/src/coverage/counters/balanced_flow.rs new file mode 100644 index 00000000000..4c20722a043 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/balanced_flow.rs @@ -0,0 +1,131 @@ +//! A control-flow graph can be said to have “balanced flow” if the flow +//! (execution count) of each node is equal to the sum of its in-edge flows, +//! and also equal to the sum of its out-edge flows. +//! +//! Control-flow graphs typically have one or more nodes that don't satisfy the +//! balanced-flow property, e.g.: +//! - The start node has out-edges, but no in-edges. +//! - Return nodes have in-edges, but no out-edges. +//! - `Yield` nodes can have an out-flow that is less than their in-flow. +//! - Inescapable loops cause the in-flow/out-flow relationship to break down. +//! +//! Balanced-flow graphs are nevertheless useful for analysis, so this module +//! provides a wrapper type ([`BalancedFlowGraph`]) that imposes balanced flow +//! on an underlying graph. This is done by non-destructively adding synthetic +//! nodes and edges as necessary. + +use rustc_data_structures::graph; +use rustc_data_structures::graph::iterate::DepthFirstSearch; +use rustc_data_structures::graph::reversed::ReversedGraph; +use rustc_index::Idx; +use rustc_index::bit_set::DenseBitSet; + +/// A view of an underlying graph that has been augmented to have “balanced flow”. +/// This means that the flow (execution count) of each node is equal to the +/// sum of its in-edge flows, and also equal to the sum of its out-edge flows. +/// +/// To achieve this, a synthetic "sink" node is non-destructively added to the +/// graph, with synthetic in-edges from these nodes: +/// - Any node that has no out-edges. +/// - Any node that explicitly requires a sink edge, as indicated by a +/// caller-supplied `force_sink_edge` function. +/// - Any node that would otherwise be unable to reach the sink, because it is +/// part of an inescapable loop. +/// +/// To make the graph fully balanced, there is also a synthetic edge from the +/// sink node back to the start node. +/// +/// --- +/// The benefit of having a balanced-flow graph is that it can be subsequently +/// transformed in ways that are guaranteed to preserve balanced flow +/// (e.g. merging nodes together), which is useful for discovering relationships +/// between the node flows of different nodes in the graph. +pub(crate) struct BalancedFlowGraph<G: graph::DirectedGraph> { + graph: G, + sink_edge_nodes: DenseBitSet<G::Node>, + pub(crate) sink: G::Node, +} + +impl<G: graph::DirectedGraph> BalancedFlowGraph<G> { + /// Creates a balanced view of an underlying graph, by adding a synthetic + /// sink node that has in-edges from nodes that need or request such an edge, + /// and a single out-edge to the start node. + /// + /// Assumes that all nodes in the underlying graph are reachable from the + /// start node. + pub(crate) fn for_graph(graph: G, force_sink_edge: impl Fn(G::Node) -> bool) -> Self + where + G: graph::ControlFlowGraph, + { + let mut sink_edge_nodes = DenseBitSet::new_empty(graph.num_nodes()); + let mut dfs = DepthFirstSearch::new(ReversedGraph::new(&graph)); + + // First, determine the set of nodes that explicitly request or require + // an out-edge to the sink. + for node in graph.iter_nodes() { + if force_sink_edge(node) || graph.successors(node).next().is_none() { + sink_edge_nodes.insert(node); + dfs.push_start_node(node); + } + } + + // Next, find all nodes that are currently not reverse-reachable from + // `sink_edge_nodes`, and add them to the set as well. + dfs.complete_search(); + sink_edge_nodes.union_not(dfs.visited_set()); + + // The sink node is 1 higher than the highest real node. + let sink = G::Node::new(graph.num_nodes()); + + BalancedFlowGraph { graph, sink_edge_nodes, sink } + } +} + +impl<G> graph::DirectedGraph for BalancedFlowGraph<G> +where + G: graph::DirectedGraph, +{ + type Node = G::Node; + + /// Returns the number of nodes in this balanced-flow graph, which is 1 + /// more than the number of nodes in the underlying graph, to account for + /// the synthetic sink node. + fn num_nodes(&self) -> usize { + // The sink node's index is already the size of the underlying graph, + // so just add 1 to that instead. + self.sink.index() + 1 + } +} + +impl<G> graph::StartNode for BalancedFlowGraph<G> +where + G: graph::StartNode, +{ + fn start_node(&self) -> Self::Node { + self.graph.start_node() + } +} + +impl<G> graph::Successors for BalancedFlowGraph<G> +where + G: graph::StartNode + graph::Successors, +{ + fn successors(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> { + let real_edges; + let sink_edge; + + if node == self.sink { + // The sink node has no real out-edges, and one synthetic out-edge + // to the start node. + real_edges = None; + sink_edge = Some(self.graph.start_node()); + } else { + // Real nodes have their real out-edges, and possibly one synthetic + // out-edge to the sink node. + real_edges = Some(self.graph.successors(node)); + sink_edge = self.sink_edge_nodes.contains(node).then_some(self.sink); + } + + real_edges.into_iter().flatten().chain(sink_edge) + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/node_flow.rs b/compiler/rustc_mir_transform/src/coverage/counters/node_flow.rs new file mode 100644 index 00000000000..91ed54b8b59 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/node_flow.rs @@ -0,0 +1,274 @@ +//! For each node in a control-flow graph, determines whether that node should +//! have a physical counter, or a counter expression that is derived from the +//! physical counters of other nodes. +//! +//! Based on the algorithm given in +//! "Optimal measurement points for program frequency counts" +//! (Knuth & Stevenson, 1973). + +use rustc_data_structures::graph; +use rustc_index::bit_set::DenseBitSet; +use rustc_index::{Idx, IndexSlice, IndexVec}; +pub(crate) use rustc_middle::mir::coverage::NodeFlowData; +use rustc_middle::mir::coverage::Op; + +use crate::coverage::counters::union_find::UnionFind; + +#[cfg(test)] +mod tests; + +/// Creates a "merged" view of an underlying graph. +/// +/// The given graph is assumed to have [“balanced flow”](balanced-flow), +/// though it does not necessarily have to be a `BalancedFlowGraph`. +/// +/// [balanced-flow]: `crate::coverage::counters::balanced_flow::BalancedFlowGraph`. +pub(crate) fn node_flow_data_for_balanced_graph<G>(graph: G) -> NodeFlowData<G::Node> +where + G: graph::Successors, +{ + let mut supernodes = UnionFind::<G::Node>::new(graph.num_nodes()); + + // For each node, merge its successors into a single supernode, and + // arbitrarily choose one of those successors to represent all of them. + let successors = graph + .iter_nodes() + .map(|node| { + graph + .successors(node) + .reduce(|a, b| supernodes.unify(a, b)) + .expect("each node in a balanced graph must have at least one out-edge") + }) + .collect::<IndexVec<G::Node, G::Node>>(); + + // Now that unification is complete, take a snapshot of the supernode forest, + // and resolve each arbitrarily-chosen successor to its canonical root. + // (This avoids having to explicitly resolve them later.) + let supernodes = supernodes.snapshot(); + let succ_supernodes = successors.into_iter().map(|succ| supernodes[succ]).collect(); + + NodeFlowData { supernodes, succ_supernodes } +} + +/// Uses the graph information in `node_flow_data`, together with a given +/// permutation of all nodes in the graph, to create physical counters and +/// counter expressions for each node in the underlying graph. +/// +/// The given list must contain exactly one copy of each node in the +/// underlying balanced-flow graph. The order of nodes is used as a hint to +/// influence counter allocation: +/// - Earlier nodes are more likely to receive counter expressions. +/// - Later nodes are more likely to receive physical counters. +pub(crate) fn make_node_counters<Node: Idx>( + node_flow_data: &NodeFlowData<Node>, + priority_list: &[Node], +) -> NodeCounters<Node> { + let mut builder = SpantreeBuilder::new(node_flow_data); + + for &node in priority_list { + builder.visit_node(node); + } + + NodeCounters { counter_terms: builder.finish() } +} + +/// End result of allocating physical counters and counter expressions for the +/// nodes of a graph. +#[derive(Debug)] +pub(crate) struct NodeCounters<Node: Idx> { + /// For the given node, returns the finished list of terms that represent + /// its physical counter or counter expression. Always non-empty. + /// + /// If a node was given a physical counter, the term list will contain + /// that counter as its sole element. + pub(crate) counter_terms: IndexVec<Node, Vec<CounterTerm<Node>>>, +} + +#[derive(Debug)] +struct SpantreeEdge<Node> { + /// If true, this edge in the spantree has been reversed an odd number of + /// times, so all physical counters added to its node's counter expression + /// need to be negated. + is_reversed: bool, + /// Each spantree edge is "claimed" by the (regular) node that caused it to + /// be created. When a node with a physical counter traverses this edge, + /// that counter is added to the claiming node's counter expression. + claiming_node: Node, + /// Supernode at the other end of this spantree edge. Transitively points + /// to the "root" of this supernode's spantree component. + span_parent: Node, +} + +/// Part of a node's counter expression, which is a sum of counter terms. +#[derive(Debug)] +pub(crate) struct CounterTerm<Node> { + /// Whether to add or subtract the value of the node's physical counter. + pub(crate) op: Op, + /// The node whose physical counter is represented by this term. + pub(crate) node: Node, +} + +#[derive(Debug)] +struct SpantreeBuilder<'a, Node: Idx> { + supernodes: &'a IndexSlice<Node, Node>, + succ_supernodes: &'a IndexSlice<Node, Node>, + + is_unvisited: DenseBitSet<Node>, + /// Links supernodes to each other, gradually forming a spanning tree of + /// the merged-flow graph. + /// + /// A supernode without a span edge is the root of its component of the + /// spantree. Nodes that aren't supernodes cannot have a spantree edge. + span_edges: IndexVec<Node, Option<SpantreeEdge<Node>>>, + /// Shared path buffer recycled by all calls to `yank_to_spantree_root`. + yank_buffer: Vec<Node>, + /// An in-progress counter expression for each node. Each expression is + /// initially empty, and will be filled in as relevant nodes are visited. + counter_terms: IndexVec<Node, Vec<CounterTerm<Node>>>, +} + +impl<'a, Node: Idx> SpantreeBuilder<'a, Node> { + fn new(node_flow_data: &'a NodeFlowData<Node>) -> Self { + let NodeFlowData { supernodes, succ_supernodes } = node_flow_data; + let num_nodes = supernodes.len(); + Self { + supernodes, + succ_supernodes, + is_unvisited: DenseBitSet::new_filled(num_nodes), + span_edges: IndexVec::from_fn_n(|_| None, num_nodes), + yank_buffer: vec![], + counter_terms: IndexVec::from_fn_n(|_| vec![], num_nodes), + } + } + + fn is_supernode(&self, node: Node) -> bool { + self.supernodes[node] == node + } + + /// Given a supernode, finds the supernode that is the "root" of its + /// spantree component. Two nodes that have the same spantree root are + /// connected in the spantree. + fn spantree_root(&self, this: Node) -> Node { + debug_assert!(self.is_supernode(this)); + + match self.span_edges[this] { + None => this, + Some(SpantreeEdge { span_parent, .. }) => self.spantree_root(span_parent), + } + } + + /// Rotates edges in the spantree so that `this` is the root of its + /// spantree component. + fn yank_to_spantree_root(&mut self, this: Node) { + debug_assert!(self.is_supernode(this)); + + // The rotation is done iteratively, by first traversing from `this` to + // its root and storing the path in a buffer, and then traversing the + // path buffer backwards to reverse all the edges. + + // Recycle the same path buffer for all calls to this method. + let path_buf = &mut self.yank_buffer; + path_buf.clear(); + path_buf.push(this); + + // Traverse the spantree until we reach a supernode that has no + // span-parent, which must be the root. + let mut curr = this; + while let &Some(SpantreeEdge { span_parent, .. }) = &self.span_edges[curr] { + path_buf.push(span_parent); + curr = span_parent; + } + + // For each spantree edge `a -> b` in the path that was just traversed, + // reverse it to become `a <- b`, while preserving `claiming_node`. + for &[a, b] in path_buf.array_windows::<2>().rev() { + let SpantreeEdge { is_reversed, claiming_node, span_parent } = self.span_edges[a] + .take() + .expect("all nodes in the path (except the last) have a `span_parent`"); + debug_assert_eq!(span_parent, b); + debug_assert!(self.span_edges[b].is_none()); + self.span_edges[b] = + Some(SpantreeEdge { is_reversed: !is_reversed, claiming_node, span_parent: a }); + } + + // The result of the rotation is that `this` is now a spantree root. + debug_assert!(self.span_edges[this].is_none()); + } + + /// Must be called exactly once for each node in the balanced-flow graph. + fn visit_node(&mut self, this: Node) { + // Assert that this node was unvisited, and mark it visited. + assert!(self.is_unvisited.remove(this), "node has already been visited: {this:?}"); + + // Get the supernode containing `this`, and make it the root of its + // component of the spantree. + let this_supernode = self.supernodes[this]; + self.yank_to_spantree_root(this_supernode); + + // Get the supernode containing all of this's successors. + let succ_supernode = self.succ_supernodes[this]; + debug_assert!(self.is_supernode(succ_supernode)); + + // If two supernodes are already connected in the spantree, they will + // have the same spantree root. (Each supernode is connected to itself.) + if this_supernode != self.spantree_root(succ_supernode) { + // Adding this node's flow edge to the spantree would cause two + // previously-disconnected supernodes to become connected, so add + // it. That spantree-edge is now "claimed" by this node. + // + // Claiming a spantree-edge means that this node will get a counter + // expression instead of a physical counter. That expression is + // currently empty, but will be built incrementally as the other + // nodes are visited. + self.span_edges[this_supernode] = Some(SpantreeEdge { + is_reversed: false, + claiming_node: this, + span_parent: succ_supernode, + }); + } else { + // This node's flow edge would join two supernodes that are already + // connected in the spantree (or are the same supernode). That would + // create a cycle in the spantree, so don't add an edge. + // + // Instead, create a physical counter for this node, and add that + // counter to all expressions on the path from `succ_supernode` to + // `this_supernode`. + + // Instead of setting `this.measure = true` as in the original paper, + // we just add the node's ID to its own list of terms. + self.counter_terms[this].push(CounterTerm { node: this, op: Op::Add }); + + // Walk the spantree from `this.successor` back to `this`. For each + // spantree edge along the way, add this node's physical counter to + // the counter expression of the node that claimed the spantree edge. + let mut curr = succ_supernode; + while curr != this_supernode { + let &SpantreeEdge { is_reversed, claiming_node, span_parent } = + self.span_edges[curr].as_ref().unwrap(); + let op = if is_reversed { Op::Subtract } else { Op::Add }; + self.counter_terms[claiming_node].push(CounterTerm { node: this, op }); + + curr = span_parent; + } + } + } + + /// Asserts that all nodes have been visited, and returns the computed + /// counter expressions (made up of physical counters) for each node. + fn finish(self) -> IndexVec<Node, Vec<CounterTerm<Node>>> { + let Self { ref span_edges, ref is_unvisited, ref counter_terms, .. } = self; + assert!(is_unvisited.is_empty(), "some nodes were never visited: {is_unvisited:?}"); + debug_assert!( + span_edges + .iter_enumerated() + .all(|(node, span_edge)| { span_edge.is_some() <= self.is_supernode(node) }), + "only supernodes can have a span edge", + ); + debug_assert!( + counter_terms.iter().all(|terms| !terms.is_empty()), + "after visiting all nodes, every node should have at least one term", + ); + + self.counter_terms + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/node_flow/tests.rs b/compiler/rustc_mir_transform/src/coverage/counters/node_flow/tests.rs new file mode 100644 index 00000000000..46c46c743c2 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/node_flow/tests.rs @@ -0,0 +1,62 @@ +use itertools::Itertools; +use rustc_data_structures::graph; +use rustc_data_structures::graph::vec_graph::VecGraph; +use rustc_index::Idx; +use rustc_middle::mir::coverage::Op; + +use crate::coverage::counters::node_flow::{ + CounterTerm, NodeCounters, NodeFlowData, make_node_counters, node_flow_data_for_balanced_graph, +}; + +fn node_flow_data<G: graph::Successors>(graph: G) -> NodeFlowData<G::Node> { + node_flow_data_for_balanced_graph(graph) +} + +fn make_graph<Node: Idx + Ord>(num_nodes: usize, edge_pairs: Vec<(Node, Node)>) -> VecGraph<Node> { + VecGraph::new(num_nodes, edge_pairs) +} + +/// Example used in "Optimal Measurement Points for Program Frequency Counts" +/// (Knuth & Stevenson, 1973), but with 0-based node IDs. +#[test] +fn example_driver() { + let graph = make_graph::<u32>( + 5, + vec![(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), (2, 4), (3, 3), (3, 4), (4, 0)], + ); + + let node_flow_data = node_flow_data(&graph); + let counters = make_node_counters(&node_flow_data, &[3, 1, 2, 0, 4]); + + assert_eq!( + format_counter_expressions(&counters), + &[ + // (comment to force vertical formatting for clarity) + "[0]: +c0", + "[1]: +c0 +c2 -c4", + "[2]: +c2", + "[3]: +c3", + "[4]: +c4", + ] + ); +} + +fn format_counter_expressions<Node: Idx>(counters: &NodeCounters<Node>) -> Vec<String> { + let format_item = |&CounterTerm { node, op }| { + let op = match op { + Op::Subtract => '-', + Op::Add => '+', + }; + format!("{op}c{node:?}") + }; + + counters + .counter_terms + .indices() + .map(|node| { + let mut terms = counters.counter_terms[node].iter().collect::<Vec<_>>(); + terms.sort_by_key(|item| item.node.index()); + format!("[{node:?}]: {}", terms.into_iter().map(format_item).join(" ")) + }) + .collect() +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/union_find.rs b/compiler/rustc_mir_transform/src/coverage/counters/union_find.rs new file mode 100644 index 00000000000..a826a953fa6 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/union_find.rs @@ -0,0 +1,96 @@ +use std::cmp::Ordering; +use std::mem; + +use rustc_index::{Idx, IndexVec}; + +#[cfg(test)] +mod tests; + +/// Simple implementation of a union-find data structure, i.e. a disjoint-set +/// forest. +#[derive(Debug)] +pub(crate) struct UnionFind<Key: Idx> { + table: IndexVec<Key, UnionFindEntry<Key>>, +} + +#[derive(Debug)] +struct UnionFindEntry<Key> { + /// Transitively points towards the "root" of the set containing this key. + /// + /// Invariant: A root key is its own parent. + parent: Key, + /// When merging two "root" keys, their ranks determine which key becomes + /// the new root, to prevent the parent tree from becoming unnecessarily + /// tall. See [`UnionFind::unify`] for details. + rank: u32, +} + +impl<Key: Idx> UnionFind<Key> { + /// Creates a new disjoint-set forest containing the keys `0..num_keys`. + /// Initially, every key is part of its own one-element set. + pub(crate) fn new(num_keys: usize) -> Self { + // Initially, every key is the root of its own set, so its parent is itself. + Self { table: IndexVec::from_fn_n(|key| UnionFindEntry { parent: key, rank: 0 }, num_keys) } + } + + /// Returns the "root" key of the disjoint-set containing the given key. + /// If two keys have the same root, they belong to the same set. + /// + /// Also updates internal data structures to make subsequent `find` + /// operations faster. + pub(crate) fn find(&mut self, key: Key) -> Key { + // Loop until we find a key that is its own parent. + let mut curr = key; + while let parent = self.table[curr].parent + && curr != parent + { + // Perform "path compression" by peeking one layer ahead, and + // setting the current key's parent to that value. + // (This works even when `parent` is the root of its set, because + // of the invariant that a root is its own parent.) + let parent_parent = self.table[parent].parent; + self.table[curr].parent = parent_parent; + + // Advance by one step and continue. + curr = parent; + } + curr + } + + /// Merges the set containing `a` and the set containing `b` into one set. + /// + /// Returns the common root of both keys, after the merge. + pub(crate) fn unify(&mut self, a: Key, b: Key) -> Key { + let mut a = self.find(a); + let mut b = self.find(b); + + // If both keys have the same root, they're already in the same set, + // so there's nothing more to do. + if a == b { + return a; + }; + + // Ensure that `a` has strictly greater rank, swapping if necessary. + // If both keys have the same rank, increment the rank of `a` so that + // future unifications will also prefer `a`, leading to flatter trees. + match Ord::cmp(&self.table[a].rank, &self.table[b].rank) { + Ordering::Less => mem::swap(&mut a, &mut b), + Ordering::Equal => self.table[a].rank += 1, + Ordering::Greater => {} + } + + debug_assert!(self.table[a].rank > self.table[b].rank); + debug_assert_eq!(self.table[b].parent, b); + + // Make `a` the parent of `b`. + self.table[b].parent = a; + + a + } + + /// Takes a "snapshot" of the current state of this disjoint-set forest, in + /// the form of a vector that directly maps each key to its current root. + pub(crate) fn snapshot(&mut self) -> IndexVec<Key, Key> { + self.table.indices().map(|key| self.find(key)).collect() + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters/union_find/tests.rs b/compiler/rustc_mir_transform/src/coverage/counters/union_find/tests.rs new file mode 100644 index 00000000000..34a4e4f8e6e --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters/union_find/tests.rs @@ -0,0 +1,32 @@ +use super::UnionFind; + +#[test] +fn empty() { + let mut sets = UnionFind::<u32>::new(10); + + for i in 1..10 { + assert_eq!(sets.find(i), i); + } +} + +#[test] +fn transitive() { + let mut sets = UnionFind::<u32>::new(10); + + sets.unify(3, 7); + sets.unify(4, 2); + + assert_eq!(sets.find(7), sets.find(3)); + assert_eq!(sets.find(2), sets.find(4)); + assert_ne!(sets.find(3), sets.find(4)); + + sets.unify(7, 4); + + assert_eq!(sets.find(7), sets.find(3)); + assert_eq!(sets.find(2), sets.find(4)); + assert_eq!(sets.find(3), sets.find(4)); + + for i in [0, 1, 5, 6, 8, 9] { + assert_eq!(sets.find(i), i); + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/graph.rs b/compiler/rustc_mir_transform/src/coverage/graph.rs new file mode 100644 index 00000000000..dcc7c5b91d7 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/graph.rs @@ -0,0 +1,433 @@ +use std::cmp::Ordering; +use std::ops::{Index, IndexMut}; +use std::{mem, slice}; + +use rustc_data_structures::fx::FxHashSet; +use rustc_data_structures::graph::dominators::Dominators; +use rustc_data_structures::graph::{self, DirectedGraph, StartNode}; +use rustc_index::IndexVec; +use rustc_index::bit_set::DenseBitSet; +pub(crate) use rustc_middle::mir::coverage::{BasicCoverageBlock, START_BCB}; +use rustc_middle::mir::{self, BasicBlock, Terminator, TerminatorKind}; +use tracing::debug; + +/// A coverage-specific simplification of the MIR control flow graph (CFG). The `CoverageGraph`s +/// nodes are `BasicCoverageBlock`s, which encompass one or more MIR `BasicBlock`s. +#[derive(Debug)] +pub(crate) struct CoverageGraph { + bcbs: IndexVec<BasicCoverageBlock, BasicCoverageBlockData>, + bb_to_bcb: IndexVec<BasicBlock, Option<BasicCoverageBlock>>, + pub(crate) successors: IndexVec<BasicCoverageBlock, Vec<BasicCoverageBlock>>, + pub(crate) predecessors: IndexVec<BasicCoverageBlock, Vec<BasicCoverageBlock>>, + + dominators: Option<Dominators<BasicCoverageBlock>>, + /// Allows nodes to be compared in some total order such that _if_ + /// `a` dominates `b`, then `a < b`. If neither node dominates the other, + /// their relative order is consistent but arbitrary. + dominator_order_rank: IndexVec<BasicCoverageBlock, u32>, + /// A loop header is a node that dominates one or more of its predecessors. + is_loop_header: DenseBitSet<BasicCoverageBlock>, + /// For each node, the loop header node of its nearest enclosing loop. + /// This forms a linked list that can be traversed to find all enclosing loops. + enclosing_loop_header: IndexVec<BasicCoverageBlock, Option<BasicCoverageBlock>>, +} + +impl CoverageGraph { + pub(crate) fn from_mir(mir_body: &mir::Body<'_>) -> Self { + let (bcbs, bb_to_bcb) = Self::compute_basic_coverage_blocks(mir_body); + + // Pre-transform MIR `BasicBlock` successors and predecessors into the BasicCoverageBlock + // equivalents. Note that since the BasicCoverageBlock graph has been fully simplified, the + // each predecessor of a BCB leader_bb should be in a unique BCB. It is possible for a + // `SwitchInt` to have multiple targets to the same destination `BasicBlock`, so + // de-duplication is required. This is done without reordering the successors. + + let successors = IndexVec::<BasicCoverageBlock, _>::from_fn_n( + |bcb| { + let mut seen_bcbs = FxHashSet::default(); + let terminator = mir_body[bcbs[bcb].last_bb()].terminator(); + bcb_filtered_successors(terminator) + .into_iter() + .filter_map(|successor_bb| bb_to_bcb[successor_bb]) + // Remove duplicate successor BCBs, keeping only the first. + .filter(|&successor_bcb| seen_bcbs.insert(successor_bcb)) + .collect::<Vec<_>>() + }, + bcbs.len(), + ); + + let mut predecessors = IndexVec::from_elem(Vec::new(), &bcbs); + for (bcb, bcb_successors) in successors.iter_enumerated() { + for &successor in bcb_successors { + predecessors[successor].push(bcb); + } + } + + let num_nodes = bcbs.len(); + let mut this = Self { + bcbs, + bb_to_bcb, + successors, + predecessors, + dominators: None, + dominator_order_rank: IndexVec::from_elem_n(0, num_nodes), + is_loop_header: DenseBitSet::new_empty(num_nodes), + enclosing_loop_header: IndexVec::from_elem_n(None, num_nodes), + }; + assert_eq!(num_nodes, this.num_nodes()); + + // Set the dominators first, because later init steps rely on them. + this.dominators = Some(graph::dominators::dominators(&this)); + + // Iterate over all nodes, such that dominating nodes are visited before + // the nodes they dominate. Either preorder or reverse postorder is fine. + let dominator_order = graph::iterate::reverse_post_order(&this, this.start_node()); + // The coverage graph is created by traversal, so all nodes are reachable. + assert_eq!(dominator_order.len(), this.num_nodes()); + for (rank, bcb) in (0u32..).zip(dominator_order) { + // The dominator rank of each node is its index in a dominator-order traversal. + this.dominator_order_rank[bcb] = rank; + + // A node is a loop header if it dominates any of its predecessors. + if this.reloop_predecessors(bcb).next().is_some() { + this.is_loop_header.insert(bcb); + } + + // If the immediate dominator is a loop header, that's our enclosing loop. + // Otherwise, inherit the immediate dominator's enclosing loop. + // (Dominator order ensures that we already processed the dominator.) + if let Some(dom) = this.dominators().immediate_dominator(bcb) { + this.enclosing_loop_header[bcb] = this + .is_loop_header + .contains(dom) + .then_some(dom) + .or_else(|| this.enclosing_loop_header[dom]); + } + } + + // The coverage graph's entry-point node (bcb0) always starts with bb0, + // which never has predecessors. Any other blocks merged into bcb0 can't + // have multiple (coverage-relevant) predecessors, so bcb0 always has + // zero in-edges. + assert!(this[START_BCB].leader_bb() == mir::START_BLOCK); + assert!(this.predecessors[START_BCB].is_empty()); + + this + } + + fn compute_basic_coverage_blocks( + mir_body: &mir::Body<'_>, + ) -> ( + IndexVec<BasicCoverageBlock, BasicCoverageBlockData>, + IndexVec<BasicBlock, Option<BasicCoverageBlock>>, + ) { + let num_basic_blocks = mir_body.basic_blocks.len(); + let mut bcbs = IndexVec::<BasicCoverageBlock, _>::with_capacity(num_basic_blocks); + let mut bb_to_bcb = IndexVec::from_elem_n(None, num_basic_blocks); + + let mut flush_chain_into_new_bcb = |current_chain: &mut Vec<BasicBlock>| { + // Take the accumulated list of blocks, leaving the vector empty + // to be used by subsequent BCBs. + let basic_blocks = mem::take(current_chain); + + let bcb = bcbs.next_index(); + for &bb in basic_blocks.iter() { + bb_to_bcb[bb] = Some(bcb); + } + + let is_out_summable = basic_blocks.last().is_some_and(|&bb| { + bcb_filtered_successors(mir_body[bb].terminator()).is_out_summable() + }); + let bcb_data = BasicCoverageBlockData { basic_blocks, is_out_summable }; + debug!("adding {bcb:?}: {bcb_data:?}"); + bcbs.push(bcb_data); + }; + + // Traverse the MIR control-flow graph, accumulating chains of blocks + // that can be combined into a single node in the coverage graph. + // A depth-first search ensures that if two nodes can be chained + // together, they will be adjacent in the traversal order. + + // Accumulates a chain of blocks that will be combined into one BCB. + let mut current_chain = vec![]; + + let subgraph = CoverageRelevantSubgraph::new(&mir_body.basic_blocks); + for bb in graph::depth_first_search(subgraph, mir::START_BLOCK) + .filter(|&bb| mir_body[bb].terminator().kind != TerminatorKind::Unreachable) + { + if let Some(&prev) = current_chain.last() { + // Adding a block to a non-empty chain is allowed if the + // previous block permits chaining, and the current block has + // `prev` as its sole predecessor. + let can_chain = subgraph.coverage_successors(prev).is_out_chainable() + && mir_body.basic_blocks.predecessors()[bb].as_slice() == &[prev]; + if !can_chain { + // The current block can't be added to the existing chain, so + // flush that chain into a new BCB, and start a new chain. + flush_chain_into_new_bcb(&mut current_chain); + } + } + + current_chain.push(bb); + } + + if !current_chain.is_empty() { + debug!("flushing accumulated blocks into one last BCB"); + flush_chain_into_new_bcb(&mut current_chain); + } + + (bcbs, bb_to_bcb) + } + + #[inline(always)] + pub(crate) fn iter_enumerated( + &self, + ) -> impl Iterator<Item = (BasicCoverageBlock, &BasicCoverageBlockData)> { + self.bcbs.iter_enumerated() + } + + #[inline(always)] + pub(crate) fn bcb_from_bb(&self, bb: BasicBlock) -> Option<BasicCoverageBlock> { + if bb.index() < self.bb_to_bcb.len() { self.bb_to_bcb[bb] } else { None } + } + + #[inline(always)] + fn dominators(&self) -> &Dominators<BasicCoverageBlock> { + self.dominators.as_ref().unwrap() + } + + #[inline(always)] + pub(crate) fn dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool { + self.dominators().dominates(dom, node) + } + + #[inline(always)] + pub(crate) fn cmp_in_dominator_order( + &self, + a: BasicCoverageBlock, + b: BasicCoverageBlock, + ) -> Ordering { + self.dominator_order_rank[a].cmp(&self.dominator_order_rank[b]) + } + + /// For the given node, yields the subset of its predecessor nodes that + /// it dominates. If that subset is non-empty, the node is a "loop header", + /// and each of those predecessors represents an in-edge that jumps back to + /// the top of its loop. + pub(crate) fn reloop_predecessors( + &self, + to_bcb: BasicCoverageBlock, + ) -> impl Iterator<Item = BasicCoverageBlock> { + self.predecessors[to_bcb].iter().copied().filter(move |&pred| self.dominates(to_bcb, pred)) + } +} + +impl Index<BasicCoverageBlock> for CoverageGraph { + type Output = BasicCoverageBlockData; + + #[inline] + fn index(&self, index: BasicCoverageBlock) -> &BasicCoverageBlockData { + &self.bcbs[index] + } +} + +impl IndexMut<BasicCoverageBlock> for CoverageGraph { + #[inline] + fn index_mut(&mut self, index: BasicCoverageBlock) -> &mut BasicCoverageBlockData { + &mut self.bcbs[index] + } +} + +impl graph::DirectedGraph for CoverageGraph { + type Node = BasicCoverageBlock; + + #[inline] + fn num_nodes(&self) -> usize { + self.bcbs.len() + } +} + +impl graph::StartNode for CoverageGraph { + #[inline] + fn start_node(&self) -> Self::Node { + self.bcb_from_bb(mir::START_BLOCK) + .expect("mir::START_BLOCK should be in a BasicCoverageBlock") + } +} + +impl graph::Successors for CoverageGraph { + #[inline] + fn successors(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> { + self.successors[node].iter().copied() + } +} + +impl graph::Predecessors for CoverageGraph { + #[inline] + fn predecessors(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> { + self.predecessors[node].iter().copied() + } +} + +/// `BasicCoverageBlockData` holds the data indexed by a `BasicCoverageBlock`. +/// +/// A `BasicCoverageBlock` (BCB) represents the maximal-length sequence of MIR `BasicBlock`s without +/// conditional branches, and form a new, simplified, coverage-specific Control Flow Graph, without +/// altering the original MIR CFG. +/// +/// Note that running the MIR `SimplifyCfg` transform is not sufficient (and therefore not +/// necessary). The BCB-based CFG is a more aggressive simplification. For example: +/// +/// * The BCB CFG ignores (trims) branches not relevant to coverage, such as unwind-related code, +/// that is injected by the Rust compiler but has no physical source code to count. This also +/// means a BasicBlock with a `Call` terminator can be merged into its primary successor target +/// block, in the same BCB. (But, note: Issue #78544: "MIR InstrumentCoverage: Improve coverage +/// of `#[should_panic]` tests and `catch_unwind()` handlers") +/// * Some BasicBlock terminators support Rust-specific concerns--like borrow-checking--that are +/// not relevant to coverage analysis. `FalseUnwind`, for example, can be treated the same as +/// a `Goto`, and merged with its successor into the same BCB. +/// +/// Each BCB with at least one computed coverage span will have no more than one `Counter`. +/// In some cases, a BCB's execution count can be computed by `Expression`. Additional +/// disjoint coverage spans in a BCB can also be counted by `Expression` (by adding `ZERO` +/// to the BCB's primary counter or expression). +/// +/// The BCB CFG is critical to simplifying the coverage analysis by ensuring graph path-based +/// queries (`dominates()`, `predecessors`, `successors`, etc.) have branch (control flow) +/// significance. +#[derive(Debug, Clone)] +pub(crate) struct BasicCoverageBlockData { + pub(crate) basic_blocks: Vec<BasicBlock>, + + /// If true, this node's execution count can be assumed to be the sum of the + /// execution counts of all of its **out-edges** (assuming no panics). + /// + /// Notably, this is false for a node ending with [`TerminatorKind::Yield`], + /// because the yielding coroutine might not be resumed. + pub(crate) is_out_summable: bool, +} + +impl BasicCoverageBlockData { + #[inline(always)] + pub(crate) fn leader_bb(&self) -> BasicBlock { + self.basic_blocks[0] + } + + #[inline(always)] + pub(crate) fn last_bb(&self) -> BasicBlock { + *self.basic_blocks.last().unwrap() + } +} + +/// Holds the coverage-relevant successors of a basic block's terminator, and +/// indicates whether that block can potentially be combined into the same BCB +/// as its sole successor. +#[derive(Clone, Copy, Debug)] +struct CoverageSuccessors<'a> { + /// Coverage-relevant successors of the corresponding terminator. + /// There might be 0, 1, or multiple targets. + targets: &'a [BasicBlock], + /// `Yield` terminators are not chainable, because their sole out-edge is + /// only followed if/when the generator is resumed after the yield. + is_yield: bool, +} + +impl CoverageSuccessors<'_> { + /// If `false`, this terminator cannot be chained into another block when + /// building the coverage graph. + fn is_out_chainable(&self) -> bool { + // If a terminator is out-summable and has exactly one out-edge, then + // it is eligible to be chained into its successor block. + self.is_out_summable() && self.targets.len() == 1 + } + + /// Returns true if the terminator itself is assumed to have the same + /// execution count as the sum of its out-edges (assuming no panics). + fn is_out_summable(&self) -> bool { + !self.is_yield && !self.targets.is_empty() + } +} + +impl IntoIterator for CoverageSuccessors<'_> { + type Item = BasicBlock; + type IntoIter = impl DoubleEndedIterator<Item = Self::Item>; + + fn into_iter(self) -> Self::IntoIter { + self.targets.iter().copied() + } +} + +// Returns the subset of a block's successors that are relevant to the coverage +// graph, i.e. those that do not represent unwinds or false edges. +// FIXME(#78544): MIR InstrumentCoverage: Improve coverage of `#[should_panic]` tests and +// `catch_unwind()` handlers. +fn bcb_filtered_successors<'a, 'tcx>(terminator: &'a Terminator<'tcx>) -> CoverageSuccessors<'a> { + use TerminatorKind::*; + let mut is_yield = false; + let targets = match &terminator.kind { + // A switch terminator can have many coverage-relevant successors. + SwitchInt { targets, .. } => targets.all_targets(), + + // A yield terminator has exactly 1 successor, but should not be chained, + // because its resume edge has a different execution count. + Yield { resume, .. } => { + is_yield = true; + slice::from_ref(resume) + } + + // These terminators have exactly one coverage-relevant successor, + // and can be chained into it. + Assert { target, .. } + | Drop { target, .. } + | FalseEdge { real_target: target, .. } + | FalseUnwind { real_target: target, .. } + | Goto { target } => slice::from_ref(target), + + // A call terminator can normally be chained, except when it has no + // successor because it is known to diverge. + Call { target: maybe_target, .. } => maybe_target.as_slice(), + + // An inline asm terminator can normally be chained, except when it + // diverges or uses asm goto. + InlineAsm { targets, .. } => &targets, + + // These terminators have no coverage-relevant successors. + CoroutineDrop + | Return + | TailCall { .. } + | Unreachable + | UnwindResume + | UnwindTerminate(_) => &[], + }; + + CoverageSuccessors { targets, is_yield } +} + +/// Wrapper around a [`mir::BasicBlocks`] graph that restricts each node's +/// successors to only the ones considered "relevant" when building a coverage +/// graph. +#[derive(Clone, Copy)] +struct CoverageRelevantSubgraph<'a, 'tcx> { + basic_blocks: &'a mir::BasicBlocks<'tcx>, +} +impl<'a, 'tcx> CoverageRelevantSubgraph<'a, 'tcx> { + fn new(basic_blocks: &'a mir::BasicBlocks<'tcx>) -> Self { + Self { basic_blocks } + } + + fn coverage_successors(&self, bb: BasicBlock) -> CoverageSuccessors<'_> { + bcb_filtered_successors(self.basic_blocks[bb].terminator()) + } +} +impl<'a, 'tcx> graph::DirectedGraph for CoverageRelevantSubgraph<'a, 'tcx> { + type Node = BasicBlock; + + fn num_nodes(&self) -> usize { + self.basic_blocks.num_nodes() + } +} +impl<'a, 'tcx> graph::Successors for CoverageRelevantSubgraph<'a, 'tcx> { + fn successors(&self, bb: Self::Node) -> impl Iterator<Item = Self::Node> { + self.coverage_successors(bb).into_iter() + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/mappings.rs b/compiler/rustc_mir_transform/src/coverage/mappings.rs new file mode 100644 index 00000000000..b4b4d0416fb --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/mappings.rs @@ -0,0 +1,350 @@ +use std::collections::BTreeSet; + +use rustc_data_structures::fx::FxIndexMap; +use rustc_index::IndexVec; +use rustc_middle::mir::coverage::{ + BlockMarkerId, BranchSpan, ConditionId, ConditionInfo, CoverageInfoHi, CoverageKind, +}; +use rustc_middle::mir::{self, BasicBlock, StatementKind}; +use rustc_middle::ty::TyCtxt; +use rustc_span::Span; + +use crate::coverage::ExtractedHirInfo; +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; +use crate::coverage::spans::extract_refined_covspans; +use crate::coverage::unexpand::unexpand_into_body_span; +use crate::errors::MCDCExceedsTestVectorLimit; + +/// Associates an ordinary executable code span with its corresponding BCB. +#[derive(Debug)] +pub(super) struct CodeMapping { + pub(super) span: Span, + pub(super) bcb: BasicCoverageBlock, +} + +/// This is separate from [`MCDCBranch`] to help prepare for larger changes +/// that will be needed for improved branch coverage in the future. +/// (See <https://github.com/rust-lang/rust/pull/124217>.) +#[derive(Debug)] +pub(super) struct BranchPair { + pub(super) span: Span, + pub(super) true_bcb: BasicCoverageBlock, + pub(super) false_bcb: BasicCoverageBlock, +} + +/// Associates an MC/DC branch span with condition info besides fields for normal branch. +#[derive(Debug)] +pub(super) struct MCDCBranch { + pub(super) span: Span, + pub(super) true_bcb: BasicCoverageBlock, + pub(super) false_bcb: BasicCoverageBlock, + pub(super) condition_info: ConditionInfo, + // Offset added to test vector idx if this branch is evaluated to true. + pub(super) true_index: usize, + // Offset added to test vector idx if this branch is evaluated to false. + pub(super) false_index: usize, +} + +/// Associates an MC/DC decision with its join BCBs. +#[derive(Debug)] +pub(super) struct MCDCDecision { + pub(super) span: Span, + pub(super) end_bcbs: BTreeSet<BasicCoverageBlock>, + pub(super) bitmap_idx: usize, + pub(super) num_test_vectors: usize, + pub(super) decision_depth: u16, +} + +// LLVM uses `i32` to index the bitmap. Thus `i32::MAX` is the hard limit for number of all test vectors +// in a function. +const MCDC_MAX_BITMAP_SIZE: usize = i32::MAX as usize; + +#[derive(Default)] +pub(super) struct ExtractedMappings { + pub(super) code_mappings: Vec<CodeMapping>, + pub(super) branch_pairs: Vec<BranchPair>, + pub(super) mcdc_bitmap_bits: usize, + pub(super) mcdc_degraded_branches: Vec<MCDCBranch>, + pub(super) mcdc_mappings: Vec<(MCDCDecision, Vec<MCDCBranch>)>, +} + +/// Extracts coverage-relevant spans from MIR, and associates them with +/// their corresponding BCBs. +pub(super) fn extract_all_mapping_info_from_mir<'tcx>( + tcx: TyCtxt<'tcx>, + mir_body: &mir::Body<'tcx>, + hir_info: &ExtractedHirInfo, + graph: &CoverageGraph, +) -> ExtractedMappings { + let mut code_mappings = vec![]; + let mut branch_pairs = vec![]; + let mut mcdc_bitmap_bits = 0; + let mut mcdc_degraded_branches = vec![]; + let mut mcdc_mappings = vec![]; + + if hir_info.is_async_fn || tcx.sess.coverage_no_mir_spans() { + // An async function desugars into a function that returns a future, + // with the user code wrapped in a closure. Any spans in the desugared + // outer function will be unhelpful, so just keep the signature span + // and ignore all of the spans in the MIR body. + // + // When debugging flag `-Zcoverage-options=no-mir-spans` is set, we need + // to give the same treatment to _all_ functions, because `llvm-cov` + // seems to ignore functions that don't have any ordinary code spans. + if let Some(span) = hir_info.fn_sig_span { + code_mappings.push(CodeMapping { span, bcb: START_BCB }); + } + } else { + // Extract coverage spans from MIR statements/terminators as normal. + extract_refined_covspans(tcx, mir_body, hir_info, graph, &mut code_mappings); + } + + branch_pairs.extend(extract_branch_pairs(mir_body, hir_info, graph)); + + extract_mcdc_mappings( + mir_body, + tcx, + hir_info.body_span, + graph, + &mut mcdc_bitmap_bits, + &mut mcdc_degraded_branches, + &mut mcdc_mappings, + ); + + ExtractedMappings { + code_mappings, + branch_pairs, + mcdc_bitmap_bits, + mcdc_degraded_branches, + mcdc_mappings, + } +} + +fn resolve_block_markers( + coverage_info_hi: &CoverageInfoHi, + mir_body: &mir::Body<'_>, +) -> IndexVec<BlockMarkerId, Option<BasicBlock>> { + let mut block_markers = IndexVec::<BlockMarkerId, Option<BasicBlock>>::from_elem_n( + None, + coverage_info_hi.num_block_markers, + ); + + // Fill out the mapping from block marker IDs to their enclosing blocks. + for (bb, data) in mir_body.basic_blocks.iter_enumerated() { + for statement in &data.statements { + if let StatementKind::Coverage(CoverageKind::BlockMarker { id }) = statement.kind { + block_markers[id] = Some(bb); + } + } + } + + block_markers +} + +// FIXME: There is currently a lot of redundancy between +// `extract_branch_pairs` and `extract_mcdc_mappings`. This is needed so +// that they can each be modified without interfering with the other, but in +// the long term we should try to bring them together again when branch coverage +// and MC/DC coverage support are more mature. + +pub(super) fn extract_branch_pairs( + mir_body: &mir::Body<'_>, + hir_info: &ExtractedHirInfo, + graph: &CoverageGraph, +) -> Vec<BranchPair> { + let Some(coverage_info_hi) = mir_body.coverage_info_hi.as_deref() else { return vec![] }; + + let block_markers = resolve_block_markers(coverage_info_hi, mir_body); + + coverage_info_hi + .branch_spans + .iter() + .filter_map(|&BranchSpan { span: raw_span, true_marker, false_marker }| { + // For now, ignore any branch span that was introduced by + // expansion. This makes things like assert macros less noisy. + if !raw_span.ctxt().outer_expn_data().is_root() { + return None; + } + let span = unexpand_into_body_span(raw_span, hir_info.body_span)?; + + let bcb_from_marker = |marker: BlockMarkerId| graph.bcb_from_bb(block_markers[marker]?); + + let true_bcb = bcb_from_marker(true_marker)?; + let false_bcb = bcb_from_marker(false_marker)?; + + Some(BranchPair { span, true_bcb, false_bcb }) + }) + .collect::<Vec<_>>() +} + +pub(super) fn extract_mcdc_mappings( + mir_body: &mir::Body<'_>, + tcx: TyCtxt<'_>, + body_span: Span, + graph: &CoverageGraph, + mcdc_bitmap_bits: &mut usize, + mcdc_degraded_branches: &mut impl Extend<MCDCBranch>, + mcdc_mappings: &mut impl Extend<(MCDCDecision, Vec<MCDCBranch>)>, +) { + let Some(coverage_info_hi) = mir_body.coverage_info_hi.as_deref() else { return }; + + let block_markers = resolve_block_markers(coverage_info_hi, mir_body); + + let bcb_from_marker = |marker: BlockMarkerId| graph.bcb_from_bb(block_markers[marker]?); + + let check_branch_bcb = + |raw_span: Span, true_marker: BlockMarkerId, false_marker: BlockMarkerId| { + // For now, ignore any branch span that was introduced by + // expansion. This makes things like assert macros less noisy. + if !raw_span.ctxt().outer_expn_data().is_root() { + return None; + } + let span = unexpand_into_body_span(raw_span, body_span)?; + + let true_bcb = bcb_from_marker(true_marker)?; + let false_bcb = bcb_from_marker(false_marker)?; + Some((span, true_bcb, false_bcb)) + }; + + let to_mcdc_branch = |&mir::coverage::MCDCBranchSpan { + span: raw_span, + condition_info, + true_marker, + false_marker, + }| { + let (span, true_bcb, false_bcb) = check_branch_bcb(raw_span, true_marker, false_marker)?; + Some(MCDCBranch { + span, + true_bcb, + false_bcb, + condition_info, + true_index: usize::MAX, + false_index: usize::MAX, + }) + }; + + let mut get_bitmap_idx = |num_test_vectors: usize| -> Option<usize> { + let bitmap_idx = *mcdc_bitmap_bits; + let next_bitmap_bits = bitmap_idx.saturating_add(num_test_vectors); + (next_bitmap_bits <= MCDC_MAX_BITMAP_SIZE).then(|| { + *mcdc_bitmap_bits = next_bitmap_bits; + bitmap_idx + }) + }; + mcdc_degraded_branches + .extend(coverage_info_hi.mcdc_degraded_branch_spans.iter().filter_map(to_mcdc_branch)); + + mcdc_mappings.extend(coverage_info_hi.mcdc_spans.iter().filter_map(|(decision, branches)| { + if branches.len() == 0 { + return None; + } + let decision_span = unexpand_into_body_span(decision.span, body_span)?; + + let end_bcbs = decision + .end_markers + .iter() + .map(|&marker| bcb_from_marker(marker)) + .collect::<Option<_>>()?; + let mut branch_mappings: Vec<_> = branches.into_iter().filter_map(to_mcdc_branch).collect(); + if branch_mappings.len() != branches.len() { + mcdc_degraded_branches.extend(branch_mappings); + return None; + } + let num_test_vectors = calc_test_vectors_index(&mut branch_mappings); + let Some(bitmap_idx) = get_bitmap_idx(num_test_vectors) else { + tcx.dcx().emit_warn(MCDCExceedsTestVectorLimit { + span: decision_span, + max_num_test_vectors: MCDC_MAX_BITMAP_SIZE, + }); + mcdc_degraded_branches.extend(branch_mappings); + return None; + }; + // LLVM requires span of the decision contains all spans of its conditions. + // Usually the decision span meets the requirement well but in cases like macros it may not. + let span = branch_mappings + .iter() + .map(|branch| branch.span) + .reduce(|lhs, rhs| lhs.to(rhs)) + .map( + |joint_span| { + if decision_span.contains(joint_span) { decision_span } else { joint_span } + }, + ) + .expect("branch mappings are ensured to be non-empty as checked above"); + Some(( + MCDCDecision { + span, + end_bcbs, + bitmap_idx, + num_test_vectors, + decision_depth: decision.decision_depth, + }, + branch_mappings, + )) + })); +} + +// LLVM checks the executed test vector by accumulating indices of tested branches. +// We calculate number of all possible test vectors of the decision and assign indices +// to branches here. +// See [the rfc](https://discourse.llvm.org/t/rfc-coverage-new-algorithm-and-file-format-for-mc-dc/76798/) +// for more details about the algorithm. +// This function is mostly like [`TVIdxBuilder::TvIdxBuilder`](https://github.com/llvm/llvm-project/blob/d594d9f7f4dc6eb748b3261917db689fdc348b96/llvm/lib/ProfileData/Coverage/CoverageMapping.cpp#L226) +fn calc_test_vectors_index(conditions: &mut Vec<MCDCBranch>) -> usize { + let mut indegree_stats = IndexVec::<ConditionId, usize>::from_elem_n(0, conditions.len()); + // `num_paths` is `width` described at the llvm rfc, which indicates how many paths reaching the condition node. + let mut num_paths_stats = IndexVec::<ConditionId, usize>::from_elem_n(0, conditions.len()); + let mut next_conditions = conditions + .iter_mut() + .map(|branch| { + let ConditionInfo { condition_id, true_next_id, false_next_id } = branch.condition_info; + [true_next_id, false_next_id] + .into_iter() + .flatten() + .for_each(|next_id| indegree_stats[next_id] += 1); + (condition_id, branch) + }) + .collect::<FxIndexMap<_, _>>(); + + let mut queue = + std::collections::VecDeque::from_iter(next_conditions.swap_remove(&ConditionId::START)); + num_paths_stats[ConditionId::START] = 1; + let mut decision_end_nodes = Vec::new(); + while let Some(branch) = queue.pop_front() { + let ConditionInfo { condition_id, true_next_id, false_next_id } = branch.condition_info; + let (false_index, true_index) = (&mut branch.false_index, &mut branch.true_index); + let this_paths_count = num_paths_stats[condition_id]; + // Note. First check the false next to ensure conditions are touched in same order with llvm-cov. + for (next, index) in [(false_next_id, false_index), (true_next_id, true_index)] { + if let Some(next_id) = next { + let next_paths_count = &mut num_paths_stats[next_id]; + *index = *next_paths_count; + *next_paths_count = next_paths_count.saturating_add(this_paths_count); + let next_indegree = &mut indegree_stats[next_id]; + *next_indegree -= 1; + if *next_indegree == 0 { + queue.push_back(next_conditions.swap_remove(&next_id).expect( + "conditions with non-zero indegree before must be in next_conditions", + )); + } + } else { + decision_end_nodes.push((this_paths_count, condition_id, index)); + } + } + } + assert!(next_conditions.is_empty(), "the decision tree has untouched nodes"); + let mut cur_idx = 0; + // LLVM hopes the end nodes are sorted in descending order by `num_paths` so that it can + // optimize bitmap size for decisions in tree form such as `a && b && c && d && ...`. + decision_end_nodes.sort_by_key(|(num_paths, _, _)| usize::MAX - *num_paths); + for (num_paths, condition_id, index) in decision_end_nodes { + assert_eq!( + num_paths, num_paths_stats[condition_id], + "end nodes should not be updated since they were visited" + ); + assert_eq!(*index, usize::MAX, "end nodes should not be assigned index before"); + *index = cur_idx; + cur_idx += num_paths; + } + cur_idx +} diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs new file mode 100644 index 00000000000..702c62eddc7 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -0,0 +1,383 @@ +mod counters; +mod graph; +mod mappings; +pub(super) mod query; +mod spans; +#[cfg(test)] +mod tests; +mod unexpand; + +use rustc_hir as hir; +use rustc_hir::intravisit::{Visitor, walk_expr}; +use rustc_middle::hir::nested_filter; +use rustc_middle::mir::coverage::{ + CoverageKind, DecisionInfo, FunctionCoverageInfo, Mapping, MappingKind, +}; +use rustc_middle::mir::{self, BasicBlock, Statement, StatementKind, TerminatorKind}; +use rustc_middle::ty::TyCtxt; +use rustc_span::Span; +use rustc_span::def_id::LocalDefId; +use tracing::{debug, debug_span, trace}; + +use crate::coverage::counters::BcbCountersData; +use crate::coverage::graph::CoverageGraph; +use crate::coverage::mappings::ExtractedMappings; + +/// Inserts `StatementKind::Coverage` statements that either instrument the binary with injected +/// counters, via intrinsic `llvm.instrprof.increment`, and/or inject metadata used during codegen +/// to construct the coverage map. +pub(super) struct InstrumentCoverage; + +impl<'tcx> crate::MirPass<'tcx> for InstrumentCoverage { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.instrument_coverage() + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, mir_body: &mut mir::Body<'tcx>) { + let mir_source = mir_body.source; + + // This pass runs after MIR promotion, but before promoted MIR starts to + // be transformed, so it should never see promoted MIR. + assert!(mir_source.promoted.is_none()); + + let def_id = mir_source.def_id().expect_local(); + + if !tcx.is_eligible_for_coverage(def_id) { + trace!("InstrumentCoverage skipped for {def_id:?} (not eligible)"); + return; + } + + // An otherwise-eligible function is still skipped if its start block + // is known to be unreachable. + match mir_body.basic_blocks[mir::START_BLOCK].terminator().kind { + TerminatorKind::Unreachable => { + trace!("InstrumentCoverage skipped for unreachable `START_BLOCK`"); + return; + } + _ => {} + } + + instrument_function_for_coverage(tcx, mir_body); + } + + fn is_required(&self) -> bool { + false + } +} + +fn instrument_function_for_coverage<'tcx>(tcx: TyCtxt<'tcx>, mir_body: &mut mir::Body<'tcx>) { + let def_id = mir_body.source.def_id(); + let _span = debug_span!("instrument_function_for_coverage", ?def_id).entered(); + + let hir_info = extract_hir_info(tcx, def_id.expect_local()); + + // Build the coverage graph, which is a simplified view of the MIR control-flow + // graph that ignores some details not relevant to coverage instrumentation. + let graph = CoverageGraph::from_mir(mir_body); + + //////////////////////////////////////////////////// + // Extract coverage spans and other mapping info from MIR. + let extracted_mappings = + mappings::extract_all_mapping_info_from_mir(tcx, mir_body, &hir_info, &graph); + + let mappings = create_mappings(&extracted_mappings); + if mappings.is_empty() { + // No spans could be converted into valid mappings, so skip this function. + debug!("no spans could be converted into valid mappings; skipping"); + return; + } + + // Use the coverage graph to prepare intermediate data that will eventually + // be used to assign physical counters and counter expressions to points in + // the control-flow graph. + let BcbCountersData { node_flow_data, priority_list } = + counters::prepare_bcb_counters_data(&graph); + + // Inject coverage statements into MIR. + inject_coverage_statements(mir_body, &graph); + inject_mcdc_statements(mir_body, &graph, &extracted_mappings); + + let mcdc_num_condition_bitmaps = extracted_mappings + .mcdc_mappings + .iter() + .map(|&(mappings::MCDCDecision { decision_depth, .. }, _)| decision_depth) + .max() + .map_or(0, |max| usize::from(max) + 1); + + mir_body.function_coverage_info = Some(Box::new(FunctionCoverageInfo { + function_source_hash: hir_info.function_source_hash, + + node_flow_data, + priority_list, + + mappings, + + mcdc_bitmap_bits: extracted_mappings.mcdc_bitmap_bits, + mcdc_num_condition_bitmaps, + })); +} + +/// For each coverage span extracted from MIR, create a corresponding mapping. +/// +/// FIXME(Zalathar): This used to be where BCBs in the extracted mappings were +/// resolved to a `CovTerm`. But that is now handled elsewhere, so this +/// function can potentially be simplified even further. +fn create_mappings(extracted_mappings: &ExtractedMappings) -> Vec<Mapping> { + // Fully destructure the mappings struct to make sure we don't miss any kinds. + let ExtractedMappings { + code_mappings, + branch_pairs, + mcdc_bitmap_bits: _, + mcdc_degraded_branches, + mcdc_mappings, + } = extracted_mappings; + let mut mappings = Vec::new(); + + mappings.extend(code_mappings.iter().map( + // Ordinary code mappings are the simplest kind. + |&mappings::CodeMapping { span, bcb }| { + let kind = MappingKind::Code { bcb }; + Mapping { kind, span } + }, + )); + + mappings.extend(branch_pairs.iter().map( + |&mappings::BranchPair { span, true_bcb, false_bcb }| { + let kind = MappingKind::Branch { true_bcb, false_bcb }; + Mapping { kind, span } + }, + )); + + // MCDC branch mappings are appended with their decisions in case decisions were ignored. + mappings.extend(mcdc_degraded_branches.iter().map( + |&mappings::MCDCBranch { + span, + true_bcb, + false_bcb, + condition_info: _, + true_index: _, + false_index: _, + }| { Mapping { kind: MappingKind::Branch { true_bcb, false_bcb }, span } }, + )); + + for (decision, branches) in mcdc_mappings { + // FIXME(#134497): Previously it was possible for some of these branch + // conversions to fail, in which case the remaining branches in the + // decision would be degraded to plain `MappingKind::Branch`. + // The changes in #134497 made that failure impossible, because the + // fallible step was deferred to codegen. But the corresponding code + // in codegen wasn't updated to detect the need for a degrade step. + let conditions = branches + .into_iter() + .map( + |&mappings::MCDCBranch { + span, + true_bcb, + false_bcb, + condition_info, + true_index: _, + false_index: _, + }| { + Mapping { + kind: MappingKind::MCDCBranch { + true_bcb, + false_bcb, + mcdc_params: condition_info, + }, + span, + } + }, + ) + .collect::<Vec<_>>(); + + // LLVM requires end index for counter mapping regions. + let kind = MappingKind::MCDCDecision(DecisionInfo { + bitmap_idx: (decision.bitmap_idx + decision.num_test_vectors) as u32, + num_conditions: u16::try_from(conditions.len()).unwrap(), + }); + let span = decision.span; + mappings.extend(std::iter::once(Mapping { kind, span }).chain(conditions.into_iter())); + } + + mappings +} + +/// Inject any necessary coverage statements into MIR, so that they influence codegen. +fn inject_coverage_statements<'tcx>(mir_body: &mut mir::Body<'tcx>, graph: &CoverageGraph) { + for (bcb, data) in graph.iter_enumerated() { + let target_bb = data.leader_bb(); + inject_statement(mir_body, CoverageKind::VirtualCounter { bcb }, target_bb); + } +} + +/// For each conditions inject statements to update condition bitmap after it has been evaluated. +/// For each decision inject statements to update test vector bitmap after it has been evaluated. +fn inject_mcdc_statements<'tcx>( + mir_body: &mut mir::Body<'tcx>, + graph: &CoverageGraph, + extracted_mappings: &ExtractedMappings, +) { + for (decision, conditions) in &extracted_mappings.mcdc_mappings { + // Inject test vector update first because `inject_statement` always insert new statement at head. + for &end in &decision.end_bcbs { + let end_bb = graph[end].leader_bb(); + inject_statement( + mir_body, + CoverageKind::TestVectorBitmapUpdate { + bitmap_idx: decision.bitmap_idx as u32, + decision_depth: decision.decision_depth, + }, + end_bb, + ); + } + + for &mappings::MCDCBranch { + span: _, + true_bcb, + false_bcb, + condition_info: _, + true_index, + false_index, + } in conditions + { + for (index, bcb) in [(false_index, false_bcb), (true_index, true_bcb)] { + let bb = graph[bcb].leader_bb(); + inject_statement( + mir_body, + CoverageKind::CondBitmapUpdate { + index: index as u32, + decision_depth: decision.decision_depth, + }, + bb, + ); + } + } + } +} + +fn inject_statement(mir_body: &mut mir::Body<'_>, counter_kind: CoverageKind, bb: BasicBlock) { + debug!(" injecting statement {counter_kind:?} for {bb:?}"); + let data = &mut mir_body[bb]; + let source_info = data.terminator().source_info; + let statement = Statement { source_info, kind: StatementKind::Coverage(counter_kind) }; + data.statements.insert(0, statement); +} + +/// Function information extracted from HIR by the coverage instrumentor. +#[derive(Debug)] +struct ExtractedHirInfo { + function_source_hash: u64, + is_async_fn: bool, + /// The span of the function's signature, if available. + /// Must have the same context and filename as the body span. + fn_sig_span: Option<Span>, + body_span: Span, + /// "Holes" are regions within the function body (or its expansions) that + /// should not be included in coverage spans for this function + /// (e.g. closures and nested items). + hole_spans: Vec<Span>, +} + +fn extract_hir_info<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> ExtractedHirInfo { + // FIXME(#79625): Consider improving MIR to provide the information needed, to avoid going back + // to HIR for it. + + // HACK: For synthetic MIR bodies (async closures), use the def id of the HIR body. + if tcx.is_synthetic_mir(def_id) { + return extract_hir_info(tcx, tcx.local_parent(def_id)); + } + + let hir_node = tcx.hir_node_by_def_id(def_id); + let fn_body_id = hir_node.body_id().expect("HIR node is a function with body"); + let hir_body = tcx.hir_body(fn_body_id); + + let maybe_fn_sig = hir_node.fn_sig(); + let is_async_fn = maybe_fn_sig.is_some_and(|fn_sig| fn_sig.header.is_async()); + + let mut body_span = hir_body.value.span; + + use hir::{Closure, Expr, ExprKind, Node}; + // Unexpand a closure's body span back to the context of its declaration. + // This helps with closure bodies that consist of just a single bang-macro, + // and also with closure bodies produced by async desugaring. + if let Node::Expr(&Expr { kind: ExprKind::Closure(&Closure { fn_decl_span, .. }), .. }) = + hir_node + { + body_span = body_span.find_ancestor_in_same_ctxt(fn_decl_span).unwrap_or(body_span); + } + + // The actual signature span is only used if it has the same context and + // filename as the body, and precedes the body. + let fn_sig_span = maybe_fn_sig.map(|fn_sig| fn_sig.span).filter(|&fn_sig_span| { + let source_map = tcx.sess.source_map(); + let file_idx = |span: Span| source_map.lookup_source_file_idx(span.lo()); + + fn_sig_span.eq_ctxt(body_span) + && fn_sig_span.hi() <= body_span.lo() + && file_idx(fn_sig_span) == file_idx(body_span) + }); + + let function_source_hash = hash_mir_source(tcx, hir_body); + + let hole_spans = extract_hole_spans_from_hir(tcx, hir_body); + + ExtractedHirInfo { function_source_hash, is_async_fn, fn_sig_span, body_span, hole_spans } +} + +fn hash_mir_source<'tcx>(tcx: TyCtxt<'tcx>, hir_body: &'tcx hir::Body<'tcx>) -> u64 { + // FIXME(cjgillot) Stop hashing HIR manually here. + let owner = hir_body.id().hir_id.owner; + tcx.hir_owner_nodes(owner).opt_hash_including_bodies.unwrap().to_smaller_hash().as_u64() +} + +fn extract_hole_spans_from_hir<'tcx>(tcx: TyCtxt<'tcx>, hir_body: &hir::Body<'tcx>) -> Vec<Span> { + struct HolesVisitor<'tcx> { + tcx: TyCtxt<'tcx>, + hole_spans: Vec<Span>, + } + + impl<'tcx> Visitor<'tcx> for HolesVisitor<'tcx> { + /// We have special handling for nested items, but we still want to + /// traverse into nested bodies of things that are not considered items, + /// such as "anon consts" (e.g. array lengths). + type NestedFilter = nested_filter::OnlyBodies; + + fn maybe_tcx(&mut self) -> TyCtxt<'tcx> { + self.tcx + } + + /// We override `visit_nested_item` instead of `visit_item` because we + /// only need the item's span, not the item itself. + fn visit_nested_item(&mut self, id: hir::ItemId) -> Self::Result { + let span = self.tcx.def_span(id.owner_id.def_id); + self.visit_hole_span(span); + // Having visited this item, we don't care about its children, + // so don't call `walk_item`. + } + + // We override `visit_expr` instead of the more specific expression + // visitors, so that we have direct access to the expression span. + fn visit_expr(&mut self, expr: &'tcx hir::Expr<'tcx>) { + match expr.kind { + hir::ExprKind::Closure(_) | hir::ExprKind::ConstBlock(_) => { + self.visit_hole_span(expr.span); + // Having visited this expression, we don't care about its + // children, so don't call `walk_expr`. + } + + // For other expressions, recursively visit as normal. + _ => walk_expr(self, expr), + } + } + } + impl HolesVisitor<'_> { + fn visit_hole_span(&mut self, hole_span: Span) { + self.hole_spans.push(hole_span); + } + } + + let mut visitor = HolesVisitor { tcx, hole_spans: vec![] }; + + visitor.visit_body(hir_body); + visitor.hole_spans +} diff --git a/compiler/rustc_mir_transform/src/coverage/query.rs b/compiler/rustc_mir_transform/src/coverage/query.rs new file mode 100644 index 00000000000..ccf76dc7108 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/query.rs @@ -0,0 +1,167 @@ +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; +use rustc_middle::mir::coverage::{BasicCoverageBlock, CoverageIdsInfo, CoverageKind, MappingKind}; +use rustc_middle::mir::{Body, Statement, StatementKind}; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::util::Providers; +use rustc_span::def_id::LocalDefId; +use rustc_span::sym; +use tracing::trace; + +use crate::coverage::counters::node_flow::make_node_counters; +use crate::coverage::counters::{CoverageCounters, transcribe_counters}; + +/// Registers query/hook implementations related to coverage. +pub(crate) fn provide(providers: &mut Providers) { + providers.hooks.is_eligible_for_coverage = is_eligible_for_coverage; + providers.queries.coverage_attr_on = coverage_attr_on; + providers.queries.coverage_ids_info = coverage_ids_info; +} + +/// Hook implementation for [`TyCtxt::is_eligible_for_coverage`]. +fn is_eligible_for_coverage(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { + // Only instrument functions, methods, and closures (not constants since they are evaluated + // at compile time by Miri). + // FIXME(#73156): Handle source code coverage in const eval, but note, if and when const + // expressions get coverage spans, we will probably have to "carve out" space for const + // expressions from coverage spans in enclosing MIR's, like we do for closures. (That might + // be tricky if const expressions have no corresponding statements in the enclosing MIR. + // Closures are carved out by their initial `Assign` statement.) + if !tcx.def_kind(def_id).is_fn_like() { + trace!("InstrumentCoverage skipped for {def_id:?} (not an fn-like)"); + return false; + } + + // Don't instrument functions with `#[automatically_derived]` on their + // enclosing impl block, on the assumption that most users won't care about + // coverage for derived impls. + if let Some(impl_of) = tcx.impl_of_method(def_id.to_def_id()) + && tcx.is_automatically_derived(impl_of) + { + trace!("InstrumentCoverage skipped for {def_id:?} (automatically derived)"); + return false; + } + + if tcx.codegen_fn_attrs(def_id).flags.contains(CodegenFnAttrFlags::NAKED) { + trace!("InstrumentCoverage skipped for {def_id:?} (`#[naked]`)"); + return false; + } + + if !tcx.coverage_attr_on(def_id) { + trace!("InstrumentCoverage skipped for {def_id:?} (`#[coverage(off)]`)"); + return false; + } + + true +} + +/// Query implementation for `coverage_attr_on`. +fn coverage_attr_on(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { + // Check for annotations directly on this def. + if let Some(attr) = tcx.get_attr(def_id, sym::coverage) { + match attr.meta_item_list().as_deref() { + Some([item]) if item.has_name(sym::off) => return false, + Some([item]) if item.has_name(sym::on) => return true, + Some(_) | None => { + // Other possibilities should have been rejected by `rustc_parse::validate_attr`. + // Use `span_delayed_bug` to avoid an ICE in failing builds (#127880). + tcx.dcx().span_delayed_bug(attr.span(), "unexpected value of coverage attribute"); + } + } + } + + match tcx.opt_local_parent(def_id) { + // Check the parent def (and so on recursively) until we find an + // enclosing attribute or reach the crate root. + Some(parent) => tcx.coverage_attr_on(parent), + // We reached the crate root without seeing a coverage attribute, so + // allow coverage instrumentation by default. + None => true, + } +} + +/// Query implementation for `coverage_ids_info`. +fn coverage_ids_info<'tcx>( + tcx: TyCtxt<'tcx>, + instance_def: ty::InstanceKind<'tcx>, +) -> Option<CoverageIdsInfo> { + let mir_body = tcx.instance_mir(instance_def); + let fn_cov_info = mir_body.function_coverage_info.as_deref()?; + + // Scan through the final MIR to see which BCBs survived MIR opts. + // Any BCB not in this set was optimized away. + let mut bcbs_seen = DenseBitSet::new_empty(fn_cov_info.priority_list.len()); + for kind in all_coverage_in_mir_body(mir_body) { + match *kind { + CoverageKind::VirtualCounter { bcb } => { + bcbs_seen.insert(bcb); + } + _ => {} + } + } + + // Determine the set of BCBs that are referred to by mappings, and therefore + // need a counter. Any node not in this set will only get a counter if it + // is part of the counter expression for a node that is in the set. + let mut bcb_needs_counter = + DenseBitSet::<BasicCoverageBlock>::new_empty(fn_cov_info.priority_list.len()); + for mapping in &fn_cov_info.mappings { + match mapping.kind { + MappingKind::Code { bcb } => { + bcb_needs_counter.insert(bcb); + } + MappingKind::Branch { true_bcb, false_bcb } => { + bcb_needs_counter.insert(true_bcb); + bcb_needs_counter.insert(false_bcb); + } + MappingKind::MCDCBranch { true_bcb, false_bcb, mcdc_params: _ } => { + bcb_needs_counter.insert(true_bcb); + bcb_needs_counter.insert(false_bcb); + } + MappingKind::MCDCDecision(_) => {} + } + } + + // Clone the priority list so that we can re-sort it. + let mut priority_list = fn_cov_info.priority_list.clone(); + // The first ID in the priority list represents the synthetic "sink" node, + // and must remain first so that it _never_ gets a physical counter. + debug_assert_eq!(priority_list[0], priority_list.iter().copied().max().unwrap()); + assert!(!bcbs_seen.contains(priority_list[0])); + // Partition the priority list, so that unreachable nodes (removed by MIR opts) + // are sorted later and therefore are _more_ likely to get a physical counter. + // This is counter-intuitive, but it means that `transcribe_counters` can + // easily skip those unused physical counters and replace them with zero. + // (The original ordering remains in effect within both partitions.) + priority_list[1..].sort_by_key(|&bcb| !bcbs_seen.contains(bcb)); + + let node_counters = make_node_counters(&fn_cov_info.node_flow_data, &priority_list); + let coverage_counters = transcribe_counters(&node_counters, &bcb_needs_counter, &bcbs_seen); + + let CoverageCounters { + phys_counter_for_node, next_counter_id, node_counters, expressions, .. + } = coverage_counters; + + Some(CoverageIdsInfo { + num_counters: next_counter_id.as_u32(), + phys_counter_for_node, + term_for_bcb: node_counters, + expressions, + }) +} + +fn all_coverage_in_mir_body<'a, 'tcx>( + body: &'a Body<'tcx>, +) -> impl Iterator<Item = &'a CoverageKind> { + body.basic_blocks.iter().flat_map(|bb_data| &bb_data.statements).filter_map(|statement| { + match statement.kind { + StatementKind::Coverage(ref kind) if !is_inlined(body, statement) => Some(kind), + _ => None, + } + }) +} + +fn is_inlined(body: &Body<'_>, statement: &Statement<'_>) -> bool { + let scope_data = &body.source_scopes[statement.source_info.scope]; + scope_data.inlined.is_some() || scope_data.inlined_parent_scope.is_some() +} diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs new file mode 100644 index 00000000000..ec76076020e --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -0,0 +1,232 @@ +use rustc_data_structures::fx::FxHashSet; +use rustc_middle::mir; +use rustc_middle::ty::TyCtxt; +use rustc_span::{DesugaringKind, ExpnKind, MacroKind, Span}; +use tracing::instrument; + +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph}; +use crate::coverage::spans::from_mir::{Hole, RawSpanFromMir, SpanFromMir}; +use crate::coverage::{ExtractedHirInfo, mappings, unexpand}; + +mod from_mir; + +pub(super) fn extract_refined_covspans<'tcx>( + tcx: TyCtxt<'tcx>, + mir_body: &mir::Body<'tcx>, + hir_info: &ExtractedHirInfo, + graph: &CoverageGraph, + code_mappings: &mut impl Extend<mappings::CodeMapping>, +) { + let &ExtractedHirInfo { body_span, .. } = hir_info; + + let raw_spans = from_mir::extract_raw_spans_from_mir(mir_body, graph); + let mut covspans = raw_spans + .into_iter() + .filter_map(|RawSpanFromMir { raw_span, bcb }| try { + let (span, expn_kind) = + unexpand::unexpand_into_body_span_with_expn_kind(raw_span, body_span)?; + // Discard any spans that fill the entire body, because they tend + // to represent compiler-inserted code, e.g. implicitly returning `()`. + if span.source_equal(body_span) { + return None; + }; + SpanFromMir { span, expn_kind, bcb } + }) + .collect::<Vec<_>>(); + + // Only proceed if we found at least one usable span. + if covspans.is_empty() { + return; + } + + // Also add the function signature span, if available. + // Otherwise, add a fake span at the start of the body, to avoid an ugly + // gap between the start of the body and the first real span. + // FIXME: Find a more principled way to solve this problem. + covspans.push(SpanFromMir::for_fn_sig( + hir_info.fn_sig_span.unwrap_or_else(|| body_span.shrink_to_lo()), + )); + + // First, perform the passes that need macro information. + covspans.sort_by(|a, b| graph.cmp_in_dominator_order(a.bcb, b.bcb)); + remove_unwanted_expansion_spans(&mut covspans); + shrink_visible_macro_spans(tcx, &mut covspans); + + // We no longer need the extra information in `SpanFromMir`, so convert to `Covspan`. + let mut covspans = covspans.into_iter().map(SpanFromMir::into_covspan).collect::<Vec<_>>(); + + let compare_covspans = |a: &Covspan, b: &Covspan| { + compare_spans(a.span, b.span) + // After deduplication, we want to keep only the most-dominated BCB. + .then_with(|| graph.cmp_in_dominator_order(a.bcb, b.bcb).reverse()) + }; + covspans.sort_by(compare_covspans); + + // Among covspans with the same span, keep only one, + // preferring the one with the most-dominated BCB. + // (Ideally we should try to preserve _all_ non-dominating BCBs, but that + // requires a lot more complexity in the span refiner, for little benefit.) + covspans.dedup_by(|b, a| a.span.source_equal(b.span)); + + // Sort the holes, and merge overlapping/adjacent holes. + let mut holes = hir_info + .hole_spans + .iter() + .copied() + // Discard any holes that aren't directly visible within the body span. + .filter(|&hole_span| body_span.contains(hole_span) && body_span.eq_ctxt(hole_span)) + .map(|span| Hole { span }) + .collect::<Vec<_>>(); + holes.sort_by(|a, b| compare_spans(a.span, b.span)); + holes.dedup_by(|b, a| a.merge_if_overlapping_or_adjacent(b)); + + // Discard any span that overlaps with a hole. + discard_spans_overlapping_holes(&mut covspans, &holes); + + // Perform more refinement steps after holes have been dealt with. + let mut covspans = remove_unwanted_overlapping_spans(covspans); + covspans.dedup_by(|b, a| a.merge_if_eligible(b)); + + code_mappings.extend(covspans.into_iter().map(|Covspan { span, bcb }| { + // Each span produced by the refiner represents an ordinary code region. + mappings::CodeMapping { span, bcb } + })); +} + +/// Macros that expand into branches (e.g. `assert!`, `trace!`) tend to generate +/// multiple condition/consequent blocks that have the span of the whole macro +/// invocation, which is unhelpful. Keeping only the first such span seems to +/// give better mappings, so remove the others. +/// +/// Similarly, `await` expands to a branch on the discriminant of `Poll`, which +/// leads to incorrect coverage if the `Future` is immediately ready (#98712). +/// +/// (The input spans should be sorted in BCB dominator order, so that the +/// retained "first" span is likely to dominate the others.) +fn remove_unwanted_expansion_spans(covspans: &mut Vec<SpanFromMir>) { + let mut deduplicated_spans = FxHashSet::default(); + + covspans.retain(|covspan| { + match covspan.expn_kind { + // Retain only the first await-related or macro-expanded covspan with this span. + Some(ExpnKind::Desugaring(DesugaringKind::Await)) => { + deduplicated_spans.insert(covspan.span) + } + Some(ExpnKind::Macro(MacroKind::Bang, _)) => deduplicated_spans.insert(covspan.span), + // Ignore (retain) other spans. + _ => true, + } + }); +} + +/// When a span corresponds to a macro invocation that is visible from the +/// function body, truncate it to just the macro name plus `!`. +/// This seems to give better results for code that uses macros. +fn shrink_visible_macro_spans(tcx: TyCtxt<'_>, covspans: &mut Vec<SpanFromMir>) { + let source_map = tcx.sess.source_map(); + + for covspan in covspans { + if matches!(covspan.expn_kind, Some(ExpnKind::Macro(MacroKind::Bang, _))) { + covspan.span = source_map.span_through_char(covspan.span, '!'); + } + } +} + +/// Discard all covspans that overlap a hole. +/// +/// The lists of covspans and holes must be sorted, and any holes that overlap +/// with each other must have already been merged. +fn discard_spans_overlapping_holes(covspans: &mut Vec<Covspan>, holes: &[Hole]) { + debug_assert!(covspans.is_sorted_by(|a, b| compare_spans(a.span, b.span).is_le())); + debug_assert!(holes.is_sorted_by(|a, b| compare_spans(a.span, b.span).is_le())); + debug_assert!(holes.array_windows().all(|[a, b]| !a.span.overlaps_or_adjacent(b.span))); + + let mut curr_hole = 0usize; + let mut overlaps_hole = |covspan: &Covspan| -> bool { + while let Some(hole) = holes.get(curr_hole) { + // Both lists are sorted, so we can permanently skip any holes that + // end before the start of the current span. + if hole.span.hi() <= covspan.span.lo() { + curr_hole += 1; + continue; + } + + return hole.span.overlaps(covspan.span); + } + + // No holes left, so this covspan doesn't overlap with any holes. + false + }; + + covspans.retain(|covspan| !overlaps_hole(covspan)); +} + +/// Takes a list of sorted spans extracted from MIR, and "refines" +/// those spans by removing spans that overlap in unwanted ways. +#[instrument(level = "debug")] +fn remove_unwanted_overlapping_spans(sorted_spans: Vec<Covspan>) -> Vec<Covspan> { + debug_assert!(sorted_spans.is_sorted_by(|a, b| compare_spans(a.span, b.span).is_le())); + + // Holds spans that have been read from the input vector, but haven't yet + // been committed to the output vector. + let mut pending = vec![]; + let mut refined = vec![]; + + for curr in sorted_spans { + pending.retain(|prev: &Covspan| { + if prev.span.hi() <= curr.span.lo() { + // There's no overlap between the previous/current covspans, + // so move the previous one into the refined list. + refined.push(prev.clone()); + false + } else { + // Otherwise, retain the previous covspan only if it has the + // same BCB. This tends to discard long outer spans that enclose + // smaller inner spans with different control flow. + prev.bcb == curr.bcb + } + }); + pending.push(curr); + } + + // Drain the rest of the pending list into the refined list. + refined.extend(pending); + refined +} + +#[derive(Clone, Debug)] +struct Covspan { + span: Span, + bcb: BasicCoverageBlock, +} + +impl Covspan { + /// If `self` and `other` can be merged, mutates `self.span` to also + /// include `other.span` and returns true. + /// + /// Two covspans can be merged if they have the same BCB, and they are + /// overlapping or adjacent. + fn merge_if_eligible(&mut self, other: &Self) -> bool { + let eligible_for_merge = + |a: &Self, b: &Self| (a.bcb == b.bcb) && a.span.overlaps_or_adjacent(b.span); + + if eligible_for_merge(self, other) { + self.span = self.span.to(other.span); + true + } else { + false + } + } +} + +/// Compares two spans in (lo ascending, hi descending) order. +fn compare_spans(a: Span, b: Span) -> std::cmp::Ordering { + // First sort by span start. + Ord::cmp(&a.lo(), &b.lo()) + // If span starts are the same, sort by span end in reverse order. + // This ensures that if spans A and B are adjacent in the list, + // and they overlap but are not equal, then either: + // - Span A extends further left, or + // - Both have the same start and span A extends further right + .then_with(|| Ord::cmp(&a.hi(), &b.hi()).reverse()) +} diff --git a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs new file mode 100644 index 00000000000..804cd8ab3f7 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs @@ -0,0 +1,195 @@ +use std::iter; + +use rustc_middle::bug; +use rustc_middle::mir::coverage::CoverageKind; +use rustc_middle::mir::{ + self, FakeReadCause, Statement, StatementKind, Terminator, TerminatorKind, +}; +use rustc_span::{ExpnKind, Span}; + +use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; +use crate::coverage::spans::Covspan; + +#[derive(Debug)] +pub(crate) struct RawSpanFromMir { + /// A span that has been extracted from a MIR statement/terminator, but + /// hasn't been "unexpanded", so it might not lie within the function body + /// span and might be part of an expansion with a different context. + pub(crate) raw_span: Span, + pub(crate) bcb: BasicCoverageBlock, +} + +/// Generates an initial set of coverage spans from the statements and +/// terminators in the function's MIR body, each associated with its +/// corresponding node in the coverage graph. +/// +/// This is necessarily an inexact process, because MIR isn't designed to +/// capture source spans at the level of detail we would want for coverage, +/// but it's good enough to be better than nothing. +pub(crate) fn extract_raw_spans_from_mir<'tcx>( + mir_body: &mir::Body<'tcx>, + graph: &CoverageGraph, +) -> Vec<RawSpanFromMir> { + let mut raw_spans = vec![]; + + // We only care about blocks that are part of the coverage graph. + for (bcb, bcb_data) in graph.iter_enumerated() { + let make_raw_span = |raw_span: Span| RawSpanFromMir { raw_span, bcb }; + + // A coverage graph node can consist of multiple basic blocks. + for &bb in &bcb_data.basic_blocks { + let bb_data = &mir_body[bb]; + + let statements = bb_data.statements.iter(); + raw_spans.extend(statements.filter_map(filtered_statement_span).map(make_raw_span)); + + // There's only one terminator, but wrap it in an iterator to + // mirror the handling of statements. + let terminator = iter::once(bb_data.terminator()); + raw_spans.extend(terminator.filter_map(filtered_terminator_span).map(make_raw_span)); + } + } + + raw_spans +} + +/// If the MIR `Statement` has a span contributive to computing coverage spans, +/// return it; otherwise return `None`. +fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> { + match statement.kind { + // These statements have spans that are often outside the scope of the executed source code + // for their parent `BasicBlock`. + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::ConstEvalCounter + | StatementKind::BackwardIncompatibleDropHint { .. } + | StatementKind::Nop => None, + + // FIXME(#78546): MIR InstrumentCoverage - Can the source_info.span for `FakeRead` + // statements be more consistent? + // + // FakeReadCause::ForGuardBinding, in this example: + // match somenum { + // x if x < 1 => { ... } + // }... + // The BasicBlock within the match arm code included one of these statements, but the span + // for it covered the `1` in this source. The actual statements have nothing to do with that + // source span: + // FakeRead(ForGuardBinding, _4); + // where `_4` is: + // _4 = &_1; (at the span for the first `x`) + // and `_1` is the `Place` for `somenum`. + // + // If and when the Issue is resolved, remove this special case match pattern: + StatementKind::FakeRead(box (FakeReadCause::ForGuardBinding, _)) => None, + + // Retain spans from most other statements. + StatementKind::FakeRead(_) + | StatementKind::Intrinsic(..) + | StatementKind::Coverage( + // The purpose of `SpanMarker` is to be matched and accepted here. + CoverageKind::SpanMarker, + ) + | StatementKind::Assign(_) + | StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(..) + | StatementKind::Retag(_, _) + | StatementKind::PlaceMention(..) + | StatementKind::AscribeUserType(_, _) => Some(statement.source_info.span), + + // Block markers are used for branch coverage, so ignore them here. + StatementKind::Coverage(CoverageKind::BlockMarker { .. }) => None, + + // These coverage statements should not exist prior to coverage instrumentation. + StatementKind::Coverage( + CoverageKind::VirtualCounter { .. } + | CoverageKind::CondBitmapUpdate { .. } + | CoverageKind::TestVectorBitmapUpdate { .. }, + ) => bug!( + "Unexpected coverage statement found during coverage instrumentation: {statement:?}" + ), + } +} + +/// If the MIR `Terminator` has a span contributive to computing coverage spans, +/// return it; otherwise return `None`. +fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Span> { + match terminator.kind { + // These terminators have spans that don't positively contribute to computing a reasonable + // span of actually executed source code. (For example, SwitchInt terminators extracted from + // an `if condition { block }` has a span that includes the executed block, if true, + // but for coverage, the code region executed, up to *and* through the SwitchInt, + // actually stops before the if's block.) + TerminatorKind::Unreachable + | TerminatorKind::Assert { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::Goto { .. } => None, + + // Call `func` operand can have a more specific span when part of a chain of calls + TerminatorKind::Call { ref func, .. } | TerminatorKind::TailCall { ref func, .. } => { + let mut span = terminator.source_info.span; + if let mir::Operand::Constant(constant) = func + && span.contains(constant.span) + { + span = constant.span; + } + Some(span) + } + + // Retain spans from all other terminators + TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Yield { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::InlineAsm { .. } => Some(terminator.source_info.span), + } +} + +#[derive(Debug)] +pub(crate) struct Hole { + pub(crate) span: Span, +} + +impl Hole { + pub(crate) fn merge_if_overlapping_or_adjacent(&mut self, other: &mut Self) -> bool { + if !self.span.overlaps_or_adjacent(other.span) { + return false; + } + + self.span = self.span.to(other.span); + true + } +} + +#[derive(Debug)] +pub(crate) struct SpanFromMir { + /// A span that has been extracted from MIR and then "un-expanded" back to + /// within the current function's `body_span`. After various intermediate + /// processing steps, this span is emitted as part of the final coverage + /// mappings. + /// + /// With the exception of `fn_sig_span`, this should always be contained + /// within `body_span`. + pub(crate) span: Span, + pub(crate) expn_kind: Option<ExpnKind>, + pub(crate) bcb: BasicCoverageBlock, +} + +impl SpanFromMir { + pub(crate) fn for_fn_sig(fn_sig_span: Span) -> Self { + Self::new(fn_sig_span, None, START_BCB) + } + + pub(crate) fn new(span: Span, expn_kind: Option<ExpnKind>, bcb: BasicCoverageBlock) -> Self { + Self { span, expn_kind, bcb } + } + + pub(crate) fn into_covspan(self) -> Covspan { + let Self { span, expn_kind: _, bcb } = self; + Covspan { span, bcb } + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/tests.rs b/compiler/rustc_mir_transform/src/coverage/tests.rs new file mode 100644 index 00000000000..3c0053c610d --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/tests.rs @@ -0,0 +1,528 @@ +//! This crate hosts a selection of "unit tests" for components of the `InstrumentCoverage` MIR +//! pass. +//! +//! ```shell +//! ./x.py test --keep-stage 1 compiler/rustc_mir --test-args '--show-output coverage' +//! ``` +//! +//! The tests construct a few "mock" objects, as needed, to support the `InstrumentCoverage` +//! functions and algorithms. Mocked objects include instances of `mir::Body`; including +//! `Terminator`s of various `kind`s, and `Span` objects. Some functions used by or used on +//! real, runtime versions of these mocked-up objects have constraints (such as cross-thread +//! limitations) and deep dependencies on other elements of the full Rust compiler (which is +//! *not* constructed or mocked for these tests). +//! +//! Of particular note, attempting to simply print elements of the `mir::Body` with default +//! `Debug` formatting can fail because some `Debug` format implementations require the +//! `TyCtxt`, obtained via a static global variable that is *not* set for these tests. +//! Initializing the global type context is prohibitively complex for the scope and scale of these +//! tests (essentially requiring initializing the entire compiler). +//! +//! Also note, some basic features of `Span` also rely on the `Span`s own "session globals", which +//! are unrelated to the `TyCtxt` global. Without initializing the `Span` session globals, some +//! basic, coverage-specific features would be impossible to test, but thankfully initializing these +//! globals is comparatively simpler. The easiest way is to wrap the test in a closure argument +//! to: `rustc_span::create_default_session_globals_then(|| { test_here(); })`. + +use itertools::Itertools; +use rustc_data_structures::graph::{DirectedGraph, Successors}; +use rustc_index::{Idx, IndexVec}; +use rustc_middle::mir::*; +use rustc_middle::{bug, ty}; +use rustc_span::{BytePos, DUMMY_SP, Pos, Span}; + +use super::graph::{self, BasicCoverageBlock}; + +fn bcb(index: u32) -> BasicCoverageBlock { + BasicCoverageBlock::from_u32(index) +} + +// All `TEMP_BLOCK` targets should be replaced before calling `to_body() -> mir::Body`. +const TEMP_BLOCK: BasicBlock = BasicBlock::MAX; + +struct MockBlocks<'tcx> { + blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>>, + dummy_place: Place<'tcx>, + next_local: usize, +} + +impl<'tcx> MockBlocks<'tcx> { + fn new() -> Self { + Self { + blocks: IndexVec::new(), + dummy_place: Place { local: RETURN_PLACE, projection: ty::List::empty() }, + next_local: 0, + } + } + + fn new_temp(&mut self) -> Local { + let index = self.next_local; + self.next_local += 1; + Local::new(index) + } + + fn push(&mut self, kind: TerminatorKind<'tcx>) -> BasicBlock { + let next_lo = if let Some(last) = self.blocks.last_index() { + self.blocks[last].terminator().source_info.span.hi() + } else { + BytePos(1) + }; + let next_hi = next_lo + BytePos(1); + self.blocks.push(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + source_info: SourceInfo::outermost(Span::with_root_ctxt(next_lo, next_hi)), + kind, + }), + is_cleanup: false, + }) + } + + fn link(&mut self, from_block: BasicBlock, to_block: BasicBlock) { + match self.blocks[from_block].terminator_mut().kind { + TerminatorKind::Assert { ref mut target, .. } + | TerminatorKind::Call { target: Some(ref mut target), .. } + | TerminatorKind::Drop { ref mut target, .. } + | TerminatorKind::FalseEdge { real_target: ref mut target, .. } + | TerminatorKind::FalseUnwind { real_target: ref mut target, .. } + | TerminatorKind::Goto { ref mut target } + | TerminatorKind::Yield { resume: ref mut target, .. } => *target = to_block, + ref invalid => bug!("Invalid from_block: {:?}", invalid), + } + } + + fn add_block_from( + &mut self, + some_from_block: Option<BasicBlock>, + to_kind: TerminatorKind<'tcx>, + ) -> BasicBlock { + let new_block = self.push(to_kind); + if let Some(from_block) = some_from_block { + self.link(from_block, new_block); + } + new_block + } + + fn set_branch(&mut self, switchint: BasicBlock, branch_index: usize, to_block: BasicBlock) { + match self.blocks[switchint].terminator_mut().kind { + TerminatorKind::SwitchInt { ref mut targets, .. } => { + let mut branches = targets.iter().collect::<Vec<_>>(); + let otherwise = if branch_index == branches.len() { + to_block + } else { + let old_otherwise = targets.otherwise(); + if branch_index > branches.len() { + branches.push((branches.len() as u128, old_otherwise)); + while branches.len() < branch_index { + branches.push((branches.len() as u128, TEMP_BLOCK)); + } + to_block + } else { + branches[branch_index] = (branch_index as u128, to_block); + old_otherwise + } + }; + *targets = SwitchTargets::new(branches.into_iter(), otherwise); + } + ref invalid => bug!("Invalid BasicBlock kind or no to_block: {:?}", invalid), + } + } + + fn call(&mut self, some_from_block: Option<BasicBlock>) -> BasicBlock { + self.add_block_from( + some_from_block, + TerminatorKind::Call { + func: Operand::Copy(self.dummy_place.clone()), + args: [].into(), + destination: self.dummy_place.clone(), + target: Some(TEMP_BLOCK), + unwind: UnwindAction::Continue, + call_source: CallSource::Misc, + fn_span: DUMMY_SP, + }, + ) + } + + fn goto(&mut self, some_from_block: Option<BasicBlock>) -> BasicBlock { + self.add_block_from(some_from_block, TerminatorKind::Goto { target: TEMP_BLOCK }) + } + + fn switchint(&mut self, some_from_block: Option<BasicBlock>) -> BasicBlock { + let switchint_kind = TerminatorKind::SwitchInt { + discr: Operand::Move(Place::from(self.new_temp())), + targets: SwitchTargets::static_if(0, TEMP_BLOCK, TEMP_BLOCK), + }; + self.add_block_from(some_from_block, switchint_kind) + } + + fn return_(&mut self, some_from_block: Option<BasicBlock>) -> BasicBlock { + self.add_block_from(some_from_block, TerminatorKind::Return) + } + + fn to_body(self) -> Body<'tcx> { + Body::new_cfg_only(self.blocks) + } +} + +fn debug_basic_blocks(mir_body: &Body<'_>) -> String { + format!( + "{:?}", + mir_body + .basic_blocks + .iter_enumerated() + .map(|(bb, data)| { + let term = &data.terminator(); + let kind = &term.kind; + let span = term.source_info.span; + let sp = format!("(span:{},{})", span.lo().to_u32(), span.hi().to_u32()); + match kind { + TerminatorKind::Assert { target, .. } + | TerminatorKind::Call { target: Some(target), .. } + | TerminatorKind::Drop { target, .. } + | TerminatorKind::FalseEdge { real_target: target, .. } + | TerminatorKind::FalseUnwind { real_target: target, .. } + | TerminatorKind::Goto { target } + | TerminatorKind::Yield { resume: target, .. } => { + format!("{}{:?}:{} -> {:?}", sp, bb, kind.name(), target) + } + TerminatorKind::InlineAsm { targets, .. } => { + format!("{}{:?}:{} -> {:?}", sp, bb, kind.name(), targets) + } + TerminatorKind::SwitchInt { targets, .. } => { + format!("{}{:?}:{} -> {:?}", sp, bb, kind.name(), targets) + } + _ => format!("{}{:?}:{}", sp, bb, kind.name()), + } + }) + .collect::<Vec<_>>() + ) +} + +static PRINT_GRAPHS: bool = false; + +fn print_mir_graphviz(name: &str, mir_body: &Body<'_>) { + if PRINT_GRAPHS { + println!( + "digraph {} {{\n{}\n}}", + name, + mir_body + .basic_blocks + .iter_enumerated() + .map(|(bb, data)| { + format!( + " {:?} [label=\"{:?}: {}\"];\n{}", + bb, + bb, + data.terminator().kind.name(), + mir_body + .basic_blocks + .successors(bb) + .map(|successor| { format!(" {:?} -> {:?};", bb, successor) }) + .join("\n") + ) + }) + .join("\n") + ); + } +} + +fn print_coverage_graphviz(name: &str, mir_body: &Body<'_>, graph: &graph::CoverageGraph) { + if PRINT_GRAPHS { + println!( + "digraph {} {{\n{}\n}}", + name, + graph + .iter_enumerated() + .map(|(bcb, bcb_data)| { + format!( + " {:?} [label=\"{:?}: {}\"];\n{}", + bcb, + bcb, + mir_body[bcb_data.last_bb()].terminator().kind.name(), + graph + .successors(bcb) + .map(|successor| { format!(" {:?} -> {:?};", bcb, successor) }) + .join("\n") + ) + }) + .join("\n") + ); + } +} + +/// Create a mock `Body` with a simple flow. +fn goto_switchint<'a>() -> Body<'a> { + let mut blocks = MockBlocks::new(); + let start = blocks.call(None); + let goto = blocks.goto(Some(start)); + let switchint = blocks.switchint(Some(goto)); + let then_call = blocks.call(None); + let else_call = blocks.call(None); + blocks.set_branch(switchint, 0, then_call); + blocks.set_branch(switchint, 1, else_call); + blocks.return_(Some(then_call)); + blocks.return_(Some(else_call)); + + let mir_body = blocks.to_body(); + print_mir_graphviz("mir_goto_switchint", &mir_body); + /* Graphviz character plots created using: `graph-easy --as=boxart`: + ┌────────────────┐ + │ bb0: Call │ + └────────────────┘ + │ + │ + ▼ + ┌────────────────┐ + │ bb1: Goto │ + └────────────────┘ + │ + │ + ▼ + ┌─────────────┐ ┌────────────────┐ + │ bb4: Call │ ◀── │ bb2: SwitchInt │ + └─────────────┘ └────────────────┘ + │ │ + │ │ + ▼ ▼ + ┌─────────────┐ ┌────────────────┐ + │ bb6: Return │ │ bb3: Call │ + └─────────────┘ └────────────────┘ + │ + │ + ▼ + ┌────────────────┐ + │ bb5: Return │ + └────────────────┘ + */ + mir_body +} + +#[track_caller] +fn assert_successors( + graph: &graph::CoverageGraph, + bcb: BasicCoverageBlock, + expected_successors: &[BasicCoverageBlock], +) { + let mut successors = graph.successors[bcb].clone(); + successors.sort_unstable(); + assert_eq!(successors, expected_successors); +} + +#[test] +fn test_covgraph_goto_switchint() { + let mir_body = goto_switchint(); + if false { + eprintln!("basic_blocks = {}", debug_basic_blocks(&mir_body)); + } + let graph = graph::CoverageGraph::from_mir(&mir_body); + print_coverage_graphviz("covgraph_goto_switchint ", &mir_body, &graph); + /* + ┌──────────────┐ ┌─────────────────┐ + │ bcb2: Return │ ◀── │ bcb0: SwitchInt │ + └──────────────┘ └─────────────────┘ + │ + │ + ▼ + ┌─────────────────┐ + │ bcb1: Return │ + └─────────────────┘ + */ + assert_eq!(graph.num_nodes(), 3, "graph: {:?}", graph.iter_enumerated().collect::<Vec<_>>()); + + assert_successors(&graph, bcb(0), &[bcb(1), bcb(2)]); + assert_successors(&graph, bcb(1), &[]); + assert_successors(&graph, bcb(2), &[]); +} + +/// Create a mock `Body` with a loop. +fn switchint_then_loop_else_return<'a>() -> Body<'a> { + let mut blocks = MockBlocks::new(); + let start = blocks.call(None); + let switchint = blocks.switchint(Some(start)); + let then_call = blocks.call(None); + blocks.set_branch(switchint, 0, then_call); + let backedge_goto = blocks.goto(Some(then_call)); + blocks.link(backedge_goto, switchint); + let else_return = blocks.return_(None); + blocks.set_branch(switchint, 1, else_return); + + let mir_body = blocks.to_body(); + print_mir_graphviz("mir_switchint_then_loop_else_return", &mir_body); + /* + ┌────────────────┐ + │ bb0: Call │ + └────────────────┘ + │ + │ + ▼ + ┌─────────────┐ ┌────────────────┐ + │ bb4: Return │ ◀── │ bb1: SwitchInt │ ◀┐ + └─────────────┘ └────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌────────────────┐ │ + │ bb2: Call │ │ + └────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌────────────────┐ │ + │ bb3: Goto │ ─┘ + └────────────────┘ + */ + mir_body +} + +#[test] +fn test_covgraph_switchint_then_loop_else_return() { + let mir_body = switchint_then_loop_else_return(); + let graph = graph::CoverageGraph::from_mir(&mir_body); + print_coverage_graphviz("covgraph_switchint_then_loop_else_return", &mir_body, &graph); + /* + ┌─────────────────┐ + │ bcb0: Call │ + └─────────────────┘ + │ + │ + ▼ + ┌────────────┐ ┌─────────────────┐ + │ bcb3: Goto │ ◀── │ bcb1: SwitchInt │ ◀┐ + └────────────┘ └─────────────────┘ │ + │ │ │ + │ │ │ + │ ▼ │ + │ ┌─────────────────┐ │ + │ │ bcb2: Return │ │ + │ └─────────────────┘ │ + │ │ + └─────────────────────────────────────┘ + */ + assert_eq!(graph.num_nodes(), 4, "graph: {:?}", graph.iter_enumerated().collect::<Vec<_>>()); + + assert_successors(&graph, bcb(0), &[bcb(1)]); + assert_successors(&graph, bcb(1), &[bcb(2), bcb(3)]); + assert_successors(&graph, bcb(2), &[]); + assert_successors(&graph, bcb(3), &[bcb(1)]); +} + +/// Create a mock `Body` with nested loops. +fn switchint_loop_then_inner_loop_else_break<'a>() -> Body<'a> { + let mut blocks = MockBlocks::new(); + let start = blocks.call(None); + let switchint = blocks.switchint(Some(start)); + let then_call = blocks.call(None); + blocks.set_branch(switchint, 0, then_call); + let else_return = blocks.return_(None); + blocks.set_branch(switchint, 1, else_return); + + let inner_start = blocks.call(Some(then_call)); + let inner_switchint = blocks.switchint(Some(inner_start)); + let inner_then_call = blocks.call(None); + blocks.set_branch(inner_switchint, 0, inner_then_call); + let inner_backedge_goto = blocks.goto(Some(inner_then_call)); + blocks.link(inner_backedge_goto, inner_switchint); + let inner_else_break_goto = blocks.goto(None); + blocks.set_branch(inner_switchint, 1, inner_else_break_goto); + + let backedge_goto = blocks.goto(Some(inner_else_break_goto)); + blocks.link(backedge_goto, switchint); + + let mir_body = blocks.to_body(); + print_mir_graphviz("mir_switchint_loop_then_inner_loop_else_break", &mir_body); + /* + ┌────────────────┐ + │ bb0: Call │ + └────────────────┘ + │ + │ + ▼ + ┌─────────────┐ ┌────────────────┐ + │ bb3: Return │ ◀── │ bb1: SwitchInt │ ◀─────┐ + └─────────────┘ └────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌────────────────┐ │ + │ bb2: Call │ │ + └────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌────────────────┐ │ + │ bb4: Call │ │ + └────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌─────────────┐ ┌────────────────┐ │ + │ bb8: Goto │ ◀── │ bb5: SwitchInt │ ◀┐ │ + └─────────────┘ └────────────────┘ │ │ + │ │ │ │ + │ │ │ │ + ▼ ▼ │ │ + ┌─────────────┐ ┌────────────────┐ │ │ + │ bb9: Goto │ ─┐ │ bb6: Call │ │ │ + └─────────────┘ │ └────────────────┘ │ │ + │ │ │ │ + │ │ │ │ + │ ▼ │ │ + │ ┌────────────────┐ │ │ + │ │ bb7: Goto │ ─┘ │ + │ └────────────────┘ │ + │ │ + └───────────────────────────┘ + */ + mir_body +} + +#[test] +fn test_covgraph_switchint_loop_then_inner_loop_else_break() { + let mir_body = switchint_loop_then_inner_loop_else_break(); + let graph = graph::CoverageGraph::from_mir(&mir_body); + print_coverage_graphviz( + "covgraph_switchint_loop_then_inner_loop_else_break", + &mir_body, + &graph, + ); + /* + ┌─────────────────┐ + │ bcb0: Call │ + └─────────────────┘ + │ + │ + ▼ + ┌──────────────┐ ┌─────────────────┐ + │ bcb2: Return │ ◀── │ bcb1: SwitchInt │ ◀┐ + └──────────────┘ └─────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌─────────────────┐ │ + │ bcb3: Call │ │ + └─────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌──────────────┐ ┌─────────────────┐ │ + │ bcb6: Goto │ ◀── │ bcb4: SwitchInt │ ◀┼────┐ + └──────────────┘ └─────────────────┘ │ │ + │ │ │ │ + │ │ │ │ + │ ▼ │ │ + │ ┌─────────────────┐ │ │ + │ │ bcb5: Goto │ ─┘ │ + │ └─────────────────┘ │ + │ │ + └────────────────────────────────────────────┘ + */ + assert_eq!(graph.num_nodes(), 7, "graph: {:?}", graph.iter_enumerated().collect::<Vec<_>>()); + + assert_successors(&graph, bcb(0), &[bcb(1)]); + assert_successors(&graph, bcb(1), &[bcb(2), bcb(3)]); + assert_successors(&graph, bcb(2), &[]); + assert_successors(&graph, bcb(3), &[bcb(4)]); + assert_successors(&graph, bcb(4), &[bcb(5), bcb(6)]); + assert_successors(&graph, bcb(5), &[bcb(1)]); + assert_successors(&graph, bcb(6), &[bcb(4)]); +} diff --git a/compiler/rustc_mir_transform/src/coverage/unexpand.rs b/compiler/rustc_mir_transform/src/coverage/unexpand.rs new file mode 100644 index 00000000000..cb861544736 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/unexpand.rs @@ -0,0 +1,55 @@ +use rustc_span::{ExpnKind, Span}; + +/// Walks through the expansion ancestors of `original_span` to find a span that +/// is contained in `body_span` and has the same [syntax context] as `body_span`. +pub(crate) fn unexpand_into_body_span(original_span: Span, body_span: Span) -> Option<Span> { + // Because we don't need to return any extra ancestor information, + // we can just delegate directly to `find_ancestor_inside_same_ctxt`. + original_span.find_ancestor_inside_same_ctxt(body_span) +} + +/// Walks through the expansion ancestors of `original_span` to find a span that +/// is contained in `body_span` and has the same [syntax context] as `body_span`. +/// +/// If the returned span represents a bang-macro invocation (e.g. `foo!(..)`), +/// the returned symbol will be the name of that macro (e.g. `foo`). +pub(crate) fn unexpand_into_body_span_with_expn_kind( + original_span: Span, + body_span: Span, +) -> Option<(Span, Option<ExpnKind>)> { + let (span, prev) = unexpand_into_body_span_with_prev(original_span, body_span)?; + + let expn_kind = prev.map(|prev| prev.ctxt().outer_expn_data().kind); + + Some((span, expn_kind)) +} + +/// Walks through the expansion ancestors of `original_span` to find a span that +/// is contained in `body_span` and has the same [syntax context] as `body_span`. +/// The ancestor that was traversed just before the matching span (if any) is +/// also returned. +/// +/// For example, a return value of `Some((ancestor, Some(prev)))` means that: +/// - `ancestor == original_span.find_ancestor_inside_same_ctxt(body_span)` +/// - `prev.parent_callsite() == ancestor` +/// +/// [syntax context]: rustc_span::SyntaxContext +fn unexpand_into_body_span_with_prev( + original_span: Span, + body_span: Span, +) -> Option<(Span, Option<Span>)> { + let mut prev = None; + let mut curr = original_span; + + while !body_span.contains(curr) || !curr.eq_ctxt(body_span) { + prev = Some(curr); + curr = curr.parent_callsite()?; + } + + debug_assert_eq!(Some(curr), original_span.find_ancestor_inside_same_ctxt(body_span)); + if let Some(prev) = prev { + debug_assert_eq!(Some(curr), prev.parent_callsite()); + } + + Some((curr, prev)) +} diff --git a/compiler/rustc_mir_transform/src/cross_crate_inline.rs b/compiler/rustc_mir_transform/src/cross_crate_inline.rs new file mode 100644 index 00000000000..6d7b7e10ef6 --- /dev/null +++ b/compiler/rustc_mir_transform/src/cross_crate_inline.rs @@ -0,0 +1,161 @@ +use rustc_attr_data_structures::InlineAttr; +use rustc_hir::def::DefKind; +use rustc_hir::def_id::LocalDefId; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::query::Providers; +use rustc_middle::ty::TyCtxt; +use rustc_session::config::{InliningThreshold, OptLevel}; +use rustc_span::sym; + +use crate::{inline, pass_manager as pm}; + +pub(super) fn provide(providers: &mut Providers) { + providers.cross_crate_inlinable = cross_crate_inlinable; +} + +fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { + let codegen_fn_attrs = tcx.codegen_fn_attrs(def_id); + // If this has an extern indicator, then this function is globally shared and thus will not + // generate cgu-internal copies which would make it cross-crate inlinable. + if codegen_fn_attrs.contains_extern_indicator() { + return false; + } + + // This just reproduces the logic from Instance::requires_inline. + match tcx.def_kind(def_id) { + DefKind::Ctor(..) | DefKind::Closure | DefKind::SyntheticCoroutineBody => return true, + DefKind::Fn | DefKind::AssocFn => {} + _ => return false, + } + + // From this point on, it is valid to return true or false. + if tcx.sess.opts.unstable_opts.cross_crate_inline_threshold == InliningThreshold::Always { + return true; + } + + if tcx.has_attr(def_id, sym::rustc_intrinsic) { + // Intrinsic fallback bodies are always cross-crate inlineable. + // To ensure that the MIR inliner doesn't cluelessly try to inline fallback + // bodies even when the backend would implement something better, we stop + // the MIR inliner from ever inlining an intrinsic. + return true; + } + + // Obey source annotations first; this is important because it means we can use + // #[inline(never)] to force code generation. + match codegen_fn_attrs.inline { + InlineAttr::Never => return false, + InlineAttr::Hint | InlineAttr::Always | InlineAttr::Force { .. } => return true, + _ => {} + } + + // If the crate is likely to be mostly unused, use cross-crate inlining to defer codegen until + // the function is referenced, in order to skip codegen for unused functions. This is + // intentionally after the check for `inline(never)`, so that `inline(never)` wins. + if tcx.sess.opts.unstable_opts.hint_mostly_unused { + return true; + } + + let sig = tcx.fn_sig(def_id).instantiate_identity(); + for ty in sig.inputs().skip_binder().iter().chain(std::iter::once(&sig.output().skip_binder())) + { + // FIXME(f16_f128): in order to avoid crashes building `core`, always inline to skip + // codegen if the function is not used. + if ty == &tcx.types.f16 || ty == &tcx.types.f128 { + return true; + } + } + + // Don't do any inference when incremental compilation is enabled; the additional inlining that + // inference permits also creates more work for small edits. + if tcx.sess.opts.incremental.is_some() { + return false; + } + + // Don't do any inference if codegen optimizations are disabled and also MIR inlining is not + // enabled. This ensures that we do inference even if someone only passes -Zinline-mir, + // which is less confusing than having to also enable -Copt-level=1. + let inliner_will_run = pm::should_run_pass(tcx, &inline::Inline, pm::Optimizations::Allowed) + || inline::ForceInline::should_run_pass_for_callee(tcx, def_id.to_def_id()); + if matches!(tcx.sess.opts.optimize, OptLevel::No) && !inliner_will_run { + return false; + } + + if !tcx.is_mir_available(def_id) { + return false; + } + + let threshold = match tcx.sess.opts.unstable_opts.cross_crate_inline_threshold { + InliningThreshold::Always => return true, + InliningThreshold::Sometimes(threshold) => threshold, + InliningThreshold::Never => return false, + }; + + let mir = tcx.optimized_mir(def_id); + let mut checker = + CostChecker { tcx, callee_body: mir, calls: 0, statements: 0, landing_pads: 0, resumes: 0 }; + checker.visit_body(mir); + checker.calls == 0 + && checker.resumes == 0 + && checker.landing_pads == 0 + && checker.statements <= threshold +} + +struct CostChecker<'b, 'tcx> { + tcx: TyCtxt<'tcx>, + callee_body: &'b Body<'tcx>, + calls: usize, + statements: usize, + landing_pads: usize, + resumes: usize, +} + +impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { + fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) { + // Don't count StorageLive/StorageDead in the inlining cost. + match statement.kind { + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Deinit(_) + | StatementKind::Nop => {} + _ => self.statements += 1, + } + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) { + let tcx = self.tcx; + match terminator.kind { + TerminatorKind::Drop { ref place, unwind, .. } => { + let ty = place.ty(self.callee_body, tcx).ty; + if !ty.is_trivially_pure_clone_copy() { + self.calls += 1; + if let UnwindAction::Cleanup(_) = unwind { + self.landing_pads += 1; + } + } + } + TerminatorKind::Call { unwind, .. } => { + self.calls += 1; + if let UnwindAction::Cleanup(_) = unwind { + self.landing_pads += 1; + } + } + TerminatorKind::Assert { unwind, .. } => { + self.calls += 1; + if let UnwindAction::Cleanup(_) = unwind { + self.landing_pads += 1; + } + } + TerminatorKind::UnwindResume => self.resumes += 1, + TerminatorKind::InlineAsm { unwind, .. } => { + self.statements += 1; + if let UnwindAction::Cleanup(_) = unwind { + self.landing_pads += 1; + } + } + TerminatorKind::Return => {} + _ => self.statements += 1, + } + } +} diff --git a/compiler/rustc_mir_transform/src/ctfe_limit.rs b/compiler/rustc_mir_transform/src/ctfe_limit.rs new file mode 100644 index 00000000000..d0b313e149a --- /dev/null +++ b/compiler/rustc_mir_transform/src/ctfe_limit.rs @@ -0,0 +1,62 @@ +//! A pass that inserts the `ConstEvalCounter` instruction into any blocks that have a back edge +//! (thus indicating there is a loop in the CFG), or whose terminator is a function call. + +use rustc_data_structures::graph::dominators::Dominators; +use rustc_middle::mir::{ + BasicBlock, BasicBlockData, Body, Statement, StatementKind, TerminatorKind, +}; +use rustc_middle::ty::TyCtxt; +use tracing::instrument; + +pub(super) struct CtfeLimit; + +impl<'tcx> crate::MirPass<'tcx> for CtfeLimit { + #[instrument(skip(self, _tcx, body))] + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let doms = body.basic_blocks.dominators(); + let indices: Vec<BasicBlock> = body + .basic_blocks + .iter_enumerated() + .filter_map(|(node, node_data)| { + if matches!(node_data.terminator().kind, TerminatorKind::Call { .. }) + // Back edges in a CFG indicate loops + || has_back_edge(doms, node, node_data) + { + Some(node) + } else { + None + } + }) + .collect(); + for index in indices { + insert_counter( + body.basic_blocks_mut() + .get_mut(index) + .expect("basic_blocks index {index} should exist"), + ); + } + } + + fn is_required(&self) -> bool { + true + } +} + +fn has_back_edge( + doms: &Dominators<BasicBlock>, + node: BasicBlock, + node_data: &BasicBlockData<'_>, +) -> bool { + if !doms.is_reachable(node) { + return false; + } + // Check if any of the dominators of the node are also the node's successor. + node_data.terminator().successors().any(|succ| doms.dominates(succ, node)) +} + +fn insert_counter(basic_block_data: &mut BasicBlockData<'_>) { + basic_block_data.statements.push(Statement { + source_info: basic_block_data.terminator().source_info, + kind: StatementKind::ConstEvalCounter, + }); +} diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs new file mode 100644 index 00000000000..fe53de31f75 --- /dev/null +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -0,0 +1,1103 @@ +//! A constant propagation optimization pass based on dataflow analysis. +//! +//! Currently, this pass only propagates scalar values. + +use std::assert_matches::assert_matches; +use std::fmt::Formatter; + +use rustc_abi::{BackendRepr, FIRST_VARIANT, FieldIdx, Size, VariantIdx}; +use rustc_const_eval::const_eval::{DummyMachine, throw_machine_stop_str}; +use rustc_const_eval::interpret::{ + ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable, interp_ok, +}; +use rustc_data_structures::fx::FxHashMap; +use rustc_hir::def::DefKind; +use rustc_middle::bug; +use rustc_middle::mir::interpret::{InterpResult, Scalar}; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_mir_dataflow::fmt::DebugWithContext; +use rustc_mir_dataflow::lattice::{FlatSet, HasBottom}; +use rustc_mir_dataflow::value_analysis::{ + Map, PlaceIndex, State, TrackElem, ValueOrPlace, debug_with_context, +}; +use rustc_mir_dataflow::{Analysis, ResultsVisitor, visit_reachable_results}; +use rustc_span::DUMMY_SP; +use tracing::{debug, debug_span, instrument}; + +// These constants are somewhat random guesses and have not been optimized. +// If `tcx.sess.mir_opt_level() >= 4`, we ignore the limits (this can become very expensive). +const BLOCK_LIMIT: usize = 100; +const PLACE_LIMIT: usize = 100; + +pub(super) struct DataflowConstProp; + +impl<'tcx> crate::MirPass<'tcx> for DataflowConstProp { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 3 + } + + #[instrument(skip_all level = "debug")] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); + if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT { + debug!("aborted dataflow const prop due too many basic blocks"); + return; + } + + // We want to have a somewhat linear runtime w.r.t. the number of statements/terminators. + // Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function + // applications, where `h` is the height of the lattice. Because the height of our lattice + // is linear w.r.t. the number of tracked places, this is `O(tracked_places * n)`. However, + // because every transfer function application could traverse the whole map, this becomes + // `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of + // map nodes is strongly correlated to the number of tracked places, this becomes more or + // less `O(n)` if we place a constant limit on the number of tracked places. + let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None }; + + // Decide which places to track during the analysis. + let map = Map::new(tcx, body, place_limit); + + // Perform the actual dataflow analysis. + let mut const_ = debug_span!("analyze") + .in_scope(|| ConstAnalysis::new(tcx, body, map).iterate_to_fixpoint(tcx, body, None)); + + // Collect results and patch the body afterwards. + let mut visitor = Collector::new(tcx, &body.local_decls); + debug_span!("collect").in_scope(|| { + visit_reachable_results(body, &mut const_.analysis, &const_.results, &mut visitor) + }); + let mut patch = visitor.patch; + debug_span!("patch").in_scope(|| patch.visit_body_preserves_cfg(body)); + } + + fn is_required(&self) -> bool { + false + } +} + +// Note: Currently, places that have their reference taken cannot be tracked. Although this would +// be possible, it has to rely on some aliasing model, which we are not ready to commit to yet. +// Because of that, we can assume that the only way to change the value behind a tracked place is +// by direct assignment. +struct ConstAnalysis<'a, 'tcx> { + map: Map<'tcx>, + tcx: TyCtxt<'tcx>, + local_decls: &'a LocalDecls<'tcx>, + ecx: InterpCx<'tcx, DummyMachine>, + typing_env: ty::TypingEnv<'tcx>, +} + +impl<'tcx> Analysis<'tcx> for ConstAnalysis<'_, 'tcx> { + type Domain = State<FlatSet<Scalar>>; + + const NAME: &'static str = "ConstAnalysis"; + + // The bottom state denotes uninitialized memory. Because we are only doing a sound + // approximation of the actual execution, we can also use this state for places where access + // would be UB. + fn bottom_value(&self, _body: &Body<'tcx>) -> Self::Domain { + State::Unreachable + } + + fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) { + // The initial state maps all tracked places of argument projections to ⊤ and the rest to ⊥. + assert_matches!(state, State::Unreachable); + *state = State::new_reachable(); + for arg in body.args_iter() { + state.flood(PlaceRef { local: arg, projection: &[] }, &self.map); + } + } + + fn apply_primary_statement_effect( + &mut self, + state: &mut Self::Domain, + statement: &Statement<'tcx>, + _location: Location, + ) { + if state.is_reachable() { + self.handle_statement(statement, state); + } + } + + fn apply_primary_terminator_effect<'mir>( + &mut self, + state: &mut Self::Domain, + terminator: &'mir Terminator<'tcx>, + _location: Location, + ) -> TerminatorEdges<'mir, 'tcx> { + if state.is_reachable() { + self.handle_terminator(terminator, state) + } else { + TerminatorEdges::None + } + } + + fn apply_call_return_effect( + &mut self, + state: &mut Self::Domain, + _block: BasicBlock, + return_places: CallReturnPlaces<'_, 'tcx>, + ) { + if state.is_reachable() { + self.handle_call_return(return_places, state) + } + } +} + +impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { + fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map<'tcx>) -> Self { + let typing_env = body.typing_env(tcx); + Self { + map, + tcx, + local_decls: &body.local_decls, + ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine), + typing_env, + } + } + + fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State<FlatSet<Scalar>>) { + match &statement.kind { + StatementKind::Assign(box (place, rvalue)) => { + self.handle_assign(*place, rvalue, state); + } + StatementKind::SetDiscriminant { box place, variant_index } => { + self.handle_set_discriminant(*place, *variant_index, state); + } + StatementKind::Intrinsic(box intrinsic) => { + self.handle_intrinsic(intrinsic); + } + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + // StorageLive leaves the local in an uninitialized state. + // StorageDead makes it UB to access the local afterwards. + state.flood_with( + Place::from(*local).as_ref(), + &self.map, + FlatSet::<Scalar>::BOTTOM, + ); + } + StatementKind::Deinit(box place) => { + // Deinit makes the place uninitialized. + state.flood_with(place.as_ref(), &self.map, FlatSet::<Scalar>::BOTTOM); + } + StatementKind::Retag(..) => { + // We don't track references. + } + StatementKind::ConstEvalCounter + | StatementKind::Nop + | StatementKind::FakeRead(..) + | StatementKind::PlaceMention(..) + | StatementKind::Coverage(..) + | StatementKind::BackwardIncompatibleDropHint { .. } + | StatementKind::AscribeUserType(..) => {} + } + } + + fn handle_intrinsic(&self, intrinsic: &NonDivergingIntrinsic<'tcx>) { + match intrinsic { + NonDivergingIntrinsic::Assume(..) => { + // Could use this, but ignoring it is sound. + } + NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping { + dst: _, + src: _, + count: _, + }) => { + // This statement represents `*dst = *src`, `count` times. + } + } + } + + fn handle_operand( + &self, + operand: &Operand<'tcx>, + state: &mut State<FlatSet<Scalar>>, + ) -> ValueOrPlace<FlatSet<Scalar>> { + match operand { + Operand::Constant(box constant) => { + ValueOrPlace::Value(self.handle_constant(constant, state)) + } + Operand::Copy(place) | Operand::Move(place) => { + // On move, we would ideally flood the place with bottom. But with the current + // framework this is not possible (similar to `InterpCx::eval_operand`). + self.map.find(place.as_ref()).map(ValueOrPlace::Place).unwrap_or(ValueOrPlace::TOP) + } + } + } + + /// The effect of a successful function call return should not be + /// applied here, see [`Analysis::apply_primary_terminator_effect`]. + fn handle_terminator<'mir>( + &self, + terminator: &'mir Terminator<'tcx>, + state: &mut State<FlatSet<Scalar>>, + ) -> TerminatorEdges<'mir, 'tcx> { + match &terminator.kind { + TerminatorKind::Call { .. } | TerminatorKind::InlineAsm { .. } => { + // Effect is applied by `handle_call_return`. + } + TerminatorKind::Drop { place, .. } => { + state.flood_with(place.as_ref(), &self.map, FlatSet::<Scalar>::BOTTOM); + } + TerminatorKind::Yield { .. } => { + // They would have an effect, but are not allowed in this phase. + bug!("encountered disallowed terminator"); + } + TerminatorKind::SwitchInt { discr, targets } => { + return self.handle_switch_int(discr, targets, state); + } + TerminatorKind::TailCall { .. } => { + // FIXME(explicit_tail_calls): determine if we need to do something here (probably + // not) + } + TerminatorKind::Goto { .. } + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Assert { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => { + // These terminators have no effect on the analysis. + } + } + terminator.edges() + } + + fn handle_call_return( + &self, + return_places: CallReturnPlaces<'_, 'tcx>, + state: &mut State<FlatSet<Scalar>>, + ) { + return_places.for_each(|place| { + state.flood(place.as_ref(), &self.map); + }) + } + + fn handle_set_discriminant( + &self, + place: Place<'tcx>, + variant_index: VariantIdx, + state: &mut State<FlatSet<Scalar>>, + ) { + state.flood_discr(place.as_ref(), &self.map); + if self.map.find_discr(place.as_ref()).is_some() { + let enum_ty = place.ty(self.local_decls, self.tcx).ty; + if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) { + state.assign_discr( + place.as_ref(), + ValueOrPlace::Value(FlatSet::Elem(discr)), + &self.map, + ); + } + } + } + + fn handle_assign( + &self, + target: Place<'tcx>, + rvalue: &Rvalue<'tcx>, + state: &mut State<FlatSet<Scalar>>, + ) { + match rvalue { + Rvalue::Use(operand) => { + state.flood(target.as_ref(), &self.map); + if let Some(target) = self.map.find(target.as_ref()) { + self.assign_operand(state, target, operand); + } + } + Rvalue::CopyForDeref(rhs) => { + state.flood(target.as_ref(), &self.map); + if let Some(target) = self.map.find(target.as_ref()) { + self.assign_operand(state, target, &Operand::Copy(*rhs)); + } + } + Rvalue::Aggregate(kind, operands) => { + // If we assign `target = Enum::Variant#0(operand)`, + // we must make sure that all `target as Variant#i` are `Top`. + state.flood(target.as_ref(), &self.map); + + let Some(target_idx) = self.map.find(target.as_ref()) else { return }; + + let (variant_target, variant_index) = match **kind { + AggregateKind::Tuple | AggregateKind::Closure(..) => (Some(target_idx), None), + AggregateKind::Adt(def_id, variant_index, ..) => { + match self.tcx.def_kind(def_id) { + DefKind::Struct => (Some(target_idx), None), + DefKind::Enum => ( + self.map.apply(target_idx, TrackElem::Variant(variant_index)), + Some(variant_index), + ), + _ => return, + } + } + _ => return, + }; + if let Some(variant_target_idx) = variant_target { + for (field_index, operand) in operands.iter_enumerated() { + if let Some(field) = + self.map.apply(variant_target_idx, TrackElem::Field(field_index)) + { + self.assign_operand(state, field, operand); + } + } + } + if let Some(variant_index) = variant_index + && let Some(discr_idx) = self.map.apply(target_idx, TrackElem::Discriminant) + { + // We are assigning the discriminant as part of an aggregate. + // This discriminant can only alias a variant field's value if the operand + // had an invalid value for that type. + // Using invalid values is UB, so we are allowed to perform the assignment + // without extra flooding. + let enum_ty = target.ty(self.local_decls, self.tcx).ty; + if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) { + state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map); + } + } + } + Rvalue::BinaryOp(op, box (left, right)) if op.is_overflowing() => { + // Flood everything now, so we can use `insert_value_idx` directly later. + state.flood(target.as_ref(), &self.map); + + let Some(target) = self.map.find(target.as_ref()) else { return }; + + let value_target = self.map.apply(target, TrackElem::Field(0_u32.into())); + let overflow_target = self.map.apply(target, TrackElem::Field(1_u32.into())); + + if value_target.is_some() || overflow_target.is_some() { + let (val, overflow) = self.binary_op(state, *op, left, right); + + if let Some(value_target) = value_target { + // We have flooded `target` earlier. + state.insert_value_idx(value_target, val, &self.map); + } + if let Some(overflow_target) = overflow_target { + // We have flooded `target` earlier. + state.insert_value_idx(overflow_target, overflow, &self.map); + } + } + } + Rvalue::Cast( + CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize, _), + operand, + _, + ) => { + let pointer = self.handle_operand(operand, state); + state.assign(target.as_ref(), pointer, &self.map); + + if let Some(target_len) = self.map.find_len(target.as_ref()) + && let operand_ty = operand.ty(self.local_decls, self.tcx) + && let Some(operand_ty) = operand_ty.builtin_deref(true) + && let ty::Array(_, len) = operand_ty.kind() + && let Some(len) = Const::Ty(self.tcx.types.usize, *len) + .try_eval_scalar_int(self.tcx, self.typing_env) + { + state.insert_value_idx(target_len, FlatSet::Elem(len.into()), &self.map); + } + } + _ => { + let result = self.handle_rvalue(rvalue, state); + state.assign(target.as_ref(), result, &self.map); + } + } + } + + fn handle_rvalue( + &self, + rvalue: &Rvalue<'tcx>, + state: &mut State<FlatSet<Scalar>>, + ) -> ValueOrPlace<FlatSet<Scalar>> { + let val = match rvalue { + Rvalue::Len(place) => { + let place_ty = place.ty(self.local_decls, self.tcx); + if let ty::Array(_, len) = place_ty.ty.kind() { + Const::Ty(self.tcx.types.usize, *len) + .try_eval_scalar(self.tcx, self.typing_env) + .map_or(FlatSet::Top, FlatSet::Elem) + } else if let [ProjectionElem::Deref] = place.projection[..] { + state.get_len(place.local.into(), &self.map) + } else { + FlatSet::Top + } + } + Rvalue::Cast(CastKind::IntToInt | CastKind::IntToFloat, operand, ty) => { + let Ok(layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else { + return ValueOrPlace::Value(FlatSet::Top); + }; + match self.eval_operand(operand, state) { + FlatSet::Elem(op) => self + .ecx + .int_to_int_or_float(&op, layout) + .discard_err() + .map_or(FlatSet::Top, |result| self.wrap_immediate(*result)), + FlatSet::Bottom => FlatSet::Bottom, + FlatSet::Top => FlatSet::Top, + } + } + Rvalue::Cast(CastKind::FloatToInt | CastKind::FloatToFloat, operand, ty) => { + let Ok(layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else { + return ValueOrPlace::Value(FlatSet::Top); + }; + match self.eval_operand(operand, state) { + FlatSet::Elem(op) => self + .ecx + .float_to_float_or_int(&op, layout) + .discard_err() + .map_or(FlatSet::Top, |result| self.wrap_immediate(*result)), + FlatSet::Bottom => FlatSet::Bottom, + FlatSet::Top => FlatSet::Top, + } + } + Rvalue::Cast(CastKind::Transmute, operand, _) => { + match self.eval_operand(operand, state) { + FlatSet::Elem(op) => self.wrap_immediate(*op), + FlatSet::Bottom => FlatSet::Bottom, + FlatSet::Top => FlatSet::Top, + } + } + Rvalue::BinaryOp(op, box (left, right)) if !op.is_overflowing() => { + // Overflows must be ignored here. + // The overflowing operators are handled in `handle_assign`. + let (val, _overflow) = self.binary_op(state, *op, left, right); + val + } + Rvalue::UnaryOp(op, operand) => match self.eval_operand(operand, state) { + FlatSet::Elem(value) => self + .ecx + .unary_op(*op, &value) + .discard_err() + .map_or(FlatSet::Top, |val| self.wrap_immediate(*val)), + FlatSet::Bottom => FlatSet::Bottom, + FlatSet::Top => FlatSet::Top, + }, + Rvalue::NullaryOp(null_op, ty) => { + let Ok(layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else { + return ValueOrPlace::Value(FlatSet::Top); + }; + let val = match null_op { + NullOp::SizeOf if layout.is_sized() => layout.size.bytes(), + NullOp::AlignOf if layout.is_sized() => layout.align.abi.bytes(), + NullOp::OffsetOf(fields) => self + .ecx + .tcx + .offset_of_subfield(self.typing_env, layout, fields.iter()) + .bytes(), + _ => return ValueOrPlace::Value(FlatSet::Top), + }; + FlatSet::Elem(Scalar::from_target_usize(val, &self.tcx)) + } + Rvalue::Discriminant(place) => state.get_discr(place.as_ref(), &self.map), + Rvalue::Use(operand) => return self.handle_operand(operand, state), + Rvalue::CopyForDeref(place) => { + return self.handle_operand(&Operand::Copy(*place), state); + } + Rvalue::Ref(..) | Rvalue::RawPtr(..) => { + // We don't track such places. + return ValueOrPlace::TOP; + } + Rvalue::Repeat(..) + | Rvalue::ThreadLocalRef(..) + | Rvalue::Cast(..) + | Rvalue::BinaryOp(..) + | Rvalue::Aggregate(..) + | Rvalue::ShallowInitBox(..) + | Rvalue::WrapUnsafeBinder(..) => { + // No modification is possible through these r-values. + return ValueOrPlace::TOP; + } + }; + ValueOrPlace::Value(val) + } + + fn handle_constant( + &self, + constant: &ConstOperand<'tcx>, + _state: &mut State<FlatSet<Scalar>>, + ) -> FlatSet<Scalar> { + constant + .const_ + .try_eval_scalar(self.tcx, self.typing_env) + .map_or(FlatSet::Top, FlatSet::Elem) + } + + fn handle_switch_int<'mir>( + &self, + discr: &'mir Operand<'tcx>, + targets: &'mir SwitchTargets, + state: &mut State<FlatSet<Scalar>>, + ) -> TerminatorEdges<'mir, 'tcx> { + let value = match self.handle_operand(discr, state) { + ValueOrPlace::Value(value) => value, + ValueOrPlace::Place(place) => state.get_idx(place, &self.map), + }; + match value { + // We are branching on uninitialized data, this is UB, treat it as unreachable. + // This allows the set of visited edges to grow monotonically with the lattice. + FlatSet::Bottom => TerminatorEdges::None, + FlatSet::Elem(scalar) => { + if let Ok(scalar_int) = scalar.try_to_scalar_int() { + TerminatorEdges::Single( + targets.target_for_value(scalar_int.to_bits_unchecked()), + ) + } else { + TerminatorEdges::SwitchInt { discr, targets } + } + } + FlatSet::Top => TerminatorEdges::SwitchInt { discr, targets }, + } + } + + /// The caller must have flooded `place`. + fn assign_operand( + &self, + state: &mut State<FlatSet<Scalar>>, + place: PlaceIndex, + operand: &Operand<'tcx>, + ) { + match operand { + Operand::Copy(rhs) | Operand::Move(rhs) => { + if let Some(rhs) = self.map.find(rhs.as_ref()) { + state.insert_place_idx(place, rhs, &self.map); + } else if rhs.projection.first() == Some(&PlaceElem::Deref) + && let FlatSet::Elem(pointer) = state.get(rhs.local.into(), &self.map) + && let rhs_ty = self.local_decls[rhs.local].ty + && let Ok(rhs_layout) = + self.tcx.layout_of(self.typing_env.as_query_input(rhs_ty)) + { + let op = ImmTy::from_scalar(pointer, rhs_layout).into(); + self.assign_constant(state, place, op, rhs.projection); + } + } + Operand::Constant(box constant) => { + if let Some(constant) = + self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err() + { + self.assign_constant(state, place, constant, &[]); + } + } + } + } + + /// The caller must have flooded `place`. + /// + /// Perform: `place = operand.projection`. + #[instrument(level = "trace", skip(self, state))] + fn assign_constant( + &self, + state: &mut State<FlatSet<Scalar>>, + place: PlaceIndex, + mut operand: OpTy<'tcx>, + projection: &[PlaceElem<'tcx>], + ) { + for &(mut proj_elem) in projection { + if let PlaceElem::Index(index) = proj_elem { + if let FlatSet::Elem(index) = state.get(index.into(), &self.map) + && let Some(offset) = index.to_target_usize(&self.tcx).discard_err() + && let Some(min_length) = offset.checked_add(1) + { + proj_elem = PlaceElem::ConstantIndex { offset, min_length, from_end: false }; + } else { + return; + } + } + operand = if let Some(operand) = self.ecx.project(&operand, proj_elem).discard_err() { + operand + } else { + return; + } + } + + self.map.for_each_projection_value( + place, + operand, + &mut |elem, op| match elem { + TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(), + TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(), + TrackElem::Discriminant => { + let variant = self.ecx.read_discriminant(op).discard_err()?; + let discr_value = + self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?; + Some(discr_value.into()) + } + TrackElem::DerefLen => { + let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into(); + let len_usize = op.len(&self.ecx).discard_err()?; + let layout = self + .tcx + .layout_of(self.typing_env.as_query_input(self.tcx.types.usize)) + .unwrap(); + Some(ImmTy::from_uint(len_usize, layout).into()) + } + }, + &mut |place, op| { + if let Some(imm) = self.ecx.read_immediate_raw(op).discard_err() + && let Some(imm) = imm.right() + { + let elem = self.wrap_immediate(*imm); + state.insert_value_idx(place, elem, &self.map); + } + }, + ); + } + + fn binary_op( + &self, + state: &mut State<FlatSet<Scalar>>, + op: BinOp, + left: &Operand<'tcx>, + right: &Operand<'tcx>, + ) -> (FlatSet<Scalar>, FlatSet<Scalar>) { + let left = self.eval_operand(left, state); + let right = self.eval_operand(right, state); + + match (left, right) { + (FlatSet::Bottom, _) | (_, FlatSet::Bottom) => (FlatSet::Bottom, FlatSet::Bottom), + // Both sides are known, do the actual computation. + (FlatSet::Elem(left), FlatSet::Elem(right)) => { + match self.ecx.binary_op(op, &left, &right).discard_err() { + // Ideally this would return an Immediate, since it's sometimes + // a pair and sometimes not. But as a hack we always return a pair + // and just make the 2nd component `Bottom` when it does not exist. + Some(val) => { + if matches!(val.layout.backend_repr, BackendRepr::ScalarPair(..)) { + let (val, overflow) = val.to_scalar_pair(); + (FlatSet::Elem(val), FlatSet::Elem(overflow)) + } else { + (FlatSet::Elem(val.to_scalar()), FlatSet::Bottom) + } + } + _ => (FlatSet::Top, FlatSet::Top), + } + } + // Exactly one side is known, attempt some algebraic simplifications. + (FlatSet::Elem(const_arg), _) | (_, FlatSet::Elem(const_arg)) => { + let layout = const_arg.layout; + if !matches!(layout.backend_repr, rustc_abi::BackendRepr::Scalar(..)) { + return (FlatSet::Top, FlatSet::Top); + } + + let arg_scalar = const_arg.to_scalar(); + let Some(arg_value) = arg_scalar.to_bits(layout.size).discard_err() else { + return (FlatSet::Top, FlatSet::Top); + }; + + match op { + BinOp::BitAnd if arg_value == 0 => (FlatSet::Elem(arg_scalar), FlatSet::Bottom), + BinOp::BitOr + if arg_value == layout.size.truncate(u128::MAX) + || (layout.ty.is_bool() && arg_value == 1) => + { + (FlatSet::Elem(arg_scalar), FlatSet::Bottom) + } + BinOp::Mul if layout.ty.is_integral() && arg_value == 0 => { + (FlatSet::Elem(arg_scalar), FlatSet::Elem(Scalar::from_bool(false))) + } + _ => (FlatSet::Top, FlatSet::Top), + } + } + (FlatSet::Top, FlatSet::Top) => (FlatSet::Top, FlatSet::Top), + } + } + + fn eval_operand( + &self, + op: &Operand<'tcx>, + state: &mut State<FlatSet<Scalar>>, + ) -> FlatSet<ImmTy<'tcx>> { + let value = match self.handle_operand(op, state) { + ValueOrPlace::Value(value) => value, + ValueOrPlace::Place(place) => state.get_idx(place, &self.map), + }; + match value { + FlatSet::Top => FlatSet::Top, + FlatSet::Elem(scalar) => { + let ty = op.ty(self.local_decls, self.tcx); + self.tcx + .layout_of(self.typing_env.as_query_input(ty)) + .map_or(FlatSet::Top, |layout| { + FlatSet::Elem(ImmTy::from_scalar(scalar, layout)) + }) + } + FlatSet::Bottom => FlatSet::Bottom, + } + } + + fn eval_discriminant(&self, enum_ty: Ty<'tcx>, variant_index: VariantIdx) -> Option<Scalar> { + if !enum_ty.is_enum() { + return None; + } + let enum_ty_layout = self.tcx.layout_of(self.typing_env.as_query_input(enum_ty)).ok()?; + let discr_value = + self.ecx.discriminant_for_variant(enum_ty_layout.ty, variant_index).discard_err()?; + Some(discr_value.to_scalar()) + } + + fn wrap_immediate(&self, imm: Immediate) -> FlatSet<Scalar> { + match imm { + Immediate::Scalar(scalar) => FlatSet::Elem(scalar), + Immediate::Uninit => FlatSet::Bottom, + _ => FlatSet::Top, + } + } +} + +/// This is used to visualize the dataflow analysis. +impl<'tcx> DebugWithContext<ConstAnalysis<'_, 'tcx>> for State<FlatSet<Scalar>> { + fn fmt_with(&self, ctxt: &ConstAnalysis<'_, 'tcx>, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + State::Reachable(values) => debug_with_context(values, None, &ctxt.map, f), + State::Unreachable => write!(f, "unreachable"), + } + } + + fn fmt_diff_with( + &self, + old: &Self, + ctxt: &ConstAnalysis<'_, 'tcx>, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + match (self, old) { + (State::Reachable(this), State::Reachable(old)) => { + debug_with_context(this, Some(old), &ctxt.map, f) + } + _ => Ok(()), // Consider printing something here. + } + } +} + +struct Patch<'tcx> { + tcx: TyCtxt<'tcx>, + + /// For a given MIR location, this stores the values of the operands used by that location. In + /// particular, this is before the effect, such that the operands of `_1 = _1 + _2` are + /// properly captured. (This may become UB soon, but it is currently emitted even by safe code.) + before_effect: FxHashMap<(Location, Place<'tcx>), Const<'tcx>>, + + /// Stores the assigned values for assignments where the Rvalue is constant. + assignments: FxHashMap<Location, Const<'tcx>>, +} + +impl<'tcx> Patch<'tcx> { + pub(crate) fn new(tcx: TyCtxt<'tcx>) -> Self { + Self { tcx, before_effect: FxHashMap::default(), assignments: FxHashMap::default() } + } + + fn make_operand(&self, const_: Const<'tcx>) -> Operand<'tcx> { + Operand::Constant(Box::new(ConstOperand { span: DUMMY_SP, user_ty: None, const_ })) + } +} + +struct Collector<'a, 'tcx> { + patch: Patch<'tcx>, + local_decls: &'a LocalDecls<'tcx>, +} + +impl<'a, 'tcx> Collector<'a, 'tcx> { + pub(crate) fn new(tcx: TyCtxt<'tcx>, local_decls: &'a LocalDecls<'tcx>) -> Self { + Self { patch: Patch::new(tcx), local_decls } + } + + #[instrument(level = "trace", skip(self, ecx, map), ret)] + fn try_make_constant( + &self, + ecx: &mut InterpCx<'tcx, DummyMachine>, + place: Place<'tcx>, + state: &State<FlatSet<Scalar>>, + map: &Map<'tcx>, + ) -> Option<Const<'tcx>> { + let ty = place.ty(self.local_decls, self.patch.tcx).ty; + let layout = ecx.layout_of(ty).ok()?; + + if layout.is_zst() { + return Some(Const::zero_sized(ty)); + } + + if layout.is_unsized() { + return None; + } + + let place = map.find(place.as_ref())?; + if layout.backend_repr.is_scalar() + && let Some(value) = propagatable_scalar(place, state, map) + { + return Some(Const::Val(ConstValue::Scalar(value), ty)); + } + + if matches!(layout.backend_repr, BackendRepr::Scalar(..) | BackendRepr::ScalarPair(..)) { + let alloc_id = ecx + .intern_with_temp_alloc(layout, |ecx, dest| { + try_write_constant(ecx, dest, place, ty, state, map) + }) + .discard_err()?; + return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty)); + } + + None + } +} + +#[instrument(level = "trace", skip(map), ret)] +fn propagatable_scalar( + place: PlaceIndex, + state: &State<FlatSet<Scalar>>, + map: &Map<'_>, +) -> Option<Scalar> { + if let FlatSet::Elem(value) = state.get_idx(place, map) + && value.try_to_scalar_int().is_ok() + { + // Do not attempt to propagate pointers, as we may fail to preserve their identity. + Some(value) + } else { + None + } +} + +#[instrument(level = "trace", skip(ecx, state, map), ret)] +fn try_write_constant<'tcx>( + ecx: &mut InterpCx<'tcx, DummyMachine>, + dest: &PlaceTy<'tcx>, + place: PlaceIndex, + ty: Ty<'tcx>, + state: &State<FlatSet<Scalar>>, + map: &Map<'tcx>, +) -> InterpResult<'tcx> { + let layout = ecx.layout_of(ty)?; + + // Fast path for ZSTs. + if layout.is_zst() { + return interp_ok(()); + } + + // Fast path for scalars. + if layout.backend_repr.is_scalar() + && let Some(value) = propagatable_scalar(place, state, map) + { + return ecx.write_immediate(Immediate::Scalar(value), dest); + } + + match ty.kind() { + // ZSTs. Nothing to do. + ty::FnDef(..) => {} + + // Those are scalars, must be handled above. + ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => + throw_machine_stop_str!("primitive type with provenance"), + + ty::Tuple(elem_tys) => { + for (i, elem) in elem_tys.iter().enumerate() { + let i = FieldIdx::from_usize(i); + let Some(field) = map.apply(place, TrackElem::Field(i)) else { + throw_machine_stop_str!("missing field in tuple") + }; + let field_dest = ecx.project_field(dest, i)?; + try_write_constant(ecx, &field_dest, field, elem, state, map)?; + } + } + + ty::Adt(def, args) => { + if def.is_union() { + throw_machine_stop_str!("cannot propagate unions") + } + + let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() { + let Some(discr) = map.apply(place, TrackElem::Discriminant) else { + throw_machine_stop_str!("missing discriminant for enum") + }; + let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else { + throw_machine_stop_str!("discriminant with provenance") + }; + let discr_bits = discr.to_bits(discr.size()); + let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else { + throw_machine_stop_str!("illegal discriminant for enum") + }; + let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else { + throw_machine_stop_str!("missing variant for enum") + }; + let variant_dest = ecx.project_downcast(dest, variant)?; + (variant, def.variant(variant), variant_place, variant_dest) + } else { + (FIRST_VARIANT, def.non_enum_variant(), place, dest.clone()) + }; + + for (i, field) in variant_def.fields.iter_enumerated() { + let ty = field.ty(*ecx.tcx, args); + let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else { + throw_machine_stop_str!("missing field in ADT") + }; + let field_dest = ecx.project_field(&variant_dest, i)?; + try_write_constant(ecx, &field_dest, field, ty, state, map)?; + } + ecx.write_discriminant(variant_idx, dest)?; + } + + // Unsupported for now. + ty::Array(_, _) + | ty::Pat(_, _) + + // Do not attempt to support indirection in constants. + | ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_) + + | ty::Never + | ty::Foreign(..) + | ty::Alias(..) + | ty::Param(_) + | ty::Bound(..) + | ty::Placeholder(..) + | ty::Closure(..) + | ty::CoroutineClosure(..) + | ty::Coroutine(..) + | ty::Dynamic(..) + | ty::UnsafeBinder(_) => throw_machine_stop_str!("unsupported type"), + + ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(), + } + + interp_ok(()) +} + +impl<'tcx> ResultsVisitor<'tcx, ConstAnalysis<'_, 'tcx>> for Collector<'_, 'tcx> { + #[instrument(level = "trace", skip(self, analysis, statement))] + fn visit_after_early_statement_effect( + &mut self, + analysis: &mut ConstAnalysis<'_, 'tcx>, + state: &State<FlatSet<Scalar>>, + statement: &Statement<'tcx>, + location: Location, + ) { + match &statement.kind { + StatementKind::Assign(box (_, rvalue)) => { + OperandCollector { + state, + visitor: self, + ecx: &mut analysis.ecx, + map: &analysis.map, + } + .visit_rvalue(rvalue, location); + } + _ => (), + } + } + + #[instrument(level = "trace", skip(self, analysis, statement))] + fn visit_after_primary_statement_effect( + &mut self, + analysis: &mut ConstAnalysis<'_, 'tcx>, + state: &State<FlatSet<Scalar>>, + statement: &Statement<'tcx>, + location: Location, + ) { + match statement.kind { + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(_)))) => { + // Don't overwrite the assignment if it already uses a constant (to keep the span). + } + StatementKind::Assign(box (place, _)) => { + if let Some(value) = + self.try_make_constant(&mut analysis.ecx, place, state, &analysis.map) + { + self.patch.assignments.insert(location, value); + } + } + _ => (), + } + } + + fn visit_after_early_terminator_effect( + &mut self, + analysis: &mut ConstAnalysis<'_, 'tcx>, + state: &State<FlatSet<Scalar>>, + terminator: &Terminator<'tcx>, + location: Location, + ) { + OperandCollector { state, visitor: self, ecx: &mut analysis.ecx, map: &analysis.map } + .visit_terminator(terminator, location); + } +} + +impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + if let Some(value) = self.assignments.get(&location) { + match &mut statement.kind { + StatementKind::Assign(box (_, rvalue)) => { + *rvalue = Rvalue::Use(self.make_operand(*value)); + } + _ => bug!("found assignment info for non-assign statement"), + } + } else { + self.super_statement(statement, location); + } + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { + match operand { + Operand::Copy(place) | Operand::Move(place) => { + if let Some(value) = self.before_effect.get(&(location, *place)) { + *operand = self.make_operand(*value); + } else if !place.projection.is_empty() { + self.super_operand(operand, location) + } + } + Operand::Constant(_) => {} + } + } + + fn process_projection_elem( + &mut self, + elem: PlaceElem<'tcx>, + location: Location, + ) -> Option<PlaceElem<'tcx>> { + if let PlaceElem::Index(local) = elem { + let offset = self.before_effect.get(&(location, local.into()))?; + let offset = offset.try_to_scalar()?; + let offset = offset.to_target_usize(&self.tcx).discard_err()?; + let min_length = offset.checked_add(1)?; + Some(PlaceElem::ConstantIndex { offset, min_length, from_end: false }) + } else { + None + } + } +} + +struct OperandCollector<'a, 'b, 'tcx> { + state: &'a State<FlatSet<Scalar>>, + visitor: &'a mut Collector<'b, 'tcx>, + ecx: &'a mut InterpCx<'tcx, DummyMachine>, + map: &'a Map<'tcx>, +} + +impl<'tcx> Visitor<'tcx> for OperandCollector<'_, '_, 'tcx> { + fn visit_projection_elem( + &mut self, + _: PlaceRef<'tcx>, + elem: PlaceElem<'tcx>, + _: PlaceContext, + location: Location, + ) { + if let PlaceElem::Index(local) = elem + && let Some(value) = + self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map) + { + self.visitor.patch.before_effect.insert((location, local.into()), value); + } + } + + fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { + if let Some(place) = operand.place() { + if let Some(value) = + self.visitor.try_make_constant(self.ecx, place, self.state, self.map) + { + self.visitor.patch.before_effect.insert((location, place), value); + } else if !place.projection.is_empty() { + // Try to propagate into `Index` projections. + self.super_operand(operand, location) + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/dead_store_elimination.rs b/compiler/rustc_mir_transform/src/dead_store_elimination.rs new file mode 100644 index 00000000000..eea2b0990d7 --- /dev/null +++ b/compiler/rustc_mir_transform/src/dead_store_elimination.rs @@ -0,0 +1,154 @@ +//! This module implements a dead store elimination (DSE) routine. +//! +//! This transformation was written specifically for the needs of dest prop. Although it is +//! perfectly sound to use it in any context that might need it, its behavior should not be changed +//! without analyzing the interaction this will have with dest prop. Specifically, in addition to +//! the soundness of this pass in general, dest prop needs it to satisfy two additional conditions: +//! +//! 1. It's idempotent, meaning that running this pass a second time immediately after running it a +//! first time will not cause any further changes. +//! 2. This idempotence persists across dest prop's main transform, in other words inserting any +//! number of iterations of dest prop between the first and second application of this transform +//! will still not cause any further changes. +//! + +use rustc_middle::bug; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::Analysis; +use rustc_mir_dataflow::debuginfo::debuginfo_locals; +use rustc_mir_dataflow::impls::{ + LivenessTransferFunction, MaybeTransitiveLiveLocals, borrowed_locals, +}; + +use crate::util::is_within_packed; + +/// Performs the optimization on the body +/// +/// The `borrowed` set must be a `DenseBitSet` of all the locals that are ever borrowed in this +/// body. It can be generated via the [`borrowed_locals`] function. +fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let borrowed_locals = borrowed_locals(body); + + // If the user requests complete debuginfo, mark the locals that appear in it as live, so + // we don't remove assignments to them. + let mut always_live = debuginfo_locals(body); + always_live.union(&borrowed_locals); + + let mut live = MaybeTransitiveLiveLocals::new(&always_live) + .iterate_to_fixpoint(tcx, body, None) + .into_results_cursor(body); + + // For blocks with a call terminator, if an argument copy can be turned into a move, + // record it as (block, argument index). + let mut call_operands_to_move = Vec::new(); + let mut patch = Vec::new(); + + for (bb, bb_data) in traversal::preorder(body) { + if let TerminatorKind::Call { ref args, .. } = bb_data.terminator().kind { + let loc = Location { block: bb, statement_index: bb_data.statements.len() }; + + // Position ourselves between the evaluation of `args` and the write to `destination`. + live.seek_to_block_end(bb); + let mut state = live.get().clone(); + + for (index, arg) in args.iter().map(|a| &a.node).enumerate().rev() { + if let Operand::Copy(place) = *arg + && !place.is_indirect() + // Do not skip the transformation if the local is in debuginfo, as we do + // not really lose any information for this purpose. + && !borrowed_locals.contains(place.local) + && !state.contains(place.local) + // If `place` is a projection of a disaligned field in a packed ADT, + // the move may be codegened as a pointer to that field. + // Using that disaligned pointer may trigger UB in the callee, + // so do nothing. + && is_within_packed(tcx, body, place).is_none() + { + call_operands_to_move.push((bb, index)); + } + + // Account that `arg` is read from, so we don't promote another argument to a move. + LivenessTransferFunction(&mut state).visit_operand(arg, loc); + } + } + + for (statement_index, statement) in bb_data.statements.iter().enumerate().rev() { + let loc = Location { block: bb, statement_index }; + if let StatementKind::Assign(assign) = &statement.kind { + if !assign.1.is_safe_to_remove() { + continue; + } + } + match &statement.kind { + StatementKind::Assign(box (place, _)) + | StatementKind::SetDiscriminant { place: box place, .. } + | StatementKind::Deinit(box place) => { + if !place.is_indirect() && !always_live.contains(place.local) { + live.seek_before_primary_effect(loc); + if !live.get().contains(place.local) { + patch.push(loc); + } + } + } + StatementKind::Retag(_, _) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Coverage(_) + | StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter + | StatementKind::PlaceMention(_) + | StatementKind::BackwardIncompatibleDropHint { .. } + | StatementKind::Nop => {} + + StatementKind::FakeRead(_) | StatementKind::AscribeUserType(_, _) => { + bug!("{:?} not found in this MIR phase!", statement.kind) + } + } + } + } + + if patch.is_empty() && call_operands_to_move.is_empty() { + return; + } + + let bbs = body.basic_blocks.as_mut_preserves_cfg(); + for Location { block, statement_index } in patch { + bbs[block].statements[statement_index].make_nop(); + } + for (block, argument_index) in call_operands_to_move { + let TerminatorKind::Call { ref mut args, .. } = bbs[block].terminator_mut().kind else { + bug!() + }; + let arg = &mut args[argument_index].node; + let Operand::Copy(place) = *arg else { bug!() }; + *arg = Operand::Move(place); + } +} + +pub(super) enum DeadStoreElimination { + Initial, + Final, +} + +impl<'tcx> crate::MirPass<'tcx> for DeadStoreElimination { + fn name(&self) -> &'static str { + match self { + DeadStoreElimination::Initial => "DeadStoreElimination-initial", + DeadStoreElimination::Final => "DeadStoreElimination-final", + } + } + + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 2 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + eliminate(tcx, body); + } + + fn is_required(&self) -> bool { + false + } +} diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs new file mode 100644 index 00000000000..a0db8bdb7ed --- /dev/null +++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs @@ -0,0 +1,196 @@ +//! Deduces supplementary parameter attributes from MIR. +//! +//! Deduced parameter attributes are those that can only be soundly determined by examining the +//! body of the function instead of just the signature. These can be useful for optimization +//! purposes on a best-effort basis. We compute them here and store them into the crate metadata so +//! dependent crates can use them. + +use rustc_hir::def_id::LocalDefId; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::{Body, Location, Operand, Place, RETURN_PLACE, Terminator, TerminatorKind}; +use rustc_middle::ty::{self, DeducedParamAttrs, Ty, TyCtxt}; +use rustc_session::config::OptLevel; + +/// A visitor that determines which arguments have been mutated. We can't use the mutability field +/// on LocalDecl for this because it has no meaning post-optimization. +struct DeduceReadOnly { + /// Each bit is indexed by argument number, starting at zero (so 0 corresponds to local decl + /// 1). The bit is true if the argument may have been mutated or false if we know it hasn't + /// been up to the point we're at. + mutable_args: DenseBitSet<usize>, +} + +impl DeduceReadOnly { + /// Returns a new DeduceReadOnly instance. + fn new(arg_count: usize) -> Self { + Self { mutable_args: DenseBitSet::new_empty(arg_count) } + } +} + +impl<'tcx> Visitor<'tcx> for DeduceReadOnly { + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { + // We're only interested in arguments. + if place.local == RETURN_PLACE || place.local.index() > self.mutable_args.domain_size() { + return; + } + + let mark_as_mutable = match context { + PlaceContext::MutatingUse(..) => { + // This is a mutation, so mark it as such. + true + } + PlaceContext::NonMutatingUse(NonMutatingUseContext::RawBorrow) => { + // Whether mutating though a `&raw const` is allowed is still undecided, so we + // disable any sketchy `readonly` optimizations for now. But we only need to do + // this if the pointer would point into the argument. IOW: for indirect places, + // like `&raw (*local).field`, this surely cannot mutate `local`. + !place.is_indirect() + } + PlaceContext::NonMutatingUse(..) | PlaceContext::NonUse(..) => { + // Not mutating, so it's fine. + false + } + }; + + if mark_as_mutable { + self.mutable_args.insert(place.local.index() - 1); + } + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + // OK, this is subtle. Suppose that we're trying to deduce whether `x` in `f` is read-only + // and we have the following: + // + // fn f(x: BigStruct) { g(x) } + // fn g(mut y: BigStruct) { y.foo = 1 } + // + // If, at the generated MIR level, `f` turned into something like: + // + // fn f(_1: BigStruct) -> () { + // let mut _0: (); + // bb0: { + // _0 = g(move _1) -> bb1; + // } + // ... + // } + // + // then it would be incorrect to mark `x` (i.e. `_1`) as `readonly`, because `g`'s write to + // its copy of the indirect parameter would actually be a write directly to the pointer that + // `f` passes. Note that function arguments are the only situation in which this problem can + // arise: every other use of `move` in MIR doesn't actually write to the value it moves + // from. + if let TerminatorKind::Call { ref args, .. } = terminator.kind { + for arg in args { + if let Operand::Move(place) = arg.node { + let local = place.local; + if place.is_indirect() + || local == RETURN_PLACE + || local.index() > self.mutable_args.domain_size() + { + continue; + } + + self.mutable_args.insert(local.index() - 1); + } + } + }; + + self.super_terminator(terminator, location); + } +} + +/// Returns true if values of a given type will never be passed indirectly, regardless of ABI. +fn type_will_always_be_passed_directly(ty: Ty<'_>) -> bool { + matches!( + ty.kind(), + ty::Bool + | ty::Char + | ty::Float(..) + | ty::Int(..) + | ty::RawPtr(..) + | ty::Ref(..) + | ty::Slice(..) + | ty::Uint(..) + ) +} + +/// Returns the deduced parameter attributes for a function. +/// +/// Deduced parameter attributes are those that can only be soundly determined by examining the +/// body of the function instead of just the signature. These can be useful for optimization +/// purposes on a best-effort basis. We compute them here and store them into the crate metadata so +/// dependent crates can use them. +pub(super) fn deduced_param_attrs<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: LocalDefId, +) -> &'tcx [DeducedParamAttrs] { + // This computation is unfortunately rather expensive, so don't do it unless we're optimizing. + // Also skip it in incremental mode. + if tcx.sess.opts.optimize == OptLevel::No || tcx.sess.opts.incremental.is_some() { + return &[]; + } + + // If the Freeze lang item isn't present, then don't bother. + if tcx.lang_items().freeze_trait().is_none() { + return &[]; + } + + // Codegen won't use this information for anything if all the function parameters are passed + // directly. Detect that and bail, for compilation speed. + let fn_ty = tcx.type_of(def_id).instantiate_identity(); + if matches!(fn_ty.kind(), ty::FnDef(..)) + && fn_ty + .fn_sig(tcx) + .inputs() + .skip_binder() + .iter() + .cloned() + .all(type_will_always_be_passed_directly) + { + return &[]; + } + + // Don't deduce any attributes for functions that have no MIR. + if !tcx.is_mir_available(def_id) { + return &[]; + } + + // Grab the optimized MIR. Analyze it to determine which arguments have been mutated. + let body: &Body<'tcx> = tcx.optimized_mir(def_id); + let mut deduce_read_only = DeduceReadOnly::new(body.arg_count); + deduce_read_only.visit_body(body); + + // Set the `readonly` attribute for every argument that we concluded is immutable and that + // contains no UnsafeCells. + // + // FIXME: This is overly conservative around generic parameters: `is_freeze()` will always + // return false for them. For a description of alternatives that could do a better job here, + // see [1]. + // + // [1]: https://github.com/rust-lang/rust/pull/103172#discussion_r999139997 + let typing_env = body.typing_env(tcx); + let mut deduced_param_attrs = tcx.arena.alloc_from_iter( + body.local_decls.iter().skip(1).take(body.arg_count).enumerate().map( + |(arg_index, local_decl)| DeducedParamAttrs { + read_only: !deduce_read_only.mutable_args.contains(arg_index) + // We must normalize here to reveal opaques and normalize + // their generic parameters, otherwise we'll see exponential + // blow-up in compile times: #113372 + && tcx + .normalize_erasing_regions(typing_env, local_decl.ty) + .is_freeze(tcx, typing_env), + }, + ), + ); + + // Trailing parameters past the size of the `deduced_param_attrs` array are assumed to have the + // default set of attributes, so we don't have to store them explicitly. Pop them off to save a + // few bytes in metadata. + while deduced_param_attrs.last() == Some(&DeducedParamAttrs::default()) { + let last_index = deduced_param_attrs.len() - 1; + deduced_param_attrs = &mut deduced_param_attrs[0..last_index]; + } + + deduced_param_attrs +} diff --git a/compiler/rustc_mir_transform/src/deref_separator.rs b/compiler/rustc_mir_transform/src/deref_separator.rs new file mode 100644 index 00000000000..bc914ea6564 --- /dev/null +++ b/compiler/rustc_mir_transform/src/deref_separator.rs @@ -0,0 +1,89 @@ +use rustc_middle::mir::visit::NonUseContext::VarDebugInfo; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +use crate::patch::MirPatch; + +pub(super) struct Derefer; + +struct DerefChecker<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + patcher: MirPatch<'tcx>, + local_decls: &'a LocalDecls<'tcx>, +} + +impl<'a, 'tcx> MutVisitor<'tcx> for DerefChecker<'a, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, cntxt: PlaceContext, loc: Location) { + if !place.projection.is_empty() + && cntxt != PlaceContext::NonUse(VarDebugInfo) + && place.projection[1..].contains(&ProjectionElem::Deref) + { + let mut place_local = place.local; + let mut last_len = 0; + let mut last_deref_idx = 0; + + for (idx, elem) in place.projection[0..].iter().enumerate() { + if *elem == ProjectionElem::Deref { + last_deref_idx = idx; + } + } + + for (idx, (p_ref, p_elem)) in place.iter_projections().enumerate() { + if !p_ref.projection.is_empty() && p_elem == ProjectionElem::Deref { + let ty = p_ref.ty(self.local_decls, self.tcx).ty; + let temp = self.patcher.new_local_with_info( + ty, + self.local_decls[p_ref.local].source_info.span, + LocalInfo::DerefTemp, + ); + + // We are adding current p_ref's projections to our + // temp value, excluding projections we already covered. + let deref_place = Place::from(place_local) + .project_deeper(&p_ref.projection[last_len..], self.tcx); + + self.patcher.add_assign( + loc, + Place::from(temp), + Rvalue::CopyForDeref(deref_place), + ); + place_local = temp; + last_len = p_ref.projection.len(); + + // Change `Place` only if we are actually at the Place's last deref + if idx == last_deref_idx { + let temp_place = + Place::from(temp).project_deeper(&place.projection[idx..], self.tcx); + *place = temp_place; + } + } + } + } + } +} + +pub(super) fn deref_finder<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let patch = MirPatch::new(body); + let mut checker = DerefChecker { tcx, patcher: patch, local_decls: &body.local_decls }; + + for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { + checker.visit_basic_block_data(bb, data); + } + + checker.patcher.apply(body); +} + +impl<'tcx> crate::MirPass<'tcx> for Derefer { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + deref_finder(tcx, body); + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs new file mode 100644 index 00000000000..4c94a6c524e --- /dev/null +++ b/compiler/rustc_mir_transform/src/dest_prop.rs @@ -0,0 +1,820 @@ +//! Propagates assignment destinations backwards in the CFG to eliminate redundant assignments. +//! +//! # Motivation +//! +//! MIR building can insert a lot of redundant copies, and Rust code in general often tends to move +//! values around a lot. The result is a lot of assignments of the form `dest = {move} src;` in MIR. +//! MIR building for constants in particular tends to create additional locals that are only used +//! inside a single block to shuffle a value around unnecessarily. +//! +//! LLVM by itself is not good enough at eliminating these redundant copies (eg. see +//! <https://github.com/rust-lang/rust/issues/32966>), so this leaves some performance on the table +//! that we can regain by implementing an optimization for removing these assign statements in rustc +//! itself. When this optimization runs fast enough, it can also speed up the constant evaluation +//! and code generation phases of rustc due to the reduced number of statements and locals. +//! +//! # The Optimization +//! +//! Conceptually, this optimization is "destination propagation". It is similar to the Named Return +//! Value Optimization, or NRVO, known from the C++ world, except that it isn't limited to return +//! values or the return place `_0`. On a very high level, independent of the actual implementation +//! details, it does the following: +//! +//! 1) Identify `dest = src;` statements with values for `dest` and `src` whose storage can soundly +//! be merged. +//! 2) Replace all mentions of `src` with `dest` ("unifying" them and propagating the destination +//! backwards). +//! 3) Delete the `dest = src;` statement (by making it a `nop`). +//! +//! Step 1) is by far the hardest, so it is explained in more detail below. +//! +//! ## Soundness +//! +//! We have a pair of places `p` and `q`, whose memory we would like to merge. In order for this to +//! be sound, we need to check a number of conditions: +//! +//! * `p` and `q` must both be *constant* - it does not make much sense to talk about merging them +//! if they do not consistently refer to the same place in memory. This is satisfied if they do +//! not contain any indirection through a pointer or any indexing projections. +//! +//! * `p` and `q` must have the **same type**. If we replace a local with a subtype or supertype, +//! we may end up with a different vtable for that local. See the `subtyping-impacts-selection` +//! tests for an example where that causes issues. +//! +//! * We need to make sure that the goal of "merging the memory" is actually structurally possible +//! in MIR. For example, even if all the other conditions are satisfied, there is no way to +//! "merge" `_5.foo` and `_6.bar`. For now, we ensure this by requiring that both `p` and `q` are +//! locals with no further projections. Future iterations of this pass should improve on this. +//! +//! * Finally, we want `p` and `q` to use the same memory - however, we still need to make sure that +//! each of them has enough "ownership" of that memory to continue "doing its job." More +//! precisely, what we will check is that whenever the program performs a write to `p`, then it +//! does not currently care about what the value in `q` is (and vice versa). We formalize the +//! notion of "does not care what the value in `q` is" by checking the *liveness* of `q`. +//! +//! Because of the difficulty of computing liveness of places that have their address taken, we do +//! not even attempt to do it. Any places that are in a local that has its address taken is +//! excluded from the optimization. +//! +//! The first two conditions are simple structural requirements on the `Assign` statements that can +//! be trivially checked. The third requirement however is more difficult and costly to check. +//! +//! ## Future Improvements +//! +//! There are a number of ways in which this pass could be improved in the future: +//! +//! * Merging storage liveness ranges instead of removing storage statements completely. This may +//! improve stack usage. +//! +//! * Allow merging locals into places with projections, eg `_5` into `_6.foo`. +//! +//! * Liveness analysis with more precision than whole locals at a time. The smaller benefit of this +//! is that it would allow us to dest prop at "sub-local" levels in some cases. The bigger benefit +//! of this is that such liveness analysis can report more accurate results about whole locals at +//! a time. For example, consider: +//! +//! ```ignore (syntax-highlighting-only) +//! _1 = u; +//! // unrelated code +//! _1.f1 = v; +//! _2 = _1.f1; +//! ``` +//! +//! Because the current analysis only thinks in terms of locals, it does not have enough +//! information to report that `_1` is dead in the "unrelated code" section. +//! +//! * Liveness analysis enabled by alias analysis. This would allow us to not just bail on locals +//! that ever have their address taken. Of course that requires actually having alias analysis +//! (and a model to build it on), so this might be a bit of a ways off. +//! +//! * Various perf improvements. There are a bunch of comments in here marked `PERF` with ideas for +//! how to do things more efficiently. However, the complexity of the pass as a whole should be +//! kept in mind. +//! +//! ## Previous Work +//! +//! A [previous attempt][attempt 1] at implementing an optimization like this turned out to be a +//! significant regression in compiler performance. Fixing the regressions introduced a lot of +//! undesirable complexity to the implementation. +//! +//! A [subsequent approach][attempt 2] tried to avoid the costly computation by limiting itself to +//! acyclic CFGs, but still turned out to be far too costly to run due to suboptimal performance +//! within individual basic blocks, requiring a walk across the entire block for every assignment +//! found within the block. For the `tuple-stress` benchmark, which has 458745 statements in a +//! single block, this proved to be far too costly. +//! +//! [Another approach after that][attempt 3] was much closer to correct, but had some soundness +//! issues - it was failing to consider stores outside live ranges, and failed to uphold some of the +//! requirements that MIR has for non-overlapping places within statements. However, it also had +//! performance issues caused by `O(l² * s)` runtime, where `l` is the number of locals and `s` is +//! the number of statements and terminators. +//! +//! Since the first attempt at this, the compiler has improved dramatically, and new analysis +//! frameworks have been added that should make this approach viable without requiring a limited +//! approach that only works for some classes of CFGs: +//! - rustc now has a powerful dataflow analysis framework that can handle forwards and backwards +//! analyses efficiently. +//! - Layout optimizations for coroutines have been added to improve code generation for +//! async/await, which are very similar in spirit to what this optimization does. +//! +//! Also, rustc now has a simple NRVO pass (see `nrvo.rs`), which handles a subset of the cases that +//! this destination propagation pass handles, proving that similar optimizations can be performed +//! on MIR. +//! +//! ## Pre/Post Optimization +//! +//! It is recommended to run `SimplifyCfg` and then `SimplifyLocals` some time after this pass, as +//! it replaces the eliminated assign statements with `nop`s and leaves unused locals behind. +//! +//! [liveness]: https://en.wikipedia.org/wiki/Live_variable_analysis +//! [attempt 1]: https://github.com/rust-lang/rust/pull/47954 +//! [attempt 2]: https://github.com/rust-lang/rust/pull/71003 +//! [attempt 3]: https://github.com/rust-lang/rust/pull/72632 + +use rustc_data_structures::fx::{FxIndexMap, IndexEntry, IndexOccupiedEntry}; +use rustc_index::bit_set::DenseBitSet; +use rustc_index::interval::SparseIntervalMatrix; +use rustc_middle::bug; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; +use rustc_middle::mir::{ + Body, HasLocalDecls, InlineAsmOperand, Local, LocalKind, Location, Operand, PassWhere, Place, + Rvalue, Statement, StatementKind, TerminatorKind, dump_mir, traversal, +}; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::Analysis; +use rustc_mir_dataflow::impls::MaybeLiveLocals; +use rustc_mir_dataflow::points::{DenseLocationMap, PointIndex, save_as_intervals}; +use tracing::{debug, trace}; + +pub(super) struct DestinationPropagation; + +impl<'tcx> crate::MirPass<'tcx> for DestinationPropagation { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + // For now, only run at MIR opt level 3. Two things need to be changed before this can be + // turned on by default: + // 1. Because of the overeager removal of storage statements, this can cause stack space + // regressions. This opt is not the place to fix this though, it's a more general + // problem in MIR. + // 2. Despite being an overall perf improvement, this still causes a 30% regression in + // keccak. We can temporarily fix this by bounding function size, but in the long term + // we should fix this by being smarter about invalidating analysis results. + sess.mir_opt_level() >= 3 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + let mut candidates = Candidates::default(); + let mut write_info = WriteInfo::default(); + trace!(func = ?tcx.def_path_str(def_id)); + + let borrowed = rustc_mir_dataflow::impls::borrowed_locals(body); + + let live = MaybeLiveLocals.iterate_to_fixpoint(tcx, body, Some("MaybeLiveLocals-DestProp")); + let points = DenseLocationMap::new(body); + let mut live = save_as_intervals(&points, body, live.analysis, live.results); + + // In order to avoid having to collect data for every single pair of locals in the body, we + // do not allow doing more than one merge for places that are derived from the same local at + // once. To avoid missed opportunities, we instead iterate to a fixed point - we'll refer to + // each of these iterations as a "round." + // + // Reaching a fixed point could in theory take up to `min(l, s)` rounds - however, we do not + // expect to see MIR like that. To verify this, a test was run against `[rust-lang/regex]` - + // the average MIR body saw 1.32 full iterations of this loop. The most that was hit were 30 + // for a single function. Only 80/2801 (2.9%) of functions saw at least 5. + // + // [rust-lang/regex]: + // https://github.com/rust-lang/regex/tree/b5372864e2df6a2f5e543a556a62197f50ca3650 + let mut round_count = 0; + loop { + // PERF: Can we do something smarter than recalculating the candidates and liveness + // results? + candidates.reset_and_find(body, &borrowed); + trace!(?candidates); + dest_prop_mir_dump(tcx, body, &points, &live, round_count); + + FilterInformation::filter_liveness( + &mut candidates, + &points, + &live, + &mut write_info, + body, + ); + + // Because we only filter once per round, it is unsound to use a local for more than + // one merge operation within a single round of optimizations. We store here which ones + // we have already used. + let mut merged_locals: DenseBitSet<Local> = + DenseBitSet::new_empty(body.local_decls.len()); + + // This is the set of merges we will apply this round. It is a subset of the candidates. + let mut merges = FxIndexMap::default(); + + for (src, candidates) in candidates.c.iter() { + if merged_locals.contains(*src) { + continue; + } + let Some(dest) = candidates.iter().find(|dest| !merged_locals.contains(**dest)) + else { + continue; + }; + + // Replace `src` by `dest` everywhere. + merges.insert(*src, *dest); + merged_locals.insert(*src); + merged_locals.insert(*dest); + + // Update liveness information based on the merge we just performed. + // Every location where `src` was live, `dest` will be live. + live.union_rows(*src, *dest); + } + trace!(merging = ?merges); + + if merges.is_empty() { + break; + } + round_count += 1; + + apply_merges(body, tcx, merges, merged_locals); + } + + trace!(round_count); + } + + fn is_required(&self) -> bool { + false + } +} + +#[derive(Debug, Default)] +struct Candidates { + /// The set of candidates we are considering in this optimization. + /// + /// We will always merge the key into at most one of its values. + /// + /// Whether a place ends up in the key or the value does not correspond to whether it appears as + /// the lhs or rhs of any assignment. As a matter of fact, the places in here might never appear + /// in an assignment at all. This happens because if we see an assignment like this: + /// + /// ```ignore (syntax-highlighting-only) + /// _1.0 = _2.0 + /// ``` + /// + /// We will still report that we would like to merge `_1` and `_2` in an attempt to allow us to + /// remove that assignment. + c: FxIndexMap<Local, Vec<Local>>, + + /// A reverse index of the `c` set; if the `c` set contains `a => Place { local: b, proj }`, + /// then this contains `b => a`. + // PERF: Possibly these should be `SmallVec`s? + reverse: FxIndexMap<Local, Vec<Local>>, +} + +////////////////////////////////////////////////////////// +// Merging +// +// Applies the actual optimization + +fn apply_merges<'tcx>( + body: &mut Body<'tcx>, + tcx: TyCtxt<'tcx>, + merges: FxIndexMap<Local, Local>, + merged_locals: DenseBitSet<Local>, +) { + let mut merger = Merger { tcx, merges, merged_locals }; + merger.visit_body_preserves_cfg(body); +} + +struct Merger<'tcx> { + tcx: TyCtxt<'tcx>, + merges: FxIndexMap<Local, Local>, + merged_locals: DenseBitSet<Local>, +} + +impl<'tcx> MutVisitor<'tcx> for Merger<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) { + if let Some(dest) = self.merges.get(local) { + *local = *dest; + } + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + match &statement.kind { + // FIXME: Don't delete storage statements, but "merge" the storage ranges instead. + StatementKind::StorageDead(local) | StatementKind::StorageLive(local) + if self.merged_locals.contains(*local) => + { + statement.make_nop(); + return; + } + _ => (), + }; + self.super_statement(statement, location); + match &statement.kind { + StatementKind::Assign(box (dest, rvalue)) => { + match rvalue { + Rvalue::CopyForDeref(place) + | Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => { + // These might've been turned into self-assignments by the replacement + // (this includes the original statement we wanted to eliminate). + if dest == place { + debug!("{:?} turned into self-assignment, deleting", location); + statement.make_nop(); + } + } + _ => {} + } + } + + _ => {} + } + } +} + +////////////////////////////////////////////////////////// +// Liveness filtering +// +// This section enforces bullet point 2 + +struct FilterInformation<'a, 'tcx> { + body: &'a Body<'tcx>, + points: &'a DenseLocationMap, + live: &'a SparseIntervalMatrix<Local, PointIndex>, + candidates: &'a mut Candidates, + write_info: &'a mut WriteInfo, + at: Location, +} + +// We first implement some utility functions which we will expose removing candidates according to +// different needs. Throughout the liveness filtering, the `candidates` are only ever accessed +// through these methods, and not directly. +impl Candidates { + /// Collects the candidates for merging. + /// + /// This is responsible for enforcing the first and third bullet point. + fn reset_and_find<'tcx>(&mut self, body: &Body<'tcx>, borrowed: &DenseBitSet<Local>) { + self.c.clear(); + self.reverse.clear(); + let mut visitor = FindAssignments { body, candidates: &mut self.c, borrowed }; + visitor.visit_body(body); + // Deduplicate candidates. + for (_, cands) in self.c.iter_mut() { + cands.sort(); + cands.dedup(); + } + // Generate the reverse map. + for (src, cands) in self.c.iter() { + for dest in cands.iter().copied() { + self.reverse.entry(dest).or_default().push(*src); + } + } + } + + /// Just `Vec::retain`, but the condition is inverted and we add debugging output + fn vec_filter_candidates( + src: Local, + v: &mut Vec<Local>, + mut f: impl FnMut(Local) -> CandidateFilter, + at: Location, + ) { + v.retain(|dest| { + let remove = f(*dest); + if remove == CandidateFilter::Remove { + trace!("eliminating {:?} => {:?} due to conflict at {:?}", src, dest, at); + } + remove == CandidateFilter::Keep + }); + } + + /// `vec_filter_candidates` but for an `Entry` + fn entry_filter_candidates( + mut entry: IndexOccupiedEntry<'_, Local, Vec<Local>>, + p: Local, + f: impl FnMut(Local) -> CandidateFilter, + at: Location, + ) { + let candidates = entry.get_mut(); + Self::vec_filter_candidates(p, candidates, f, at); + if candidates.len() == 0 { + // FIXME(#120456) - is `swap_remove` correct? + entry.swap_remove(); + } + } + + /// For all candidates `(p, q)` or `(q, p)` removes the candidate if `f(q)` says to do so + fn filter_candidates_by( + &mut self, + p: Local, + mut f: impl FnMut(Local) -> CandidateFilter, + at: Location, + ) { + // Cover the cases where `p` appears as a `src` + if let IndexEntry::Occupied(entry) = self.c.entry(p) { + Self::entry_filter_candidates(entry, p, &mut f, at); + } + // And the cases where `p` appears as a `dest` + let Some(srcs) = self.reverse.get_mut(&p) else { + return; + }; + // We use `retain` here to remove the elements from the reverse set if we've removed the + // matching candidate in the forward set. + srcs.retain(|src| { + if f(*src) == CandidateFilter::Keep { + return true; + } + let IndexEntry::Occupied(entry) = self.c.entry(*src) else { + return false; + }; + Self::entry_filter_candidates( + entry, + *src, + |dest| { + if dest == p { CandidateFilter::Remove } else { CandidateFilter::Keep } + }, + at, + ); + false + }); + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +enum CandidateFilter { + Keep, + Remove, +} + +impl<'a, 'tcx> FilterInformation<'a, 'tcx> { + /// Filters the set of candidates to remove those that conflict. + /// + /// The steps we take are exactly those that are outlined at the top of the file. For each + /// statement/terminator, we collect the set of locals that are written to in that + /// statement/terminator, and then we remove all pairs of candidates that contain one such local + /// and another one that is live. + /// + /// We need to be careful about the ordering of operations within each statement/terminator + /// here. Many statements might write and read from more than one place, and we need to consider + /// them all. The strategy for doing this is as follows: We first gather all the places that are + /// written to within the statement/terminator via `WriteInfo`. Then, we use the liveness + /// analysis from *before* the statement/terminator (in the control flow sense) to eliminate + /// candidates - this is because we want to conservatively treat a pair of locals that is both + /// read and written in the statement/terminator to be conflicting, and the liveness analysis + /// before the statement/terminator will correctly report locals that are read in the + /// statement/terminator to be live. We are additionally conservative by treating all written to + /// locals as also being read from. + fn filter_liveness( + candidates: &mut Candidates, + points: &DenseLocationMap, + live: &SparseIntervalMatrix<Local, PointIndex>, + write_info: &mut WriteInfo, + body: &Body<'tcx>, + ) { + let mut this = FilterInformation { + body, + points, + live, + candidates, + // We don't actually store anything at this scope, we just keep things here to be able + // to reuse the allocation. + write_info, + // Doesn't matter what we put here, will be overwritten before being used + at: Location::START, + }; + this.internal_filter_liveness(); + } + + fn internal_filter_liveness(&mut self) { + for (block, data) in traversal::preorder(self.body) { + self.at = Location { block, statement_index: data.statements.len() }; + self.write_info.for_terminator(&data.terminator().kind); + self.apply_conflicts(); + + for (i, statement) in data.statements.iter().enumerate().rev() { + self.at = Location { block, statement_index: i }; + self.write_info.for_statement(&statement.kind, self.body); + self.apply_conflicts(); + } + } + } + + fn apply_conflicts(&mut self) { + let writes = &self.write_info.writes; + for p in writes { + let other_skip = self.write_info.skip_pair.and_then(|(a, b)| { + if a == *p { + Some(b) + } else if b == *p { + Some(a) + } else { + None + } + }); + let at = self.points.point_from_location(self.at); + self.candidates.filter_candidates_by( + *p, + |q| { + if Some(q) == other_skip { + return CandidateFilter::Keep; + } + // It is possible that a local may be live for less than the + // duration of a statement This happens in the case of function + // calls or inline asm. Because of this, we also mark locals as + // conflicting when both of them are written to in the same + // statement. + if self.live.contains(q, at) || writes.contains(&q) { + CandidateFilter::Remove + } else { + CandidateFilter::Keep + } + }, + self.at, + ); + } + } +} + +/// Describes where a statement/terminator writes to +#[derive(Default, Debug)] +struct WriteInfo { + writes: Vec<Local>, + /// If this pair of locals is a candidate pair, completely skip processing it during this + /// statement. All other candidates are unaffected. + skip_pair: Option<(Local, Local)>, +} + +impl WriteInfo { + fn for_statement<'tcx>(&mut self, statement: &StatementKind<'tcx>, body: &Body<'tcx>) { + self.reset(); + match statement { + StatementKind::Assign(box (lhs, rhs)) => { + self.add_place(*lhs); + match rhs { + Rvalue::Use(op) => { + self.add_operand(op); + self.consider_skipping_for_assign_use(*lhs, op, body); + } + Rvalue::Repeat(op, _) => { + self.add_operand(op); + } + Rvalue::Cast(_, op, _) + | Rvalue::UnaryOp(_, op) + | Rvalue::ShallowInitBox(op, _) => { + self.add_operand(op); + } + Rvalue::BinaryOp(_, ops) => { + for op in [&ops.0, &ops.1] { + self.add_operand(op); + } + } + Rvalue::Aggregate(_, ops) => { + for op in ops { + self.add_operand(op); + } + } + Rvalue::WrapUnsafeBinder(op, _) => { + self.add_operand(op); + } + Rvalue::ThreadLocalRef(_) + | Rvalue::NullaryOp(_, _) + | Rvalue::Ref(_, _, _) + | Rvalue::RawPtr(_, _) + | Rvalue::Len(_) + | Rvalue::Discriminant(_) + | Rvalue::CopyForDeref(_) => {} + } + } + // Retags are technically also reads, but reporting them as a write suffices + StatementKind::SetDiscriminant { place, .. } + | StatementKind::Deinit(place) + | StatementKind::Retag(_, place) => { + self.add_place(**place); + } + StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter + | StatementKind::Nop + | StatementKind::Coverage(_) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::BackwardIncompatibleDropHint { .. } + | StatementKind::PlaceMention(_) => {} + StatementKind::FakeRead(_) | StatementKind::AscribeUserType(_, _) => { + bug!("{:?} not found in this MIR phase", statement) + } + } + } + + fn consider_skipping_for_assign_use<'tcx>( + &mut self, + lhs: Place<'tcx>, + rhs: &Operand<'tcx>, + body: &Body<'tcx>, + ) { + let Some(rhs) = rhs.place() else { return }; + if let Some(pair) = places_to_candidate_pair(lhs, rhs, body) { + self.skip_pair = Some(pair); + } + } + + fn for_terminator<'tcx>(&mut self, terminator: &TerminatorKind<'tcx>) { + self.reset(); + match terminator { + TerminatorKind::SwitchInt { discr: op, .. } + | TerminatorKind::Assert { cond: op, .. } => { + self.add_operand(op); + } + TerminatorKind::Call { destination, func, args, .. } => { + self.add_place(*destination); + self.add_operand(func); + for arg in args { + self.add_operand(&arg.node); + } + } + TerminatorKind::TailCall { func, args, .. } => { + self.add_operand(func); + for arg in args { + self.add_operand(&arg.node); + } + } + TerminatorKind::InlineAsm { operands, .. } => { + for asm_operand in operands { + match asm_operand { + InlineAsmOperand::In { value, .. } => { + self.add_operand(value); + } + InlineAsmOperand::Out { place, .. } => { + if let Some(place) = place { + self.add_place(*place); + } + } + // Note that the `late` field in `InOut` is about whether the registers used + // for these things overlap, and is of absolutely no interest to us. + InlineAsmOperand::InOut { in_value, out_place, .. } => { + if let Some(place) = out_place { + self.add_place(*place); + } + self.add_operand(in_value); + } + InlineAsmOperand::Const { .. } + | InlineAsmOperand::SymFn { .. } + | InlineAsmOperand::SymStatic { .. } + | InlineAsmOperand::Label { .. } => {} + } + } + } + TerminatorKind::Goto { .. } + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Unreachable { .. } => (), + TerminatorKind::Drop { .. } => { + // `Drop`s create a `&mut` and so are not considered + } + TerminatorKind::Yield { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => { + bug!("{:?} not found in this MIR phase", terminator) + } + } + } + + fn add_place(&mut self, place: Place<'_>) { + self.writes.push(place.local); + } + + fn add_operand<'tcx>(&mut self, op: &Operand<'tcx>) { + match op { + // FIXME(JakobDegen): In a previous version, the `Move` case was incorrectly treated as + // being a read only. This was unsound, however we cannot add a regression test because + // it is not possible to set this off with current MIR. Once we have that ability, a + // regression test should be added. + Operand::Move(p) => self.add_place(*p), + Operand::Copy(_) | Operand::Constant(_) => (), + } + } + + fn reset(&mut self) { + self.writes.clear(); + self.skip_pair = None; + } +} + +///////////////////////////////////////////////////// +// Candidate accumulation + +/// If the pair of places is being considered for merging, returns the candidate which would be +/// merged in order to accomplish this. +/// +/// The contract here is in one direction - there is a guarantee that merging the locals that are +/// outputted by this function would result in an assignment between the inputs becoming a +/// self-assignment. However, there is no guarantee that the returned pair is actually suitable for +/// merging - candidate collection must still check this independently. +/// +/// This output is unique for each unordered pair of input places. +fn places_to_candidate_pair<'tcx>( + a: Place<'tcx>, + b: Place<'tcx>, + body: &Body<'tcx>, +) -> Option<(Local, Local)> { + let (mut a, mut b) = if a.projection.len() == 0 && b.projection.len() == 0 { + (a.local, b.local) + } else { + return None; + }; + + // By sorting, we make sure we're input order independent + if a > b { + std::mem::swap(&mut a, &mut b); + } + + // We could now return `(a, b)`, but then we miss some candidates in the case where `a` can't be + // used as a `src`. + if is_local_required(a, body) { + std::mem::swap(&mut a, &mut b); + } + // We could check `is_local_required` again here, but there's no need - after all, we make no + // promise that the candidate pair is actually valid + Some((a, b)) +} + +struct FindAssignments<'a, 'tcx> { + body: &'a Body<'tcx>, + candidates: &'a mut FxIndexMap<Local, Vec<Local>>, + borrowed: &'a DenseBitSet<Local>, +} + +impl<'tcx> Visitor<'tcx> for FindAssignments<'_, 'tcx> { + fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) { + if let StatementKind::Assign(box ( + lhs, + Rvalue::CopyForDeref(rhs) | Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), + )) = &statement.kind + { + let Some((src, dest)) = places_to_candidate_pair(*lhs, *rhs, self.body) else { + return; + }; + + // As described at the top of the file, we do not go near things that have + // their address taken. + if self.borrowed.contains(src) || self.borrowed.contains(dest) { + return; + } + + // As described at the top of this file, we do not touch locals which have + // different types. + let src_ty = self.body.local_decls()[src].ty; + let dest_ty = self.body.local_decls()[dest].ty; + if src_ty != dest_ty { + // FIXME(#112651): This can be removed afterwards. Also update the module description. + trace!("skipped `{src:?} = {dest:?}` due to subtyping: {src_ty} != {dest_ty}"); + return; + } + + // Also, we need to make sure that MIR actually allows the `src` to be removed + if is_local_required(src, self.body) { + return; + } + + // We may insert duplicates here, but that's fine + self.candidates.entry(src).or_default().push(dest); + } + } +} + +/// Some locals are part of the function's interface and can not be removed. +/// +/// Note that these locals *can* still be merged with non-required locals by removing that other +/// local. +fn is_local_required(local: Local, body: &Body<'_>) -> bool { + match body.local_kind(local) { + LocalKind::Arg | LocalKind::ReturnPointer => true, + LocalKind::Temp => false, + } +} + +///////////////////////////////////////////////////////// +// MIR Dump + +fn dest_prop_mir_dump<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + points: &DenseLocationMap, + live: &SparseIntervalMatrix<Local, PointIndex>, + round: usize, +) { + let locals_live_at = |location| { + let location = points.point_from_location(location); + live.rows().filter(|&r| live.contains(r, location)).collect::<Vec<_>>() + }; + dump_mir(tcx, false, "DestinationPropagation-dataflow", &round, body, |pass_where, w| { + if let PassWhere::BeforeLocation(loc) = pass_where { + writeln!(w, " // live: {:?}", locals_live_at(loc))?; + } + + Ok(()) + }); +} diff --git a/compiler/rustc_mir_transform/src/dump_mir.rs b/compiler/rustc_mir_transform/src/dump_mir.rs new file mode 100644 index 00000000000..e4fcbaa483d --- /dev/null +++ b/compiler/rustc_mir_transform/src/dump_mir.rs @@ -0,0 +1,39 @@ +//! This pass just dumps MIR at a specified point. + +use std::fs::File; +use std::io; + +use rustc_middle::mir::{Body, write_mir_pretty}; +use rustc_middle::ty::TyCtxt; +use rustc_session::config::{OutFileName, OutputType}; + +pub(super) struct Marker(pub &'static str); + +impl<'tcx> crate::MirPass<'tcx> for Marker { + fn name(&self) -> &'static str { + self.0 + } + + fn run_pass(&self, _tcx: TyCtxt<'tcx>, _body: &mut Body<'tcx>) {} + + fn is_required(&self) -> bool { + false + } +} + +pub fn emit_mir(tcx: TyCtxt<'_>) -> io::Result<()> { + match tcx.output_filenames(()).path(OutputType::Mir) { + OutFileName::Stdout => { + let mut f = io::stdout(); + write_mir_pretty(tcx, None, &mut f)?; + } + OutFileName::Real(path) => { + let mut f = File::create_buffered(&path)?; + write_mir_pretty(tcx, None, &mut f)?; + if tcx.sess.opts.json_artifact_notifications { + tcx.dcx().emit_artifact_notification(&path, "mir"); + } + } + } + Ok(()) +} diff --git a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs new file mode 100644 index 00000000000..da88e5c698b --- /dev/null +++ b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs @@ -0,0 +1,407 @@ +use std::fmt::Debug; + +use rustc_middle::mir::*; +use rustc_middle::ty::{Ty, TyCtxt}; +use tracing::trace; + +use super::simplify::simplify_cfg; +use crate::patch::MirPatch; + +/// This pass optimizes something like +/// ```ignore (syntax-highlighting-only) +/// let x: Option<()>; +/// let y: Option<()>; +/// match (x,y) { +/// (Some(_), Some(_)) => {0}, +/// (None, None) => {2}, +/// _ => {1} +/// } +/// ``` +/// into something like +/// ```ignore (syntax-highlighting-only) +/// let x: Option<()>; +/// let y: Option<()>; +/// let discriminant_x = std::mem::discriminant(x); +/// let discriminant_y = std::mem::discriminant(y); +/// if discriminant_x == discriminant_y { +/// match x { +/// Some(_) => 0, +/// None => 2, +/// } +/// } else { +/// 1 +/// } +/// ``` +/// +/// Specifically, it looks for instances of control flow like this: +/// ```text +/// +/// ================= +/// | BB1 | +/// |---------------| ============================ +/// | ... | /------> | BBC | +/// |---------------| | |--------------------------| +/// | switchInt(Q) | | | _cl = discriminant(P) | +/// | c | --------/ |--------------------------| +/// | d | -------\ | switchInt(_cl) | +/// | ... | | | c | ---> BBC.2 +/// | otherwise | --\ | /--- | otherwise | +/// ================= | | | ============================ +/// | | | +/// ================= | | | +/// | BBU | <-| | | ============================ +/// |---------------| \-------> | BBD | +/// |---------------| | |--------------------------| +/// | unreachable | | | _dl = discriminant(P) | +/// ================= | |--------------------------| +/// | | switchInt(_dl) | +/// ================= | | d | ---> BBD.2 +/// | BB9 | <--------------- | otherwise | +/// |---------------| ============================ +/// | ... | +/// ================= +/// ``` +/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the +/// code: +/// - `BB1` is `parent` and `BBC, BBD` are children +/// - `P` is `child_place` +/// - `child_ty` is the type of `_cl`. +/// - `Q` is `parent_op`. +/// - `parent_ty` is the type of `Q`. +/// - `BB9` is `destination` +/// All this is then transformed into: +/// ```text +/// +/// ======================= +/// | BB1 | +/// |---------------------| ============================ +/// | ... | /------> | BBEq | +/// | _s = discriminant(P)| | |--------------------------| +/// | _t = Ne(Q, _s) | | |--------------------------| +/// |---------------------| | | switchInt(Q) | +/// | switchInt(_t) | | | c | ---> BBC.2 +/// | false | --------/ | d | ---> BBD.2 +/// | otherwise | /--------- | otherwise | +/// ======================= | ============================ +/// | +/// ================= | +/// | BB9 | <-----------/ +/// |---------------| +/// | ... | +/// ================= +/// ``` +pub(super) struct EarlyOtherwiseBranch; + +impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 2 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("running EarlyOtherwiseBranch on {:?}", body.source); + + let mut should_cleanup = false; + + // Also consider newly generated bbs in the same pass + for parent in body.basic_blocks.indices() { + let bbs = &*body.basic_blocks; + let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue }; + + trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}"); + + should_cleanup = true; + + let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } = + &bbs[parent].terminator().kind + else { + unreachable!() + }; + // Always correct since we can only switch on `Copy` types + let parent_op = match parent_op { + Operand::Move(x) => Operand::Copy(*x), + Operand::Copy(x) => Operand::Copy(*x), + Operand::Constant(x) => Operand::Constant(x.clone()), + }; + let parent_ty = parent_op.ty(body.local_decls(), tcx); + let statements_before = bbs[parent].statements.len(); + let parent_end = Location { block: parent, statement_index: statements_before }; + + let mut patch = MirPatch::new(body); + + let second_operand = if opt_data.need_hoist_discriminant { + // create temp to store second discriminant in, `_s` in example above + let second_discriminant_temp = + patch.new_temp(opt_data.child_ty, opt_data.child_source.span); + + // create assignment of discriminant + patch.add_assign( + parent_end, + Place::from(second_discriminant_temp), + Rvalue::Discriminant(opt_data.child_place), + ); + Operand::Move(Place::from(second_discriminant_temp)) + } else { + Operand::Copy(opt_data.child_place) + }; + + // create temp to store inequality comparison between the two discriminants, `_t` in + // example above + let nequal = BinOp::Ne; + let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty); + let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span); + + // create inequality comparison + let comp_rvalue = + Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand))); + patch.add_statement( + parent_end, + StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))), + ); + + let eq_new_targets = parent_targets.iter().map(|(value, child)| { + let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind + else { + unreachable!() + }; + (value, targets.target_for_value(value)) + }); + // The otherwise either is the same target branch or an unreachable. + let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise()); + + // Create `bbEq` in example above + let eq_switch = BasicBlockData::new( + Some(Terminator { + source_info: bbs[parent].terminator().source_info, + kind: TerminatorKind::SwitchInt { + // switch on the first discriminant, so we can mark the second one as dead + discr: parent_op, + targets: eq_targets, + }, + }), + bbs[parent].is_cleanup, + ); + + let eq_bb = patch.new_block(eq_switch); + + // Jump to it on the basis of the inequality comparison + let true_case = opt_data.destination; + let false_case = eq_bb; + patch.patch_terminator( + parent, + TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case), + ); + + patch.apply(body); + } + + // Since this optimization adds new basic blocks and invalidates others, + // clean up the cfg to make it nicer for other passes + if should_cleanup { + simplify_cfg(tcx, body); + } + } + + fn is_required(&self) -> bool { + false + } +} + +#[derive(Debug)] +struct OptimizationData<'tcx> { + destination: BasicBlock, + child_place: Place<'tcx>, + child_ty: Ty<'tcx>, + child_source: SourceInfo, + need_hoist_discriminant: bool, +} + +fn evaluate_candidate<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + parent: BasicBlock, +) -> Option<OptimizationData<'tcx>> { + let bbs = &body.basic_blocks; + // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled. + if bbs[parent].is_cleanup { + return None; + } + let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind + else { + return None; + }; + let parent_ty = parent_discr.ty(body.local_decls(), tcx); + let (_, child) = targets.iter().next()?; + + let Terminator { + kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr }, + source_info, + } = bbs[child].terminator() + else { + return None; + }; + let child_ty = child_discr.ty(body.local_decls(), tcx); + if child_ty != parent_ty { + return None; + } + + // We only handle: + // ``` + // bb4: { + // _8 = discriminant((_3.1: Enum1)); + // switchInt(move _8) -> [2: bb7, otherwise: bb1]; + // } + // ``` + // and + // ``` + // bb2: { + // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1]; + // } + // ``` + if bbs[child].statements.len() > 1 { + return None; + } + + // When thie BB has exactly one statement, this statement should be discriminant. + let need_hoist_discriminant = bbs[child].statements.len() == 1; + let child_place = if need_hoist_discriminant { + if !bbs[targets.otherwise()].is_empty_unreachable() { + // Someone could write code like this: + // ```rust + // let Q = val; + // if discriminant(P) == otherwise { + // let ptr = &mut Q as *mut _ as *mut u8; + // // It may be difficult for us to effectively determine whether values are valid. + // // Invalid values can come from all sorts of corners. + // unsafe { *ptr = 10; } + // } + // + // match P { + // A => match Q { + // A => { + // // code + // } + // _ => { + // // don't use Q + // } + // } + // _ => { + // // don't use Q + // } + // }; + // ``` + // + // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an + // invalid value, which is UB. + // In order to fix this, **we would either need to show that the discriminant computation of + // `place` is computed in all branches**. + // FIXME(#95162) For the moment, we adopt a conservative approach and + // consider only the `otherwise` branch has no statements and an unreachable terminator. + return None; + } + // Handle: + // ``` + // bb4: { + // _8 = discriminant((_3.1: Enum1)); + // switchInt(move _8) -> [2: bb7, otherwise: bb1]; + // } + // ``` + let [ + Statement { + kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))), + .. + }, + ] = bbs[child].statements.as_slice() + else { + return None; + }; + *child_place + } else { + // Handle: + // ``` + // bb2: { + // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1]; + // } + // ``` + let Operand::Copy(child_place) = child_discr else { + return None; + }; + *child_place + }; + let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable() + { + child_targets.otherwise() + } else { + targets.otherwise() + }; + + // Verify that the optimization is legal for each branch + for (value, child) in targets.iter() { + if !verify_candidate_branch( + &bbs[child], + value, + child_place, + destination, + need_hoist_discriminant, + ) { + return None; + } + } + Some(OptimizationData { + destination, + child_place, + child_ty, + child_source: *source_info, + need_hoist_discriminant, + }) +} + +fn verify_candidate_branch<'tcx>( + branch: &BasicBlockData<'tcx>, + value: u128, + place: Place<'tcx>, + destination: BasicBlock, + need_hoist_discriminant: bool, +) -> bool { + // In order for the optimization to be correct, the terminator must be a `SwitchInt`. + let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else { + return false; + }; + if need_hoist_discriminant { + // If we need hoist discriminant, the branch must have exactly one statement. + let [statement] = branch.statements.as_slice() else { + return false; + }; + // The statement must assign the discriminant of `place`. + let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) = + statement.kind + else { + return false; + }; + if from_place != place { + return false; + } + // The assignment must invalidate a local that terminate on a `SwitchInt`. + if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) { + return false; + } + } else { + // If we don't need hoist discriminant, the branch must not have any statements. + if !branch.statements.is_empty() { + return false; + } + // The place on `SwitchInt` must be the same. + if *switch_op != Operand::Copy(place) { + return false; + } + } + // It must fall through to `destination` if the switch misses. + if destination != targets.otherwise() { + return false; + } + // It must have exactly one branch for value `value` and have no more branches. + let mut iter = targets.iter(); + let (Some((target_value, _)), None) = (iter.next(), iter.next()) else { + return false; + }; + target_value == value +} diff --git a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs new file mode 100644 index 00000000000..5c344a80688 --- /dev/null +++ b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs @@ -0,0 +1,158 @@ +//! This pass transforms derefs of Box into a deref of the pointer inside Box. +//! +//! Box is not actually a pointer so it is incorrect to dereference it directly. + +use rustc_abi::FieldIdx; +use rustc_hir::def_id::DefId; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::*; +use rustc_middle::span_bug; +use rustc_middle::ty::{Ty, TyCtxt}; + +use crate::patch::MirPatch; + +/// Constructs the types used when accessing a Box's pointer +fn build_ptr_tys<'tcx>( + tcx: TyCtxt<'tcx>, + pointee: Ty<'tcx>, + unique_did: DefId, + nonnull_did: DefId, +) -> (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>) { + let args = tcx.mk_args(&[pointee.into()]); + let unique_ty = tcx.type_of(unique_did).instantiate(tcx, args); + let nonnull_ty = tcx.type_of(nonnull_did).instantiate(tcx, args); + let ptr_ty = Ty::new_imm_ptr(tcx, pointee); + + (unique_ty, nonnull_ty, ptr_ty) +} + +/// Constructs the projection needed to access a Box's pointer +pub(super) fn build_projection<'tcx>( + unique_ty: Ty<'tcx>, + nonnull_ty: Ty<'tcx>, +) -> [PlaceElem<'tcx>; 2] { + [PlaceElem::Field(FieldIdx::ZERO, unique_ty), PlaceElem::Field(FieldIdx::ZERO, nonnull_ty)] +} + +struct ElaborateBoxDerefVisitor<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + unique_did: DefId, + nonnull_did: DefId, + local_decls: &'a mut LocalDecls<'tcx>, + patch: MirPatch<'tcx>, +} + +impl<'a, 'tcx> MutVisitor<'tcx> for ElaborateBoxDerefVisitor<'a, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place( + &mut self, + place: &mut Place<'tcx>, + context: visit::PlaceContext, + location: Location, + ) { + let tcx = self.tcx; + + let base_ty = self.local_decls[place.local].ty; + + // Derefer ensures that derefs are always the first projection + if let Some(PlaceElem::Deref) = place.projection.first() + && let Some(boxed_ty) = base_ty.boxed_ty() + { + let source_info = self.local_decls[place.local].source_info; + + let (unique_ty, nonnull_ty, ptr_ty) = + build_ptr_tys(tcx, boxed_ty, self.unique_did, self.nonnull_did); + + let ptr_local = self.patch.new_temp(ptr_ty, source_info.span); + + self.patch.add_assign( + location, + Place::from(ptr_local), + Rvalue::Cast( + CastKind::Transmute, + Operand::Copy( + Place::from(place.local) + .project_deeper(&build_projection(unique_ty, nonnull_ty), tcx), + ), + ptr_ty, + ), + ); + + place.local = ptr_local; + } + + self.super_place(place, context, location); + } +} + +pub(super) struct ElaborateBoxDerefs; + +impl<'tcx> crate::MirPass<'tcx> for ElaborateBoxDerefs { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // If box is not present, this pass doesn't need to do anything. + let Some(def_id) = tcx.lang_items().owned_box() else { return }; + + let unique_did = tcx.adt_def(def_id).non_enum_variant().fields[FieldIdx::ZERO].did; + + let Some(nonnull_def) = tcx.type_of(unique_did).instantiate_identity().ty_adt_def() else { + span_bug!(tcx.def_span(unique_did), "expected Box to contain Unique") + }; + + let nonnull_did = nonnull_def.non_enum_variant().fields[FieldIdx::ZERO].did; + + let patch = MirPatch::new(body); + + let local_decls = &mut body.local_decls; + + let mut visitor = + ElaborateBoxDerefVisitor { tcx, unique_did, nonnull_did, local_decls, patch }; + + for (block, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { + visitor.visit_basic_block_data(block, data); + } + + visitor.patch.apply(body); + + for debug_info in body.var_debug_info.iter_mut() { + if let VarDebugInfoContents::Place(place) = &mut debug_info.value { + let mut new_projections: Option<Vec<_>> = None; + + for (base, elem) in place.iter_projections() { + let base_ty = base.ty(&body.local_decls, tcx).ty; + + if let PlaceElem::Deref = elem + && let Some(boxed_ty) = base_ty.boxed_ty() + { + // Clone the projections before us, since now we need to mutate them. + let new_projections = + new_projections.get_or_insert_with(|| base.projection.to_vec()); + + let (unique_ty, nonnull_ty, ptr_ty) = + build_ptr_tys(tcx, boxed_ty, unique_did, nonnull_did); + + new_projections.extend_from_slice(&build_projection(unique_ty, nonnull_ty)); + // While we can't project into `NonNull<_>` in a basic block + // due to MCP#807, this is debug info where it's fine. + new_projections.push(PlaceElem::Field(FieldIdx::ZERO, ptr_ty)); + new_projections.push(PlaceElem::Deref); + } else if let Some(new_projections) = new_projections.as_mut() { + // Keep building up our projections list once we've started it. + new_projections.push(elem); + } + } + + // Store the mutated projections if we actually changed something. + if let Some(new_projections) = new_projections { + place.projection = tcx.mk_place_elems(&new_projections); + } + } + } + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/elaborate_drop.rs b/compiler/rustc_mir_transform/src/elaborate_drop.rs new file mode 100644 index 00000000000..c9bc52c6c7e --- /dev/null +++ b/compiler/rustc_mir_transform/src/elaborate_drop.rs @@ -0,0 +1,1475 @@ +use std::{fmt, iter, mem}; + +use rustc_abi::{FIRST_VARIANT, FieldIdx, VariantIdx}; +use rustc_hir::def::DefKind; +use rustc_hir::lang_items::LangItem; +use rustc_index::Idx; +use rustc_middle::mir::*; +use rustc_middle::ty::adjustment::PointerCoercion; +use rustc_middle::ty::util::IntTypeExt; +use rustc_middle::ty::{self, GenericArg, GenericArgsRef, Ty, TyCtxt}; +use rustc_middle::{bug, span_bug, traits}; +use rustc_span::DUMMY_SP; +use rustc_span::source_map::{Spanned, dummy_spanned}; +use tracing::{debug, instrument}; + +use crate::patch::MirPatch; + +/// Describes how/if a value should be dropped. +#[derive(Debug)] +pub(crate) enum DropStyle { + /// The value is already dead at the drop location, no drop will be executed. + Dead, + + /// The value is known to always be initialized at the drop location, drop will always be + /// executed. + Static, + + /// Whether the value needs to be dropped depends on its drop flag. + Conditional, + + /// An "open" drop is one where only the fields of a value are dropped. + /// + /// For example, this happens when moving out of a struct field: The rest of the struct will be + /// dropped in such an "open" drop. It is also used to generate drop glue for the individual + /// components of a value, for example for dropping array elements. + Open, +} + +/// Which drop flags to affect/check with an operation. +#[derive(Debug)] +pub(crate) enum DropFlagMode { + /// Only affect the top-level drop flag, not that of any contained fields. + Shallow, + /// Affect all nested drop flags in addition to the top-level one. + Deep, +} + +/// Describes if unwinding is necessary and where to unwind to if a panic occurs. +#[derive(Copy, Clone, Debug)] +pub(crate) enum Unwind { + /// Unwind to this block. + To(BasicBlock), + /// Already in an unwind path, any panic will cause an abort. + InCleanup, +} + +impl Unwind { + fn is_cleanup(self) -> bool { + match self { + Unwind::To(..) => false, + Unwind::InCleanup => true, + } + } + + fn into_action(self) -> UnwindAction { + match self { + Unwind::To(bb) => UnwindAction::Cleanup(bb), + Unwind::InCleanup => UnwindAction::Terminate(UnwindTerminateReason::InCleanup), + } + } + + fn map<F>(self, f: F) -> Self + where + F: FnOnce(BasicBlock) -> BasicBlock, + { + match self { + Unwind::To(bb) => Unwind::To(f(bb)), + Unwind::InCleanup => Unwind::InCleanup, + } + } +} + +pub(crate) trait DropElaborator<'a, 'tcx>: fmt::Debug { + /// The type representing paths that can be moved out of. + /// + /// Users can move out of individual fields of a struct, such as `a.b.c`. This type is used to + /// represent such move paths. Sometimes tracking individual move paths is not necessary, in + /// which case this may be set to (for example) `()`. + type Path: Copy + fmt::Debug; + + // Accessors + + fn patch_ref(&self) -> &MirPatch<'tcx>; + fn patch(&mut self) -> &mut MirPatch<'tcx>; + fn body(&self) -> &'a Body<'tcx>; + fn tcx(&self) -> TyCtxt<'tcx>; + fn typing_env(&self) -> ty::TypingEnv<'tcx>; + fn allow_async_drops(&self) -> bool; + + fn terminator_loc(&self, bb: BasicBlock) -> Location; + + // Drop logic + + /// Returns how `path` should be dropped, given `mode`. + fn drop_style(&self, path: Self::Path, mode: DropFlagMode) -> DropStyle; + + /// Returns the drop flag of `path` as a MIR `Operand` (or `None` if `path` has no drop flag). + fn get_drop_flag(&mut self, path: Self::Path) -> Option<Operand<'tcx>>; + + /// Modifies the MIR patch so that the drop flag of `path` (if any) is cleared at `location`. + /// + /// If `mode` is deep, drop flags of all child paths should also be cleared by inserting + /// additional statements. + fn clear_drop_flag(&mut self, location: Location, path: Self::Path, mode: DropFlagMode); + + // Subpaths + + /// Returns the subpath of a field of `path` (or `None` if there is no dedicated subpath). + /// + /// If this returns `None`, `field` will not get a dedicated drop flag. + fn field_subpath(&self, path: Self::Path, field: FieldIdx) -> Option<Self::Path>; + + /// Returns the subpath of a dereference of `path` (or `None` if there is no dedicated subpath). + /// + /// If this returns `None`, `*path` will not get a dedicated drop flag. + /// + /// This is only relevant for `Box<T>`, where the contained `T` can be moved out of the box. + fn deref_subpath(&self, path: Self::Path) -> Option<Self::Path>; + + /// Returns the subpath of downcasting `path` to one of its variants. + /// + /// If this returns `None`, the downcast of `path` will not get a dedicated drop flag. + fn downcast_subpath(&self, path: Self::Path, variant: VariantIdx) -> Option<Self::Path>; + + /// Returns the subpath of indexing a fixed-size array `path`. + /// + /// If this returns `None`, elements of `path` will not get a dedicated drop flag. + /// + /// This is only relevant for array patterns, which can move out of individual array elements. + fn array_subpath(&self, path: Self::Path, index: u64, size: u64) -> Option<Self::Path>; +} + +#[derive(Debug)] +struct DropCtxt<'a, 'b, 'tcx, D> +where + D: DropElaborator<'b, 'tcx>, +{ + elaborator: &'a mut D, + + source_info: SourceInfo, + + place: Place<'tcx>, + path: D::Path, + succ: BasicBlock, + unwind: Unwind, + dropline: Option<BasicBlock>, +} + +/// "Elaborates" a drop of `place`/`path` and patches `bb`'s terminator to execute it. +/// +/// The passed `elaborator` is used to determine what should happen at the drop terminator. It +/// decides whether the drop can be statically determined or whether it needs a dynamic drop flag, +/// and whether the drop is "open", ie. should be expanded to drop all subfields of the dropped +/// value. +/// +/// When this returns, the MIR patch in the `elaborator` contains the necessary changes. +pub(crate) fn elaborate_drop<'b, 'tcx, D>( + elaborator: &mut D, + source_info: SourceInfo, + place: Place<'tcx>, + path: D::Path, + succ: BasicBlock, + unwind: Unwind, + bb: BasicBlock, + dropline: Option<BasicBlock>, +) where + D: DropElaborator<'b, 'tcx>, + 'tcx: 'b, +{ + DropCtxt { elaborator, source_info, place, path, succ, unwind, dropline }.elaborate_drop(bb) +} + +impl<'a, 'b, 'tcx, D> DropCtxt<'a, 'b, 'tcx, D> +where + D: DropElaborator<'b, 'tcx>, + 'tcx: 'b, +{ + #[instrument(level = "trace", skip(self), ret)] + fn place_ty(&self, place: Place<'tcx>) -> Ty<'tcx> { + if place.local < self.elaborator.body().local_decls.next_index() { + place.ty(self.elaborator.body(), self.tcx()).ty + } else { + // We don't have a slice with all the locals, since some are in the patch. + PlaceTy::from_ty(self.elaborator.patch_ref().local_ty(place.local)) + .multi_projection_ty(self.elaborator.tcx(), place.projection) + .ty + } + } + + fn tcx(&self) -> TyCtxt<'tcx> { + self.elaborator.tcx() + } + + // Generates three blocks: + // * #1:pin_obj_bb: call Pin<ObjTy>::new_unchecked(&mut obj) + // * #2:call_drop_bb: fut = call obj.<AsyncDrop::drop>() OR call async_drop_in_place<T>(obj) + // * #3:drop_term_bb: drop (obj, fut, ...) + // We keep async drop unexpanded to poll-loop here, to expand it later, at StateTransform - + // into states expand. + // call_destructor_only - to call only AsyncDrop::drop, not full async_drop_in_place glue + fn build_async_drop( + &mut self, + place: Place<'tcx>, + drop_ty: Ty<'tcx>, + bb: Option<BasicBlock>, + succ: BasicBlock, + unwind: Unwind, + dropline: Option<BasicBlock>, + call_destructor_only: bool, + ) -> BasicBlock { + let tcx = self.tcx(); + let span = self.source_info.span; + + let pin_obj_bb = bb.unwrap_or_else(|| { + self.elaborator.patch().new_block(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + // Temporary terminator, will be replaced by patch + source_info: self.source_info, + kind: TerminatorKind::Return, + }), + is_cleanup: false, + }) + }); + + let (fut_ty, drop_fn_def_id, trait_args) = if call_destructor_only { + // Resolving obj.<AsyncDrop::drop>() + let trait_ref = + ty::TraitRef::new(tcx, tcx.require_lang_item(LangItem::AsyncDrop, span), [drop_ty]); + let (drop_trait, trait_args) = match tcx.codegen_select_candidate( + ty::TypingEnv::fully_monomorphized().as_query_input(trait_ref), + ) { + Ok(traits::ImplSource::UserDefined(traits::ImplSourceUserDefinedData { + impl_def_id, + args, + .. + })) => (*impl_def_id, *args), + impl_source => { + span_bug!(span, "invalid `AsyncDrop` impl_source: {:?}", impl_source); + } + }; + // impl_item_refs may be empty if drop fn is not implemented in 'impl AsyncDrop for ...' + // (#140974). + // Such code will report error, so just generate sync drop here and return + let Some(drop_fn_def_id) = tcx + .associated_item_def_ids(drop_trait) + .first() + .and_then(|def_id| { + if tcx.def_kind(def_id) == DefKind::AssocFn + && tcx.check_args_compatible(*def_id, trait_args) + { + Some(def_id) + } else { + None + } + }) + .copied() + else { + tcx.dcx().span_delayed_bug( + self.elaborator.body().span, + "AsyncDrop type without correct `async fn drop(...)`.", + ); + self.elaborator.patch().patch_terminator( + pin_obj_bb, + TerminatorKind::Drop { + place, + target: succ, + unwind: unwind.into_action(), + replace: false, + drop: None, + async_fut: None, + }, + ); + return pin_obj_bb; + }; + let drop_fn = Ty::new_fn_def(tcx, drop_fn_def_id, trait_args); + let sig = drop_fn.fn_sig(tcx); + let sig = tcx.instantiate_bound_regions_with_erased(sig); + (sig.output(), drop_fn_def_id, trait_args) + } else { + // Resolving async_drop_in_place<T> function for drop_ty + let drop_fn_def_id = tcx.require_lang_item(LangItem::AsyncDropInPlace, span); + let trait_args = tcx.mk_args(&[drop_ty.into()]); + let sig = tcx.fn_sig(drop_fn_def_id).instantiate(tcx, trait_args); + let sig = tcx.instantiate_bound_regions_with_erased(sig); + (sig.output(), drop_fn_def_id, trait_args) + }; + + let fut = Place::from(self.new_temp(fut_ty)); + + // #1:pin_obj_bb >>> obj_ref = &mut obj + let obj_ref_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, drop_ty); + let obj_ref_place = Place::from(self.new_temp(obj_ref_ty)); + + let term_loc = self.elaborator.terminator_loc(pin_obj_bb); + self.elaborator.patch().add_assign( + term_loc, + obj_ref_place, + Rvalue::Ref( + tcx.lifetimes.re_erased, + BorrowKind::Mut { kind: MutBorrowKind::Default }, + place, + ), + ); + + // pin_obj_place preparation + let pin_obj_new_unchecked_fn = Ty::new_fn_def( + tcx, + tcx.require_lang_item(LangItem::PinNewUnchecked, span), + [GenericArg::from(obj_ref_ty)], + ); + let pin_obj_ty = pin_obj_new_unchecked_fn.fn_sig(tcx).output().no_bound_vars().unwrap(); + let pin_obj_place = Place::from(self.new_temp(pin_obj_ty)); + let pin_obj_new_unchecked_fn = Operand::Constant(Box::new(ConstOperand { + span, + user_ty: None, + const_: Const::zero_sized(pin_obj_new_unchecked_fn), + })); + + // #3:drop_term_bb + let drop_term_bb = self.new_block( + unwind, + TerminatorKind::Drop { + place, + target: succ, + unwind: unwind.into_action(), + replace: false, + drop: dropline, + async_fut: Some(fut.local), + }, + ); + + // #2:call_drop_bb + let mut call_statements = Vec::new(); + let drop_arg = if call_destructor_only { + pin_obj_place + } else { + let ty::Adt(adt_def, adt_args) = pin_obj_ty.kind() else { + bug!(); + }; + let obj_ptr_ty = Ty::new_mut_ptr(tcx, drop_ty); + let unwrap_ty = adt_def.non_enum_variant().fields[FieldIdx::ZERO].ty(tcx, adt_args); + let obj_ref_place = Place::from(self.new_temp(unwrap_ty)); + call_statements.push(self.assign( + obj_ref_place, + Rvalue::Use(Operand::Copy(tcx.mk_place_field( + pin_obj_place, + FieldIdx::ZERO, + unwrap_ty, + ))), + )); + + let obj_ptr_place = Place::from(self.new_temp(obj_ptr_ty)); + + let addr = Rvalue::RawPtr(RawPtrKind::Mut, tcx.mk_place_deref(obj_ref_place)); + call_statements.push(self.assign(obj_ptr_place, addr)); + obj_ptr_place + }; + call_statements.push(Statement { + source_info: self.source_info, + kind: StatementKind::StorageLive(fut.local), + }); + + let call_drop_bb = self.new_block_with_statements( + unwind, + call_statements, + TerminatorKind::Call { + func: Operand::function_handle(tcx, drop_fn_def_id, trait_args, span), + args: [Spanned { node: Operand::Move(drop_arg), span: DUMMY_SP }].into(), + destination: fut, + target: Some(drop_term_bb), + unwind: unwind.into_action(), + call_source: CallSource::Misc, + fn_span: self.source_info.span, + }, + ); + + // StorageDead(fut) in self.succ block (at the begin) + self.elaborator.patch().add_statement( + Location { block: self.succ, statement_index: 0 }, + StatementKind::StorageDead(fut.local), + ); + // StorageDead(fut) in unwind block (at the begin) + if let Unwind::To(block) = unwind { + self.elaborator.patch().add_statement( + Location { block, statement_index: 0 }, + StatementKind::StorageDead(fut.local), + ); + } + // StorageDead(fut) in dropline block (at the begin) + if let Some(block) = dropline { + self.elaborator.patch().add_statement( + Location { block, statement_index: 0 }, + StatementKind::StorageDead(fut.local), + ); + } + + // #1:pin_obj_bb >>> call Pin<ObjTy>::new_unchecked(&mut obj) + self.elaborator.patch().patch_terminator( + pin_obj_bb, + TerminatorKind::Call { + func: pin_obj_new_unchecked_fn, + args: [dummy_spanned(Operand::Move(obj_ref_place))].into(), + destination: pin_obj_place, + target: Some(call_drop_bb), + unwind: unwind.into_action(), + call_source: CallSource::Misc, + fn_span: span, + }, + ); + pin_obj_bb + } + + fn build_drop(&mut self, bb: BasicBlock) { + let drop_ty = self.place_ty(self.place); + if self.tcx().features().async_drop() + && self.elaborator.body().coroutine.is_some() + && self.elaborator.allow_async_drops() + && !self.elaborator.patch_ref().block(self.elaborator.body(), bb).is_cleanup + && drop_ty.needs_async_drop(self.tcx(), self.elaborator.typing_env()) + { + self.build_async_drop( + self.place, + drop_ty, + Some(bb), + self.succ, + self.unwind, + self.dropline, + false, + ); + } else { + self.elaborator.patch().patch_terminator( + bb, + TerminatorKind::Drop { + place: self.place, + target: self.succ, + unwind: self.unwind.into_action(), + replace: false, + drop: None, + async_fut: None, + }, + ); + } + } + + /// This elaborates a single drop instruction, located at `bb`, and + /// patches over it. + /// + /// The elaborated drop checks the drop flags to only drop what + /// is initialized. + /// + /// In addition, the relevant drop flags also need to be cleared + /// to avoid double-drops. However, in the middle of a complex + /// drop, one must avoid clearing some of the flags before they + /// are read, as that would cause a memory leak. + /// + /// In particular, when dropping an ADT, multiple fields may be + /// joined together under the `rest` subpath. They are all controlled + /// by the primary drop flag, but only the last rest-field dropped + /// should clear it (and it must also not clear anything else). + // + // FIXME: I think we should just control the flags externally, + // and then we do not need this machinery. + #[instrument(level = "debug")] + fn elaborate_drop(&mut self, bb: BasicBlock) { + match self.elaborator.drop_style(self.path, DropFlagMode::Deep) { + DropStyle::Dead => { + self.elaborator + .patch() + .patch_terminator(bb, TerminatorKind::Goto { target: self.succ }); + } + DropStyle::Static => { + self.build_drop(bb); + } + DropStyle::Conditional => { + let drop_bb = self.complete_drop(self.succ, self.unwind); + self.elaborator + .patch() + .patch_terminator(bb, TerminatorKind::Goto { target: drop_bb }); + } + DropStyle::Open => { + let drop_bb = self.open_drop(); + self.elaborator + .patch() + .patch_terminator(bb, TerminatorKind::Goto { target: drop_bb }); + } + } + } + + /// Returns the place and move path for each field of `variant`, + /// (the move path is `None` if the field is a rest field). + fn move_paths_for_fields( + &self, + base_place: Place<'tcx>, + variant_path: D::Path, + variant: &'tcx ty::VariantDef, + args: GenericArgsRef<'tcx>, + ) -> Vec<(Place<'tcx>, Option<D::Path>)> { + variant + .fields + .iter_enumerated() + .map(|(field_idx, field)| { + let subpath = self.elaborator.field_subpath(variant_path, field_idx); + let tcx = self.tcx(); + + assert_eq!(self.elaborator.typing_env().typing_mode, ty::TypingMode::PostAnalysis); + let field_ty = match tcx.try_normalize_erasing_regions( + self.elaborator.typing_env(), + field.ty(tcx, args), + ) { + Ok(t) => t, + Err(_) => Ty::new_error( + self.tcx(), + self.tcx().dcx().span_delayed_bug( + self.elaborator.body().span, + "Error normalizing in drop elaboration.", + ), + ), + }; + + (tcx.mk_place_field(base_place, field_idx, field_ty), subpath) + }) + .collect() + } + + fn drop_subpath( + &mut self, + place: Place<'tcx>, + path: Option<D::Path>, + succ: BasicBlock, + unwind: Unwind, + dropline: Option<BasicBlock>, + ) -> BasicBlock { + if let Some(path) = path { + debug!("drop_subpath: for std field {:?}", place); + + DropCtxt { + elaborator: self.elaborator, + source_info: self.source_info, + path, + place, + succ, + unwind, + dropline, + } + .elaborated_drop_block() + } else { + debug!("drop_subpath: for rest field {:?}", place); + + DropCtxt { + elaborator: self.elaborator, + source_info: self.source_info, + place, + succ, + unwind, + dropline, + // Using `self.path` here to condition the drop on + // our own drop flag. + path: self.path, + } + .complete_drop(succ, unwind) + } + } + + /// Creates one-half of the drop ladder for a list of fields, and return + /// the list of steps in it in reverse order, with the first step + /// dropping 0 fields and so on. + /// + /// `unwind_ladder` is such a list of steps in reverse order, + /// which is called if the matching step of the drop glue panics. + /// + /// `dropline_ladder` is a similar list of steps in reverse order, + /// which is called if the matching step of the drop glue will contain async drop + /// (expanded later to Yield) and the containing coroutine will be dropped at this point. + fn drop_halfladder( + &mut self, + unwind_ladder: &[Unwind], + dropline_ladder: &[Option<BasicBlock>], + mut succ: BasicBlock, + fields: &[(Place<'tcx>, Option<D::Path>)], + ) -> Vec<BasicBlock> { + iter::once(succ) + .chain(itertools::izip!(fields.iter().rev(), unwind_ladder, dropline_ladder).map( + |(&(place, path), &unwind_succ, &dropline_to)| { + succ = self.drop_subpath(place, path, succ, unwind_succ, dropline_to); + succ + }, + )) + .collect() + } + + fn drop_ladder_bottom(&mut self) -> (BasicBlock, Unwind, Option<BasicBlock>) { + // Clear the "master" drop flag at the end. This is needed + // because the "master" drop protects the ADT's discriminant, + // which is invalidated after the ADT is dropped. + ( + self.drop_flag_reset_block(DropFlagMode::Shallow, self.succ, self.unwind), + self.unwind, + self.dropline, + ) + } + + /// Creates a full drop ladder, consisting of 2 connected half-drop-ladders + /// + /// For example, with 3 fields, the drop ladder is + /// + /// .d0: + /// ELAB(drop location.0 [target=.d1, unwind=.c1]) + /// .d1: + /// ELAB(drop location.1 [target=.d2, unwind=.c2]) + /// .d2: + /// ELAB(drop location.2 [target=`self.succ`, unwind=`self.unwind`]) + /// .c1: + /// ELAB(drop location.1 [target=.c2]) + /// .c2: + /// ELAB(drop location.2 [target=`self.unwind`]) + /// + /// For possible-async drops in coroutines we also need dropline ladder + /// .d0 (mainline): + /// ELAB(drop location.0 [target=.d1, unwind=.c1, drop=.e1]) + /// .d1 (mainline): + /// ELAB(drop location.1 [target=.d2, unwind=.c2, drop=.e2]) + /// .d2 (mainline): + /// ELAB(drop location.2 [target=`self.succ`, unwind=`self.unwind`, drop=`self.drop`]) + /// .c1 (unwind): + /// ELAB(drop location.1 [target=.c2]) + /// .c2 (unwind): + /// ELAB(drop location.2 [target=`self.unwind`]) + /// .e1 (dropline): + /// ELAB(drop location.1 [target=.e2, unwind=.c2]) + /// .e2 (dropline): + /// ELAB(drop location.2 [target=`self.drop`, unwind=`self.unwind`]) + /// + /// NOTE: this does not clear the master drop flag, so you need + /// to point succ/unwind on a `drop_ladder_bottom`. + fn drop_ladder( + &mut self, + fields: Vec<(Place<'tcx>, Option<D::Path>)>, + succ: BasicBlock, + unwind: Unwind, + dropline: Option<BasicBlock>, + ) -> (BasicBlock, Unwind, Option<BasicBlock>) { + debug!("drop_ladder({:?}, {:?})", self, fields); + assert!( + if unwind.is_cleanup() { dropline.is_none() } else { true }, + "Dropline is set for cleanup drop ladder" + ); + + let mut fields = fields; + fields.retain(|&(place, _)| { + self.place_ty(place).needs_drop(self.tcx(), self.elaborator.typing_env()) + }); + + debug!("drop_ladder - fields needing drop: {:?}", fields); + + let dropline_ladder: Vec<Option<BasicBlock>> = vec![None; fields.len() + 1]; + let unwind_ladder = vec![Unwind::InCleanup; fields.len() + 1]; + let unwind_ladder: Vec<_> = if let Unwind::To(succ) = unwind { + let halfladder = self.drop_halfladder(&unwind_ladder, &dropline_ladder, succ, &fields); + halfladder.into_iter().map(Unwind::To).collect() + } else { + unwind_ladder + }; + let dropline_ladder: Vec<_> = if let Some(succ) = dropline { + let halfladder = self.drop_halfladder(&unwind_ladder, &dropline_ladder, succ, &fields); + halfladder.into_iter().map(Some).collect() + } else { + dropline_ladder + }; + + let normal_ladder = self.drop_halfladder(&unwind_ladder, &dropline_ladder, succ, &fields); + + ( + *normal_ladder.last().unwrap(), + *unwind_ladder.last().unwrap(), + *dropline_ladder.last().unwrap(), + ) + } + + fn open_drop_for_tuple(&mut self, tys: &[Ty<'tcx>]) -> BasicBlock { + debug!("open_drop_for_tuple({:?}, {:?})", self, tys); + + let fields = tys + .iter() + .enumerate() + .map(|(i, &ty)| { + ( + self.tcx().mk_place_field(self.place, FieldIdx::new(i), ty), + self.elaborator.field_subpath(self.path, FieldIdx::new(i)), + ) + }) + .collect(); + + let (succ, unwind, dropline) = self.drop_ladder_bottom(); + self.drop_ladder(fields, succ, unwind, dropline).0 + } + + /// Drops the T contained in a `Box<T>` if it has not been moved out of + #[instrument(level = "debug", ret)] + fn open_drop_for_box_contents( + &mut self, + adt: ty::AdtDef<'tcx>, + args: GenericArgsRef<'tcx>, + succ: BasicBlock, + unwind: Unwind, + dropline: Option<BasicBlock>, + ) -> BasicBlock { + // drop glue is sent straight to codegen + // box cannot be directly dereferenced + let unique_ty = adt.non_enum_variant().fields[FieldIdx::ZERO].ty(self.tcx(), args); + let unique_variant = unique_ty.ty_adt_def().unwrap().non_enum_variant(); + let nonnull_ty = unique_variant.fields[FieldIdx::ZERO].ty(self.tcx(), args); + let ptr_ty = Ty::new_imm_ptr(self.tcx(), args[0].expect_ty()); + + let unique_place = self.tcx().mk_place_field(self.place, FieldIdx::ZERO, unique_ty); + let nonnull_place = self.tcx().mk_place_field(unique_place, FieldIdx::ZERO, nonnull_ty); + + let ptr_local = self.new_temp(ptr_ty); + + let interior = self.tcx().mk_place_deref(Place::from(ptr_local)); + let interior_path = self.elaborator.deref_subpath(self.path); + + let do_drop_bb = self.drop_subpath(interior, interior_path, succ, unwind, dropline); + + let setup_bbd = BasicBlockData { + statements: vec![self.assign( + Place::from(ptr_local), + Rvalue::Cast(CastKind::Transmute, Operand::Copy(nonnull_place), ptr_ty), + )], + terminator: Some(Terminator { + kind: TerminatorKind::Goto { target: do_drop_bb }, + source_info: self.source_info, + }), + is_cleanup: unwind.is_cleanup(), + }; + self.elaborator.patch().new_block(setup_bbd) + } + + #[instrument(level = "debug", ret)] + fn open_drop_for_adt( + &mut self, + adt: ty::AdtDef<'tcx>, + args: GenericArgsRef<'tcx>, + ) -> BasicBlock { + if adt.variants().is_empty() { + return self.elaborator.patch().new_block(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + source_info: self.source_info, + kind: TerminatorKind::Unreachable, + }), + is_cleanup: self.unwind.is_cleanup(), + }); + } + + let skip_contents = adt.is_union() || adt.is_manually_drop(); + let contents_drop = if skip_contents { + (self.succ, self.unwind, self.dropline) + } else { + self.open_drop_for_adt_contents(adt, args) + }; + + if adt.is_box() { + // we need to drop the inside of the box before running the destructor + let succ = self.destructor_call_block_sync((contents_drop.0, contents_drop.1)); + let unwind = contents_drop + .1 + .map(|unwind| self.destructor_call_block_sync((unwind, Unwind::InCleanup))); + let dropline = contents_drop + .2 + .map(|dropline| self.destructor_call_block_sync((dropline, contents_drop.1))); + + self.open_drop_for_box_contents(adt, args, succ, unwind, dropline) + } else if adt.has_dtor(self.tcx()) { + self.destructor_call_block(contents_drop) + } else { + contents_drop.0 + } + } + + fn open_drop_for_adt_contents( + &mut self, + adt: ty::AdtDef<'tcx>, + args: GenericArgsRef<'tcx>, + ) -> (BasicBlock, Unwind, Option<BasicBlock>) { + let (succ, unwind, dropline) = self.drop_ladder_bottom(); + if !adt.is_enum() { + let fields = + self.move_paths_for_fields(self.place, self.path, adt.variant(FIRST_VARIANT), args); + self.drop_ladder(fields, succ, unwind, dropline) + } else { + self.open_drop_for_multivariant(adt, args, succ, unwind, dropline) + } + } + + fn open_drop_for_multivariant( + &mut self, + adt: ty::AdtDef<'tcx>, + args: GenericArgsRef<'tcx>, + succ: BasicBlock, + unwind: Unwind, + dropline: Option<BasicBlock>, + ) -> (BasicBlock, Unwind, Option<BasicBlock>) { + let mut values = Vec::with_capacity(adt.variants().len()); + let mut normal_blocks = Vec::with_capacity(adt.variants().len()); + let mut unwind_blocks = + if unwind.is_cleanup() { None } else { Some(Vec::with_capacity(adt.variants().len())) }; + let mut dropline_blocks = + if dropline.is_none() { None } else { Some(Vec::with_capacity(adt.variants().len())) }; + + let mut have_otherwise_with_drop_glue = false; + let mut have_otherwise = false; + let tcx = self.tcx(); + + for (variant_index, discr) in adt.discriminants(tcx) { + let variant = &adt.variant(variant_index); + let subpath = self.elaborator.downcast_subpath(self.path, variant_index); + + if let Some(variant_path) = subpath { + let base_place = tcx.mk_place_elem( + self.place, + ProjectionElem::Downcast(Some(variant.name), variant_index), + ); + let fields = self.move_paths_for_fields(base_place, variant_path, variant, args); + values.push(discr.val); + if let Unwind::To(unwind) = unwind { + // We can't use the half-ladder from the original + // drop ladder, because this breaks the + // "funclet can't have 2 successor funclets" + // requirement from MSVC: + // + // switch unwind-switch + // / \ / \ + // v1.0 v2.0 v2.0-unwind v1.0-unwind + // | | / | + // v1.1-unwind v2.1-unwind | + // ^ | + // \-------------------------------/ + // + // Create a duplicate half-ladder to avoid that. We + // could technically only do this on MSVC, but I + // I want to minimize the divergence between MSVC + // and non-MSVC. + + let unwind_blocks = unwind_blocks.as_mut().unwrap(); + let unwind_ladder = vec![Unwind::InCleanup; fields.len() + 1]; + let dropline_ladder: Vec<Option<BasicBlock>> = vec![None; fields.len() + 1]; + let halfladder = + self.drop_halfladder(&unwind_ladder, &dropline_ladder, unwind, &fields); + unwind_blocks.push(halfladder.last().cloned().unwrap()); + } + let (normal, _, drop_bb) = self.drop_ladder(fields, succ, unwind, dropline); + normal_blocks.push(normal); + if dropline.is_some() { + dropline_blocks.as_mut().unwrap().push(drop_bb.unwrap()); + } + } else { + have_otherwise = true; + + let typing_env = self.elaborator.typing_env(); + let have_field_with_drop_glue = variant + .fields + .iter() + .any(|field| field.ty(tcx, args).needs_drop(tcx, typing_env)); + if have_field_with_drop_glue { + have_otherwise_with_drop_glue = true; + } + } + } + + if !have_otherwise { + values.pop(); + } else if !have_otherwise_with_drop_glue { + normal_blocks.push(self.goto_block(succ, unwind)); + if let Unwind::To(unwind) = unwind { + unwind_blocks.as_mut().unwrap().push(self.goto_block(unwind, Unwind::InCleanup)); + } + } else { + normal_blocks.push(self.drop_block(succ, unwind)); + if let Unwind::To(unwind) = unwind { + unwind_blocks.as_mut().unwrap().push(self.drop_block(unwind, Unwind::InCleanup)); + } + } + + ( + self.adt_switch_block(adt, normal_blocks, &values, succ, unwind), + unwind.map(|unwind| { + self.adt_switch_block( + adt, + unwind_blocks.unwrap(), + &values, + unwind, + Unwind::InCleanup, + ) + }), + dropline.map(|dropline| { + self.adt_switch_block(adt, dropline_blocks.unwrap(), &values, dropline, unwind) + }), + ) + } + + fn adt_switch_block( + &mut self, + adt: ty::AdtDef<'tcx>, + blocks: Vec<BasicBlock>, + values: &[u128], + succ: BasicBlock, + unwind: Unwind, + ) -> BasicBlock { + // If there are multiple variants, then if something + // is present within the enum the discriminant, tracked + // by the rest path, must be initialized. + // + // Additionally, we do not want to switch on the + // discriminant after it is free-ed, because that + // way lies only trouble. + let discr_ty = adt.repr().discr_type().to_ty(self.tcx()); + let discr = Place::from(self.new_temp(discr_ty)); + let discr_rv = Rvalue::Discriminant(self.place); + let switch_block = BasicBlockData { + statements: vec![self.assign(discr, discr_rv)], + terminator: Some(Terminator { + source_info: self.source_info, + kind: TerminatorKind::SwitchInt { + discr: Operand::Move(discr), + targets: SwitchTargets::new( + values.iter().copied().zip(blocks.iter().copied()), + *blocks.last().unwrap(), + ), + }, + }), + is_cleanup: unwind.is_cleanup(), + }; + let switch_block = self.elaborator.patch().new_block(switch_block); + self.drop_flag_test_block(switch_block, succ, unwind) + } + + fn destructor_call_block_sync(&mut self, (succ, unwind): (BasicBlock, Unwind)) -> BasicBlock { + debug!("destructor_call_block_sync({:?}, {:?})", self, succ); + let tcx = self.tcx(); + let drop_trait = tcx.require_lang_item(LangItem::Drop, DUMMY_SP); + let drop_fn = tcx.associated_item_def_ids(drop_trait)[0]; + let ty = self.place_ty(self.place); + + let ref_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, ty); + let ref_place = self.new_temp(ref_ty); + let unit_temp = Place::from(self.new_temp(tcx.types.unit)); + + let result = BasicBlockData { + statements: vec![self.assign( + Place::from(ref_place), + Rvalue::Ref( + tcx.lifetimes.re_erased, + BorrowKind::Mut { kind: MutBorrowKind::Default }, + self.place, + ), + )], + terminator: Some(Terminator { + kind: TerminatorKind::Call { + func: Operand::function_handle( + tcx, + drop_fn, + [ty.into()], + self.source_info.span, + ), + args: [Spanned { node: Operand::Move(Place::from(ref_place)), span: DUMMY_SP }] + .into(), + destination: unit_temp, + target: Some(succ), + unwind: unwind.into_action(), + call_source: CallSource::Misc, + fn_span: self.source_info.span, + }, + source_info: self.source_info, + }), + is_cleanup: unwind.is_cleanup(), + }; + + let destructor_block = self.elaborator.patch().new_block(result); + + let block_start = Location { block: destructor_block, statement_index: 0 }; + self.elaborator.clear_drop_flag(block_start, self.path, DropFlagMode::Shallow); + + self.drop_flag_test_block(destructor_block, succ, unwind) + } + + fn destructor_call_block( + &mut self, + (succ, unwind, dropline): (BasicBlock, Unwind, Option<BasicBlock>), + ) -> BasicBlock { + debug!("destructor_call_block({:?}, {:?})", self, succ); + let ty = self.place_ty(self.place); + if self.tcx().features().async_drop() + && self.elaborator.body().coroutine.is_some() + && self.elaborator.allow_async_drops() + && !unwind.is_cleanup() + && ty.is_async_drop(self.tcx(), self.elaborator.typing_env()) + { + let destructor_block = + self.build_async_drop(self.place, ty, None, succ, unwind, dropline, true); + + let block_start = Location { block: destructor_block, statement_index: 0 }; + self.elaborator.clear_drop_flag(block_start, self.path, DropFlagMode::Shallow); + + self.drop_flag_test_block(destructor_block, succ, unwind) + } else { + self.destructor_call_block_sync((succ, unwind)) + } + } + + /// Create a loop that drops an array: + /// + /// ```text + /// loop-block: + /// can_go = cur == len + /// if can_go then succ else drop-block + /// drop-block: + /// ptr = &raw mut P[cur] + /// cur = cur + 1 + /// drop(ptr) + /// ``` + fn drop_loop( + &mut self, + succ: BasicBlock, + cur: Local, + len: Local, + ety: Ty<'tcx>, + unwind: Unwind, + dropline: Option<BasicBlock>, + ) -> BasicBlock { + let copy = |place: Place<'tcx>| Operand::Copy(place); + let move_ = |place: Place<'tcx>| Operand::Move(place); + let tcx = self.tcx(); + + let ptr_ty = Ty::new_mut_ptr(tcx, ety); + let ptr = Place::from(self.new_temp(ptr_ty)); + let can_go = Place::from(self.new_temp(tcx.types.bool)); + let one = self.constant_usize(1); + + let drop_block = BasicBlockData { + statements: vec![ + self.assign( + ptr, + Rvalue::RawPtr(RawPtrKind::Mut, tcx.mk_place_index(self.place, cur)), + ), + self.assign( + cur.into(), + Rvalue::BinaryOp(BinOp::Add, Box::new((move_(cur.into()), one))), + ), + ], + is_cleanup: unwind.is_cleanup(), + terminator: Some(Terminator { + source_info: self.source_info, + // this gets overwritten by drop elaboration. + kind: TerminatorKind::Unreachable, + }), + }; + let drop_block = self.elaborator.patch().new_block(drop_block); + + let loop_block = BasicBlockData { + statements: vec![self.assign( + can_go, + Rvalue::BinaryOp(BinOp::Eq, Box::new((copy(Place::from(cur)), copy(len.into())))), + )], + is_cleanup: unwind.is_cleanup(), + terminator: Some(Terminator { + source_info: self.source_info, + kind: TerminatorKind::if_(move_(can_go), succ, drop_block), + }), + }; + let loop_block = self.elaborator.patch().new_block(loop_block); + + let place = tcx.mk_place_deref(ptr); + if self.tcx().features().async_drop() + && self.elaborator.body().coroutine.is_some() + && self.elaborator.allow_async_drops() + && !unwind.is_cleanup() + && ety.needs_async_drop(self.tcx(), self.elaborator.typing_env()) + { + self.build_async_drop( + place, + ety, + Some(drop_block), + loop_block, + unwind, + dropline, + false, + ); + } else { + self.elaborator.patch().patch_terminator( + drop_block, + TerminatorKind::Drop { + place, + target: loop_block, + unwind: unwind.into_action(), + replace: false, + drop: None, + async_fut: None, + }, + ); + } + loop_block + } + + fn open_drop_for_array( + &mut self, + array_ty: Ty<'tcx>, + ety: Ty<'tcx>, + opt_size: Option<u64>, + ) -> BasicBlock { + debug!("open_drop_for_array({:?}, {:?}, {:?})", array_ty, ety, opt_size); + let tcx = self.tcx(); + + if let Some(size) = opt_size { + enum ProjectionKind<Path> { + Drop(std::ops::Range<u64>), + Keep(u64, Path), + } + // Previously, we'd make a projection for every element in the array and create a drop + // ladder if any `array_subpath` was `Some`, i.e. moving out with an array pattern. + // This caused huge memory usage when generating the drops for large arrays, so we instead + // record the *subslices* which are dropped and the *indexes* which are kept + let mut drop_ranges = vec![]; + let mut dropping = true; + let mut start = 0; + for i in 0..size { + let path = self.elaborator.array_subpath(self.path, i, size); + if dropping && path.is_some() { + drop_ranges.push(ProjectionKind::Drop(start..i)); + dropping = false; + } else if !dropping && path.is_none() { + dropping = true; + start = i; + } + if let Some(path) = path { + drop_ranges.push(ProjectionKind::Keep(i, path)); + } + } + if !drop_ranges.is_empty() { + if dropping { + drop_ranges.push(ProjectionKind::Drop(start..size)); + } + let fields = drop_ranges + .iter() + .rev() + .map(|p| { + let (project, path) = match p { + ProjectionKind::Drop(r) => ( + ProjectionElem::Subslice { + from: r.start, + to: r.end, + from_end: false, + }, + None, + ), + &ProjectionKind::Keep(offset, path) => ( + ProjectionElem::ConstantIndex { + offset, + min_length: size, + from_end: false, + }, + Some(path), + ), + }; + (tcx.mk_place_elem(self.place, project), path) + }) + .collect::<Vec<_>>(); + let (succ, unwind, dropline) = self.drop_ladder_bottom(); + return self.drop_ladder(fields, succ, unwind, dropline).0; + } + } + + let array_ptr_ty = Ty::new_mut_ptr(tcx, array_ty); + let array_ptr = self.new_temp(array_ptr_ty); + + let slice_ty = Ty::new_slice(tcx, ety); + let slice_ptr_ty = Ty::new_mut_ptr(tcx, slice_ty); + let slice_ptr = self.new_temp(slice_ptr_ty); + + let mut delegate_block = BasicBlockData { + statements: vec![ + self.assign(Place::from(array_ptr), Rvalue::RawPtr(RawPtrKind::Mut, self.place)), + self.assign( + Place::from(slice_ptr), + Rvalue::Cast( + CastKind::PointerCoercion( + PointerCoercion::Unsize, + CoercionSource::Implicit, + ), + Operand::Move(Place::from(array_ptr)), + slice_ptr_ty, + ), + ), + ], + is_cleanup: self.unwind.is_cleanup(), + terminator: None, + }; + + let array_place = mem::replace( + &mut self.place, + Place::from(slice_ptr).project_deeper(&[PlaceElem::Deref], tcx), + ); + let slice_block = self.drop_loop_trio_for_slice(ety); + self.place = array_place; + + delegate_block.terminator = Some(Terminator { + source_info: self.source_info, + kind: TerminatorKind::Goto { target: slice_block }, + }); + self.elaborator.patch().new_block(delegate_block) + } + + /// Creates a trio of drop-loops of `place`, which drops its contents, even + /// in the case of 1 panic or in the case of coroutine drop + fn drop_loop_trio_for_slice(&mut self, ety: Ty<'tcx>) -> BasicBlock { + debug!("drop_loop_trio_for_slice({:?})", ety); + let tcx = self.tcx(); + let len = self.new_temp(tcx.types.usize); + let cur = self.new_temp(tcx.types.usize); + + let unwind = self + .unwind + .map(|unwind| self.drop_loop(unwind, cur, len, ety, Unwind::InCleanup, None)); + + let dropline = + self.dropline.map(|dropline| self.drop_loop(dropline, cur, len, ety, unwind, None)); + + let loop_block = self.drop_loop(self.succ, cur, len, ety, unwind, dropline); + + let [PlaceElem::Deref] = self.place.projection.as_slice() else { + span_bug!( + self.source_info.span, + "Expected place for slice drop shim to be *_n, but it's {:?}", + self.place, + ); + }; + + let zero = self.constant_usize(0); + let block = BasicBlockData { + statements: vec![ + self.assign( + len.into(), + Rvalue::UnaryOp( + UnOp::PtrMetadata, + Operand::Copy(Place::from(self.place.local)), + ), + ), + self.assign(cur.into(), Rvalue::Use(zero)), + ], + is_cleanup: unwind.is_cleanup(), + terminator: Some(Terminator { + source_info: self.source_info, + kind: TerminatorKind::Goto { target: loop_block }, + }), + }; + + let drop_block = self.elaborator.patch().new_block(block); + // FIXME(#34708): handle partially-dropped array/slice elements. + let reset_block = self.drop_flag_reset_block(DropFlagMode::Deep, drop_block, unwind); + self.drop_flag_test_block(reset_block, self.succ, unwind) + } + + /// The slow-path - create an "open", elaborated drop for a type + /// which is moved-out-of only partially, and patch `bb` to a jump + /// to it. This must not be called on ADTs with a destructor, + /// as these can't be moved-out-of, except for `Box<T>`, which is + /// special-cased. + /// + /// This creates a "drop ladder" that drops the needed fields of the + /// ADT, both in the success case or if one of the destructors fail. + fn open_drop(&mut self) -> BasicBlock { + let ty = self.place_ty(self.place); + match ty.kind() { + ty::Closure(_, args) => self.open_drop_for_tuple(args.as_closure().upvar_tys()), + ty::CoroutineClosure(_, args) => { + self.open_drop_for_tuple(args.as_coroutine_closure().upvar_tys()) + } + // Note that `elaborate_drops` only drops the upvars of a coroutine, + // and this is ok because `open_drop` here can only be reached + // within that own coroutine's resume function. + // This should only happen for the self argument on the resume function. + // It effectively only contains upvars until the coroutine transformation runs. + // See librustc_body/transform/coroutine.rs for more details. + ty::Coroutine(_, args) => self.open_drop_for_tuple(args.as_coroutine().upvar_tys()), + ty::Tuple(fields) => self.open_drop_for_tuple(fields), + ty::Adt(def, args) => self.open_drop_for_adt(*def, args), + ty::Dynamic(..) => self.complete_drop(self.succ, self.unwind), + ty::Array(ety, size) => { + let size = size.try_to_target_usize(self.tcx()); + self.open_drop_for_array(ty, *ety, size) + } + ty::Slice(ety) => self.drop_loop_trio_for_slice(*ety), + + ty::UnsafeBinder(_) => { + // Unsafe binders may elaborate drops if their inner type isn't copy. + // This is enforced in typeck, so this should never happen. + self.tcx().dcx().span_delayed_bug( + self.source_info.span, + "open drop for unsafe binder shouldn't be encountered", + ); + self.elaborator.patch().new_block(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + source_info: self.source_info, + kind: TerminatorKind::Unreachable, + }), + is_cleanup: self.unwind.is_cleanup(), + }) + } + + _ => span_bug!(self.source_info.span, "open drop from non-ADT `{:?}`", ty), + } + } + + fn complete_drop(&mut self, succ: BasicBlock, unwind: Unwind) -> BasicBlock { + debug!("complete_drop(succ={:?}, unwind={:?})", succ, unwind); + + let drop_block = self.drop_block(succ, unwind); + + self.drop_flag_test_block(drop_block, succ, unwind) + } + + /// Creates a block that resets the drop flag. If `mode` is deep, all children drop flags will + /// also be cleared. + fn drop_flag_reset_block( + &mut self, + mode: DropFlagMode, + succ: BasicBlock, + unwind: Unwind, + ) -> BasicBlock { + debug!("drop_flag_reset_block({:?},{:?})", self, mode); + + if unwind.is_cleanup() { + // The drop flag isn't read again on the unwind path, so don't + // bother setting it. + return succ; + } + let block = self.new_block(unwind, TerminatorKind::Goto { target: succ }); + let block_start = Location { block, statement_index: 0 }; + self.elaborator.clear_drop_flag(block_start, self.path, mode); + block + } + + fn elaborated_drop_block(&mut self) -> BasicBlock { + debug!("elaborated_drop_block({:?})", self); + let blk = self.drop_block_simple(self.succ, self.unwind); + self.elaborate_drop(blk); + blk + } + + fn drop_block_simple(&mut self, target: BasicBlock, unwind: Unwind) -> BasicBlock { + let block = TerminatorKind::Drop { + place: self.place, + target, + unwind: unwind.into_action(), + replace: false, + drop: self.dropline, + async_fut: None, + }; + self.new_block(unwind, block) + } + + fn drop_block(&mut self, target: BasicBlock, unwind: Unwind) -> BasicBlock { + let drop_ty = self.place_ty(self.place); + if self.tcx().features().async_drop() + && self.elaborator.body().coroutine.is_some() + && self.elaborator.allow_async_drops() + && !unwind.is_cleanup() + && drop_ty.needs_async_drop(self.tcx(), self.elaborator.typing_env()) + { + self.build_async_drop( + self.place, + drop_ty, + None, + self.succ, + unwind, + self.dropline, + false, + ) + } else { + let block = TerminatorKind::Drop { + place: self.place, + target, + unwind: unwind.into_action(), + replace: false, + drop: None, + async_fut: None, + }; + self.new_block(unwind, block) + } + } + + fn goto_block(&mut self, target: BasicBlock, unwind: Unwind) -> BasicBlock { + let block = TerminatorKind::Goto { target }; + self.new_block(unwind, block) + } + + /// Returns the block to jump to in order to test the drop flag and execute the drop. + /// + /// Depending on the required `DropStyle`, this might be a generated block with an `if` + /// terminator (for dynamic/open drops), or it might be `on_set` or `on_unset` itself, in case + /// the drop can be statically determined. + fn drop_flag_test_block( + &mut self, + on_set: BasicBlock, + on_unset: BasicBlock, + unwind: Unwind, + ) -> BasicBlock { + let style = self.elaborator.drop_style(self.path, DropFlagMode::Shallow); + debug!( + "drop_flag_test_block({:?},{:?},{:?},{:?}) - {:?}", + self, on_set, on_unset, unwind, style + ); + + match style { + DropStyle::Dead => on_unset, + DropStyle::Static => on_set, + DropStyle::Conditional | DropStyle::Open => { + let flag = self.elaborator.get_drop_flag(self.path).unwrap(); + let term = TerminatorKind::if_(flag, on_set, on_unset); + self.new_block(unwind, term) + } + } + } + + fn new_block(&mut self, unwind: Unwind, k: TerminatorKind<'tcx>) -> BasicBlock { + self.elaborator.patch().new_block(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { source_info: self.source_info, kind: k }), + is_cleanup: unwind.is_cleanup(), + }) + } + + fn new_block_with_statements( + &mut self, + unwind: Unwind, + statements: Vec<Statement<'tcx>>, + k: TerminatorKind<'tcx>, + ) -> BasicBlock { + self.elaborator.patch().new_block(BasicBlockData { + statements, + terminator: Some(Terminator { source_info: self.source_info, kind: k }), + is_cleanup: unwind.is_cleanup(), + }) + } + + fn new_temp(&mut self, ty: Ty<'tcx>) -> Local { + self.elaborator.patch().new_temp(ty, self.source_info.span) + } + + fn constant_usize(&self, val: u16) -> Operand<'tcx> { + Operand::Constant(Box::new(ConstOperand { + span: self.source_info.span, + user_ty: None, + const_: Const::from_usize(self.tcx(), val.into()), + })) + } + + fn assign(&self, lhs: Place<'tcx>, rhs: Rvalue<'tcx>) -> Statement<'tcx> { + Statement { + source_info: self.source_info, + kind: StatementKind::Assign(Box::new((lhs, rhs))), + } + } +} diff --git a/compiler/rustc_mir_transform/src/elaborate_drops.rs b/compiler/rustc_mir_transform/src/elaborate_drops.rs new file mode 100644 index 00000000000..42c8cb0b906 --- /dev/null +++ b/compiler/rustc_mir_transform/src/elaborate_drops.rs @@ -0,0 +1,518 @@ +use std::fmt; + +use rustc_abi::{FieldIdx, VariantIdx}; +use rustc_index::IndexVec; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_mir_dataflow::impls::{MaybeInitializedPlaces, MaybeUninitializedPlaces}; +use rustc_mir_dataflow::move_paths::{LookupResult, MoveData, MovePathIndex}; +use rustc_mir_dataflow::{ + Analysis, DropFlagState, MoveDataTypingEnv, ResultsCursor, on_all_children_bits, + on_lookup_result_bits, +}; +use rustc_span::Span; +use tracing::{debug, instrument}; + +use crate::deref_separator::deref_finder; +use crate::elaborate_drop::{DropElaborator, DropFlagMode, DropStyle, Unwind, elaborate_drop}; +use crate::patch::MirPatch; + +/// During MIR building, Drop terminators are inserted in every place where a drop may occur. +/// However, in this phase, the presence of these terminators does not guarantee that a destructor +/// will run, as the target of the drop may be uninitialized. +/// In general, the compiler cannot determine at compile time whether a destructor will run or not. +/// +/// At a high level, this pass refines Drop to only run the destructor if the +/// target is initialized. The way this is achieved is by inserting drop flags for every variable +/// that may be dropped, and then using those flags to determine whether a destructor should run. +/// Once this is complete, Drop terminators in the MIR correspond to a call to the "drop glue" or +/// "drop shim" for the type of the dropped place. +/// +/// This pass relies on dropped places having an associated move path, which is then used to +/// determine the initialization status of the place and its descendants. +/// It's worth noting that a MIR containing a Drop without an associated move path is probably ill +/// formed, as it would allow running a destructor on a place behind a reference: +/// +/// ```text +/// fn drop_term<T>(t: &mut T) { +/// mir! { +/// { +/// Drop(*t, exit) +/// } +/// exit = { +/// Return() +/// } +/// } +/// } +/// ``` +pub(super) struct ElaborateDrops; + +impl<'tcx> crate::MirPass<'tcx> for ElaborateDrops { + #[instrument(level = "trace", skip(self, tcx, body))] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("elaborate_drops({:?} @ {:?})", body.source, body.span); + // FIXME(#132279): This is used during the phase transition from analysis + // to runtime, so we have to manually specify the correct typing mode. + let typing_env = ty::TypingEnv::post_analysis(tcx, body.source.def_id()); + // For types that do not need dropping, the behaviour is trivial. So we only need to track + // init/uninit for types that do need dropping. + let move_data = MoveData::gather_moves(body, tcx, |ty| ty.needs_drop(tcx, typing_env)); + let elaborate_patch = { + let env = MoveDataTypingEnv { move_data, typing_env }; + + let mut inits = MaybeInitializedPlaces::new(tcx, body, &env.move_data) + .skipping_unreachable_unwind() + .iterate_to_fixpoint(tcx, body, Some("elaborate_drops")) + .into_results_cursor(body); + let dead_unwinds = compute_dead_unwinds(body, &mut inits); + + let uninits = MaybeUninitializedPlaces::new(tcx, body, &env.move_data) + .mark_inactive_variants_as_uninit() + .skipping_unreachable_unwind(dead_unwinds) + .iterate_to_fixpoint(tcx, body, Some("elaborate_drops")) + .into_results_cursor(body); + + let drop_flags = IndexVec::from_elem(None, &env.move_data.move_paths); + ElaborateDropsCtxt { + tcx, + body, + env: &env, + init_data: InitializationData { inits, uninits }, + drop_flags, + patch: MirPatch::new(body), + } + .elaborate() + }; + elaborate_patch.apply(body); + deref_finder(tcx, body); + } + + fn is_required(&self) -> bool { + true + } +} + +/// Records unwind edges which are known to be unreachable, because they are in `drop` terminators +/// that can't drop anything. +#[instrument(level = "trace", skip(body, flow_inits), ret)] +fn compute_dead_unwinds<'a, 'tcx>( + body: &'a Body<'tcx>, + flow_inits: &mut ResultsCursor<'a, 'tcx, MaybeInitializedPlaces<'a, 'tcx>>, +) -> DenseBitSet<BasicBlock> { + // We only need to do this pass once, because unwind edges can only + // reach cleanup blocks, which can't have unwind edges themselves. + let mut dead_unwinds = DenseBitSet::new_empty(body.basic_blocks.len()); + for (bb, bb_data) in body.basic_blocks.iter_enumerated() { + let TerminatorKind::Drop { place, unwind: UnwindAction::Cleanup(_), .. } = + bb_data.terminator().kind + else { + continue; + }; + + flow_inits.seek_before_primary_effect(body.terminator_loc(bb)); + if flow_inits.analysis().is_unwind_dead(place, flow_inits.get()) { + dead_unwinds.insert(bb); + } + } + + dead_unwinds +} + +struct InitializationData<'a, 'tcx> { + inits: ResultsCursor<'a, 'tcx, MaybeInitializedPlaces<'a, 'tcx>>, + uninits: ResultsCursor<'a, 'tcx, MaybeUninitializedPlaces<'a, 'tcx>>, +} + +impl InitializationData<'_, '_> { + fn seek_before(&mut self, loc: Location) { + self.inits.seek_before_primary_effect(loc); + self.uninits.seek_before_primary_effect(loc); + } + + fn maybe_init_uninit(&self, path: MovePathIndex) -> (bool, bool) { + (self.inits.get().contains(path), self.uninits.get().contains(path)) + } +} + +impl<'a, 'tcx> DropElaborator<'a, 'tcx> for ElaborateDropsCtxt<'a, 'tcx> { + type Path = MovePathIndex; + + fn patch_ref(&self) -> &MirPatch<'tcx> { + &self.patch + } + + fn patch(&mut self) -> &mut MirPatch<'tcx> { + &mut self.patch + } + + fn body(&self) -> &'a Body<'tcx> { + self.body + } + + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn typing_env(&self) -> ty::TypingEnv<'tcx> { + self.env.typing_env + } + + fn allow_async_drops(&self) -> bool { + true + } + + fn terminator_loc(&self, bb: BasicBlock) -> Location { + self.patch.terminator_loc(self.body, bb) + } + + #[instrument(level = "debug", skip(self), ret)] + fn drop_style(&self, path: Self::Path, mode: DropFlagMode) -> DropStyle { + let ((maybe_init, maybe_uninit), multipart) = match mode { + DropFlagMode::Shallow => (self.init_data.maybe_init_uninit(path), false), + DropFlagMode::Deep => { + let mut some_maybe_init = false; + let mut some_maybe_uninit = false; + let mut children_count = 0; + on_all_children_bits(self.move_data(), path, |child| { + let (maybe_init, maybe_uninit) = self.init_data.maybe_init_uninit(child); + debug!("elaborate_drop: state({:?}) = {:?}", child, (maybe_init, maybe_uninit)); + some_maybe_init |= maybe_init; + some_maybe_uninit |= maybe_uninit; + children_count += 1; + }); + ((some_maybe_init, some_maybe_uninit), children_count != 1) + } + }; + match (maybe_init, maybe_uninit, multipart) { + (false, _, _) => DropStyle::Dead, + (true, false, _) => DropStyle::Static, + (true, true, false) => DropStyle::Conditional, + (true, true, true) => DropStyle::Open, + } + } + + fn clear_drop_flag(&mut self, loc: Location, path: Self::Path, mode: DropFlagMode) { + match mode { + DropFlagMode::Shallow => { + self.set_drop_flag(loc, path, DropFlagState::Absent); + } + DropFlagMode::Deep => { + on_all_children_bits(self.move_data(), path, |child| { + self.set_drop_flag(loc, child, DropFlagState::Absent) + }); + } + } + } + + fn field_subpath(&self, path: Self::Path, field: FieldIdx) -> Option<Self::Path> { + rustc_mir_dataflow::move_path_children_matching(self.move_data(), path, |e| match e { + ProjectionElem::Field(idx, _) => idx == field, + _ => false, + }) + } + + fn array_subpath(&self, path: Self::Path, index: u64, size: u64) -> Option<Self::Path> { + rustc_mir_dataflow::move_path_children_matching(self.move_data(), path, |e| match e { + ProjectionElem::ConstantIndex { offset, min_length, from_end } => { + debug_assert!(size == min_length, "min_length should be exact for arrays"); + assert!(!from_end, "from_end should not be used for array element ConstantIndex"); + offset == index + } + _ => false, + }) + } + + fn deref_subpath(&self, path: Self::Path) -> Option<Self::Path> { + rustc_mir_dataflow::move_path_children_matching(self.move_data(), path, |e| { + e == ProjectionElem::Deref + }) + } + + fn downcast_subpath(&self, path: Self::Path, variant: VariantIdx) -> Option<Self::Path> { + rustc_mir_dataflow::move_path_children_matching(self.move_data(), path, |e| match e { + ProjectionElem::Downcast(_, idx) => idx == variant, + _ => false, + }) + } + + fn get_drop_flag(&mut self, path: Self::Path) -> Option<Operand<'tcx>> { + self.drop_flag(path).map(Operand::Copy) + } +} + +struct ElaborateDropsCtxt<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + body: &'a Body<'tcx>, + env: &'a MoveDataTypingEnv<'tcx>, + init_data: InitializationData<'a, 'tcx>, + drop_flags: IndexVec<MovePathIndex, Option<Local>>, + patch: MirPatch<'tcx>, +} + +impl fmt::Debug for ElaborateDropsCtxt<'_, '_> { + fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { + Ok(()) + } +} + +impl<'a, 'tcx> ElaborateDropsCtxt<'a, 'tcx> { + fn move_data(&self) -> &'a MoveData<'tcx> { + &self.env.move_data + } + + fn create_drop_flag(&mut self, index: MovePathIndex, span: Span) { + let patch = &mut self.patch; + debug!("create_drop_flag({:?})", self.body.span); + self.drop_flags[index].get_or_insert_with(|| patch.new_temp(self.tcx.types.bool, span)); + } + + fn drop_flag(&mut self, index: MovePathIndex) -> Option<Place<'tcx>> { + self.drop_flags[index].map(Place::from) + } + + /// create a patch that elaborates all drops in the input + /// MIR. + fn elaborate(mut self) -> MirPatch<'tcx> { + self.collect_drop_flags(); + + self.elaborate_drops(); + + self.drop_flags_on_init(); + self.drop_flags_for_fn_rets(); + self.drop_flags_for_args(); + self.drop_flags_for_locs(); + + self.patch + } + + fn collect_drop_flags(&mut self) { + for (bb, data) in self.body.basic_blocks.iter_enumerated() { + let terminator = data.terminator(); + let TerminatorKind::Drop { ref place, .. } = terminator.kind else { continue }; + + let path = self.move_data().rev_lookup.find(place.as_ref()); + debug!("collect_drop_flags: {:?}, place {:?} ({:?})", bb, place, path); + + match path { + LookupResult::Exact(path) => { + self.init_data.seek_before(self.body.terminator_loc(bb)); + on_all_children_bits(self.move_data(), path, |child| { + let (maybe_init, maybe_uninit) = self.init_data.maybe_init_uninit(child); + debug!( + "collect_drop_flags: collecting {:?} from {:?}@{:?} - {:?}", + child, + place, + path, + (maybe_init, maybe_uninit) + ); + if maybe_init && maybe_uninit { + self.create_drop_flag(child, terminator.source_info.span) + } + }); + } + LookupResult::Parent(None) => {} + LookupResult::Parent(Some(parent)) => { + if self.body.local_decls[place.local].is_deref_temp() { + continue; + } + + self.init_data.seek_before(self.body.terminator_loc(bb)); + let (_maybe_init, maybe_uninit) = self.init_data.maybe_init_uninit(parent); + if maybe_uninit { + self.tcx.dcx().span_delayed_bug( + terminator.source_info.span, + format!( + "drop of untracked, uninitialized value {bb:?}, place {place:?} ({path:?})" + ), + ); + } + } + }; + } + } + + fn elaborate_drops(&mut self) { + // This function should mirror what `collect_drop_flags` does. + for (bb, data) in self.body.basic_blocks.iter_enumerated() { + let terminator = data.terminator(); + let TerminatorKind::Drop { place, target, unwind, replace, drop, async_fut: _ } = + terminator.kind + else { + continue; + }; + + // This place does not need dropping. It does not have an associated move-path, so the + // match below will conservatively keep an unconditional drop. As that drop is useless, + // just remove it here and now. + if !place + .ty(&self.body.local_decls, self.tcx) + .ty + .needs_drop(self.tcx, self.typing_env()) + { + self.patch.patch_terminator(bb, TerminatorKind::Goto { target }); + continue; + } + + let path = self.move_data().rev_lookup.find(place.as_ref()); + match path { + LookupResult::Exact(path) => { + let unwind = match unwind { + _ if data.is_cleanup => Unwind::InCleanup, + UnwindAction::Cleanup(cleanup) => Unwind::To(cleanup), + UnwindAction::Continue => Unwind::To(self.patch.resume_block()), + UnwindAction::Unreachable => { + Unwind::To(self.patch.unreachable_cleanup_block()) + } + UnwindAction::Terminate(reason) => { + debug_assert_ne!( + reason, + UnwindTerminateReason::InCleanup, + "we are not in a cleanup block, InCleanup reason should be impossible" + ); + Unwind::To(self.patch.terminate_block(reason)) + } + }; + self.init_data.seek_before(self.body.terminator_loc(bb)); + elaborate_drop( + self, + terminator.source_info, + place, + path, + target, + unwind, + bb, + drop, + ) + } + LookupResult::Parent(None) => {} + LookupResult::Parent(Some(_)) => { + if !replace { + self.tcx.dcx().span_bug( + terminator.source_info.span, + format!("drop of untracked value {bb:?}"), + ); + } + // A drop and replace behind a pointer/array/whatever. + // The borrow checker requires that these locations are initialized before the + // assignment, so we just leave an unconditional drop. + assert!(!data.is_cleanup); + } + } + } + } + + fn constant_bool(&self, span: Span, val: bool) -> Rvalue<'tcx> { + Rvalue::Use(Operand::Constant(Box::new(ConstOperand { + span, + user_ty: None, + const_: Const::from_bool(self.tcx, val), + }))) + } + + fn set_drop_flag(&mut self, loc: Location, path: MovePathIndex, val: DropFlagState) { + if let Some(flag) = self.drop_flags[path] { + let span = self.patch.source_info_for_location(self.body, loc).span; + let val = self.constant_bool(span, val.value()); + self.patch.add_assign(loc, Place::from(flag), val); + } + } + + fn drop_flags_on_init(&mut self) { + let loc = Location::START; + let span = self.patch.source_info_for_location(self.body, loc).span; + let false_ = self.constant_bool(span, false); + for flag in self.drop_flags.iter().flatten() { + self.patch.add_assign(loc, Place::from(*flag), false_.clone()); + } + } + + fn drop_flags_for_fn_rets(&mut self) { + for (bb, data) in self.body.basic_blocks.iter_enumerated() { + if let TerminatorKind::Call { + destination, + target: Some(tgt), + unwind: UnwindAction::Cleanup(_), + .. + } = data.terminator().kind + { + assert!(!self.patch.is_term_patched(bb)); + + let loc = Location { block: tgt, statement_index: 0 }; + let path = self.move_data().rev_lookup.find(destination.as_ref()); + on_lookup_result_bits(self.move_data(), path, |child| { + self.set_drop_flag(loc, child, DropFlagState::Present) + }); + } + } + } + + fn drop_flags_for_args(&mut self) { + let loc = Location::START; + rustc_mir_dataflow::drop_flag_effects_for_function_entry( + self.body, + &self.env.move_data, + |path, ds| { + self.set_drop_flag(loc, path, ds); + }, + ) + } + + fn drop_flags_for_locs(&mut self) { + // We intentionally iterate only over the *old* basic blocks. + // + // Basic blocks created by drop elaboration update their + // drop flags by themselves, to avoid the drop flags being + // clobbered before they are read. + + for (bb, data) in self.body.basic_blocks.iter_enumerated() { + debug!("drop_flags_for_locs({:?})", data); + for i in 0..(data.statements.len() + 1) { + debug!("drop_flag_for_locs: stmt {}", i); + if i == data.statements.len() { + match data.terminator().kind { + TerminatorKind::Drop { .. } => { + // drop elaboration should handle that by itself + continue; + } + TerminatorKind::UnwindResume => { + // It is possible for `Resume` to be patched + // (in particular it can be patched to be replaced with + // a Goto; see `MirPatch::new`). + } + _ => { + assert!(!self.patch.is_term_patched(bb)); + } + } + } + let loc = Location { block: bb, statement_index: i }; + rustc_mir_dataflow::drop_flag_effects_for_location( + self.body, + &self.env.move_data, + loc, + |path, ds| self.set_drop_flag(loc, path, ds), + ) + } + + // There may be a critical edge after this call, + // so mark the return as initialized *before* the + // call. + if let TerminatorKind::Call { + destination, + target: Some(_), + unwind: + UnwindAction::Continue | UnwindAction::Unreachable | UnwindAction::Terminate(_), + .. + } = data.terminator().kind + { + assert!(!self.patch.is_term_patched(bb)); + + let loc = Location { block: bb, statement_index: data.statements.len() }; + let path = self.move_data().rev_lookup.find(destination.as_ref()); + on_lookup_result_bits(self.move_data(), path, |child| { + self.set_drop_flag(loc, child, DropFlagState::Present) + }); + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/errors.rs b/compiler/rustc_mir_transform/src/errors.rs new file mode 100644 index 00000000000..cffa0183fa7 --- /dev/null +++ b/compiler/rustc_mir_transform/src/errors.rs @@ -0,0 +1,185 @@ +use rustc_errors::codes::*; +use rustc_errors::{Diag, LintDiagnostic}; +use rustc_macros::{Diagnostic, LintDiagnostic, Subdiagnostic}; +use rustc_middle::mir::AssertKind; +use rustc_middle::ty::TyCtxt; +use rustc_session::lint::{self, Lint}; +use rustc_span::def_id::DefId; +use rustc_span::{Ident, Span, Symbol}; + +use crate::fluent_generated as fluent; + +#[derive(LintDiagnostic)] +#[diag(mir_transform_unconditional_recursion)] +#[help] +pub(crate) struct UnconditionalRecursion { + #[label] + pub(crate) span: Span, + #[label(mir_transform_unconditional_recursion_call_site_label)] + pub(crate) call_sites: Vec<Span>, +} + +#[derive(Diagnostic)] +#[diag(mir_transform_force_inline_attr)] +#[note] +pub(crate) struct InvalidForceInline { + #[primary_span] + pub attr_span: Span, + #[label(mir_transform_callee)] + pub callee_span: Span, + pub callee: String, + pub reason: &'static str, +} + +#[derive(LintDiagnostic)] +pub(crate) enum ConstMutate { + #[diag(mir_transform_const_modify)] + #[note] + Modify { + #[note(mir_transform_const_defined_here)] + konst: Span, + }, + #[diag(mir_transform_const_mut_borrow)] + #[note] + #[note(mir_transform_note2)] + MutBorrow { + #[note(mir_transform_note3)] + method_call: Option<Span>, + #[note(mir_transform_const_defined_here)] + konst: Span, + }, +} + +#[derive(Diagnostic)] +#[diag(mir_transform_unaligned_packed_ref, code = E0793)] +#[note] +#[note(mir_transform_note_ub)] +#[help] +pub(crate) struct UnalignedPackedRef { + #[primary_span] + pub span: Span, +} + +#[derive(Diagnostic)] +#[diag(mir_transform_unknown_pass_name)] +pub(crate) struct UnknownPassName<'a> { + pub(crate) name: &'a str, +} + +pub(crate) struct AssertLint<P> { + pub span: Span, + pub assert_kind: AssertKind<P>, + pub lint_kind: AssertLintKind, +} + +pub(crate) enum AssertLintKind { + ArithmeticOverflow, + UnconditionalPanic, +} + +impl<'a, P: std::fmt::Debug> LintDiagnostic<'a, ()> for AssertLint<P> { + fn decorate_lint<'b>(self, diag: &'b mut Diag<'a, ()>) { + diag.primary_message(match self.lint_kind { + AssertLintKind::ArithmeticOverflow => fluent::mir_transform_arithmetic_overflow, + AssertLintKind::UnconditionalPanic => fluent::mir_transform_operation_will_panic, + }); + let label = self.assert_kind.diagnostic_message(); + self.assert_kind.add_args(&mut |name, value| { + diag.arg(name, value); + }); + diag.span_label(self.span, label); + } +} + +impl AssertLintKind { + pub(crate) fn lint(&self) -> &'static Lint { + match self { + AssertLintKind::ArithmeticOverflow => lint::builtin::ARITHMETIC_OVERFLOW, + AssertLintKind::UnconditionalPanic => lint::builtin::UNCONDITIONAL_PANIC, + } + } +} + +#[derive(LintDiagnostic)] +#[diag(mir_transform_ffi_unwind_call)] +pub(crate) struct FfiUnwindCall { + #[label(mir_transform_ffi_unwind_call)] + pub span: Span, + pub foreign: bool, +} + +#[derive(LintDiagnostic)] +#[diag(mir_transform_fn_item_ref)] +pub(crate) struct FnItemRef { + #[suggestion(code = "{sugg}", applicability = "unspecified")] + pub span: Span, + pub sugg: String, + pub ident: Ident, +} + +#[derive(Diagnostic)] +#[diag(mir_transform_exceeds_mcdc_test_vector_limit)] +pub(crate) struct MCDCExceedsTestVectorLimit { + #[primary_span] + pub(crate) span: Span, + pub(crate) max_num_test_vectors: usize, +} + +pub(crate) struct MustNotSupend<'a, 'tcx> { + pub tcx: TyCtxt<'tcx>, + pub yield_sp: Span, + pub reason: Option<MustNotSuspendReason>, + pub src_sp: Span, + pub pre: &'a str, + pub def_id: DefId, + pub post: &'a str, +} + +// Needed for def_path_str +impl<'a> LintDiagnostic<'a, ()> for MustNotSupend<'_, '_> { + fn decorate_lint<'b>(self, diag: &'b mut rustc_errors::Diag<'a, ()>) { + diag.primary_message(fluent::mir_transform_must_not_suspend); + diag.span_label(self.yield_sp, fluent::_subdiag::label); + if let Some(reason) = self.reason { + diag.subdiagnostic(reason); + } + diag.span_help(self.src_sp, fluent::_subdiag::help); + diag.arg("pre", self.pre); + diag.arg("def_path", self.tcx.def_path_str(self.def_id)); + diag.arg("post", self.post); + } +} + +#[derive(Subdiagnostic)] +#[note(mir_transform_note)] +pub(crate) struct MustNotSuspendReason { + #[primary_span] + pub span: Span, + pub reason: String, +} + +#[derive(Diagnostic)] +#[diag(mir_transform_force_inline)] +#[note] +pub(crate) struct ForceInlineFailure { + #[label(mir_transform_caller)] + pub caller_span: Span, + #[label(mir_transform_callee)] + pub callee_span: Span, + #[label(mir_transform_attr)] + pub attr_span: Span, + #[primary_span] + #[label(mir_transform_call)] + pub call_span: Span, + pub callee: String, + pub caller: String, + pub reason: &'static str, + #[subdiagnostic] + pub justification: Option<ForceInlineJustification>, +} + +#[derive(Subdiagnostic)] +#[note(mir_transform_force_inline_justification)] +pub(crate) struct ForceInlineJustification { + pub sym: Symbol, +} diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs new file mode 100644 index 00000000000..abbff1c48dd --- /dev/null +++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs @@ -0,0 +1,145 @@ +use rustc_abi::ExternAbi; +use rustc_hir::def_id::{LOCAL_CRATE, LocalDefId}; +use rustc_middle::mir::*; +use rustc_middle::query::{LocalCrate, Providers}; +use rustc_middle::ty::{self, TyCtxt, layout}; +use rustc_middle::{bug, span_bug}; +use rustc_session::lint::builtin::FFI_UNWIND_CALLS; +use rustc_target::spec::PanicStrategy; +use tracing::debug; + +use crate::errors; + +// Check if the body of this def_id can possibly leak a foreign unwind into Rust code. +fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { + debug!("has_ffi_unwind_calls({local_def_id:?})"); + + // Only perform check on functions because constants cannot call FFI functions. + let def_id = local_def_id.to_def_id(); + let kind = tcx.def_kind(def_id); + if !kind.is_fn_like() { + return false; + } + + let body = &*tcx.mir_built(local_def_id).borrow(); + + let body_ty = tcx.type_of(def_id).skip_binder(); + let body_abi = match body_ty.kind() { + ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), + ty::Closure(..) => ExternAbi::RustCall, + ty::CoroutineClosure(..) => ExternAbi::RustCall, + ty::Coroutine(..) => ExternAbi::Rust, + ty::Error(_) => return false, + _ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty), + }; + let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi); + + // Foreign unwinds cannot leak past functions that themselves cannot unwind. + if !body_can_unwind { + return false; + } + + let mut tainted = false; + + for block in body.basic_blocks.iter() { + if block.is_cleanup { + continue; + } + let Some(terminator) = &block.terminator else { continue }; + let TerminatorKind::Call { func, .. } = &terminator.kind else { continue }; + + let ty = func.ty(body, tcx); + let sig = ty.fn_sig(tcx); + + // Rust calls cannot themselves create foreign unwinds. + // We assume this is true for intrinsics as well. + if sig.abi().is_rustic_abi() { + continue; + }; + + let fn_def_id = match ty.kind() { + ty::FnPtr(..) => None, + &ty::FnDef(def_id, _) => { + // Rust calls cannot themselves create foreign unwinds (even if they use a non-Rust + // ABI). So the leak of the foreign unwind into Rust can only be elsewhere, not + // here. + if !tcx.is_foreign_item(def_id) { + continue; + } + Some(def_id) + } + _ => bug!("invalid callee of type {:?}", ty), + }; + + if layout::fn_can_unwind(tcx, fn_def_id, sig.abi()) { + // We have detected a call that can possibly leak foreign unwind. + // + // Because the function body itself can unwind, we are not aborting this function call + // upon unwind, so this call can possibly leak foreign unwind into Rust code if the + // panic runtime linked is panic-abort. + + let lint_root = body.source_scopes[terminator.source_info.scope] + .local_data + .as_ref() + .unwrap_crate_local() + .lint_root; + let span = terminator.source_info.span; + + let foreign = fn_def_id.is_some(); + tcx.emit_node_span_lint( + FFI_UNWIND_CALLS, + lint_root, + span, + errors::FfiUnwindCall { span, foreign }, + ); + + tainted = true; + } + } + + tainted +} + +fn required_panic_strategy(tcx: TyCtxt<'_>, _: LocalCrate) -> Option<PanicStrategy> { + if tcx.is_panic_runtime(LOCAL_CRATE) { + return Some(tcx.sess.panic_strategy()); + } + + if tcx.sess.panic_strategy() == PanicStrategy::Abort { + return Some(PanicStrategy::Abort); + } + + for def_id in tcx.hir_body_owners() { + if tcx.has_ffi_unwind_calls(def_id) { + // Given that this crate is compiled in `-C panic=unwind`, the `AbortUnwindingCalls` + // MIR pass will not be run on FFI-unwind call sites, therefore a foreign exception + // can enter Rust through these sites. + // + // On the other hand, crates compiled with `-C panic=abort` expects that all Rust + // functions cannot unwind (whether it's caused by Rust panic or foreign exception), + // and this expectation mismatch can cause unsoundness (#96926). + // + // To address this issue, we enforce that if FFI-unwind calls are used in a crate + // compiled with `panic=unwind`, then the final panic strategy must be `panic=unwind`. + // This will ensure that no crates will have wrong unwindability assumption. + // + // It should be noted that it is okay to link `panic=unwind` into a `panic=abort` + // program if it contains no FFI-unwind calls. In such case foreign exception can only + // enter Rust in a `panic=abort` crate, which will lead to an abort. There will also + // be no exceptions generated from Rust, so the assumption which `panic=abort` crates + // make, that no Rust function can unwind, indeed holds for crates compiled with + // `panic=unwind` as well. In such case this function returns `None`, indicating that + // the crate does not require a particular final panic strategy, and can be freely + // linked to crates with either strategy (we need such ability for libstd and its + // dependencies). + return Some(PanicStrategy::Unwind); + } + } + + // This crate can be linked with either runtime. + None +} + +pub(crate) fn provide(providers: &mut Providers) { + *providers = Providers { has_ffi_unwind_calls, required_panic_strategy, ..*providers }; +} diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs new file mode 100644 index 00000000000..38b5ccdb32e --- /dev/null +++ b/compiler/rustc_mir_transform/src/function_item_references.rs @@ -0,0 +1,190 @@ +use itertools::Itertools; +use rustc_abi::ExternAbi; +use rustc_hir::def_id::DefId; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, EarlyBinder, GenericArgsRef, Ty, TyCtxt}; +use rustc_session::lint::builtin::FUNCTION_ITEM_REFERENCES; +use rustc_span::source_map::Spanned; +use rustc_span::{Span, sym}; + +use crate::errors; + +pub(super) struct FunctionItemReferences; + +impl<'tcx> crate::MirLint<'tcx> for FunctionItemReferences { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + let mut checker = FunctionItemRefChecker { tcx, body }; + checker.visit_body(body); + } +} + +struct FunctionItemRefChecker<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + body: &'a Body<'tcx>, +} + +impl<'tcx> Visitor<'tcx> for FunctionItemRefChecker<'_, 'tcx> { + /// Emits a lint for function reference arguments bound by `fmt::Pointer` or passed to + /// `transmute`. This only handles arguments in calls outside macro expansions to avoid double + /// counting function references formatted as pointers by macros. + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + if let TerminatorKind::Call { + func, + args, + destination: _, + target: _, + unwind: _, + call_source: _, + fn_span: _, + } = &terminator.kind + { + let source_info = *self.body.source_info(location); + let func_ty = func.ty(self.body, self.tcx); + if let ty::FnDef(def_id, args_ref) = *func_ty.kind() { + // Handle calls to `transmute` + if self.tcx.is_diagnostic_item(sym::transmute, def_id) { + let arg_ty = args[0].node.ty(self.body, self.tcx); + for inner_ty in arg_ty.walk().filter_map(|arg| arg.as_type()) { + if let Some((fn_id, fn_args)) = FunctionItemRefChecker::is_fn_ref(inner_ty) + { + let span = self.nth_arg_span(args, 0); + self.emit_lint(fn_id, fn_args, source_info, span); + } + } + } else { + self.check_bound_args(def_id, args_ref, args, source_info); + } + } + } + self.super_terminator(terminator, location); + } +} + +impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { + /// Emits a lint for function reference arguments bound by `fmt::Pointer` in calls to the + /// function defined by `def_id` with the generic parameters `args_ref`. + fn check_bound_args( + &self, + def_id: DefId, + args_ref: GenericArgsRef<'tcx>, + args: &[Spanned<Operand<'tcx>>], + source_info: SourceInfo, + ) { + let param_env = self.tcx.param_env(def_id); + let bounds = param_env.caller_bounds(); + for bound in bounds { + if let Some(bound_ty) = self.is_pointer_trait(bound) { + // Get the argument types as they appear in the function signature. + let arg_defs = + self.tcx.fn_sig(def_id).instantiate_identity().skip_binder().inputs(); + for (arg_num, arg_def) in arg_defs.iter().enumerate() { + // For all types reachable from the argument type in the fn sig + for inner_ty in arg_def.walk().filter_map(|arg| arg.as_type()) { + // If the inner type matches the type bound by `Pointer` + if inner_ty == bound_ty { + // Do an instantiation using the parameters from the callsite + let instantiated_ty = + EarlyBinder::bind(inner_ty).instantiate(self.tcx, args_ref); + if let Some((fn_id, fn_args)) = + FunctionItemRefChecker::is_fn_ref(instantiated_ty) + { + let mut span = self.nth_arg_span(args, arg_num); + if span.from_expansion() { + // The operand's ctxt wouldn't display the lint since it's + // inside a macro so we have to use the callsite's ctxt. + let callsite_ctxt = span.source_callsite().ctxt(); + span = span.with_ctxt(callsite_ctxt); + } + self.emit_lint(fn_id, fn_args, source_info, span); + } + } + } + } + } + } + } + + /// If the given predicate is the trait `fmt::Pointer`, returns the bound parameter type. + fn is_pointer_trait(&self, bound: ty::Clause<'tcx>) -> Option<Ty<'tcx>> { + if let ty::ClauseKind::Trait(predicate) = bound.kind().skip_binder() { + self.tcx + .is_diagnostic_item(sym::Pointer, predicate.def_id()) + .then(|| predicate.trait_ref.self_ty()) + } else { + None + } + } + + /// If a type is a reference or raw pointer to the anonymous type of a function definition, + /// returns that function's `DefId` and `GenericArgsRef`. + fn is_fn_ref(ty: Ty<'tcx>) -> Option<(DefId, GenericArgsRef<'tcx>)> { + let referent_ty = match ty.kind() { + ty::Ref(_, referent_ty, _) => Some(referent_ty), + ty::RawPtr(referent_ty, _) => Some(referent_ty), + _ => None, + }; + referent_ty + .map(|ref_ty| { + if let ty::FnDef(def_id, args_ref) = *ref_ty.kind() { + Some((def_id, args_ref)) + } else { + None + } + }) + .unwrap_or(None) + } + + fn nth_arg_span(&self, args: &[Spanned<Operand<'tcx>>], n: usize) -> Span { + match &args[n].node { + Operand::Copy(place) | Operand::Move(place) => { + self.body.local_decls[place.local].source_info.span + } + Operand::Constant(constant) => constant.span, + } + } + + fn emit_lint( + &self, + fn_id: DefId, + fn_args: GenericArgsRef<'tcx>, + source_info: SourceInfo, + span: Span, + ) { + let lint_root = self.body.source_scopes[source_info.scope] + .local_data + .as_ref() + .unwrap_crate_local() + .lint_root; + // FIXME: use existing printing routines to print the function signature + let fn_sig = self.tcx.fn_sig(fn_id).instantiate(self.tcx, fn_args); + let unsafety = fn_sig.safety().prefix_str(); + let abi = match fn_sig.abi() { + ExternAbi::Rust => String::from(""), + other_abi => format!("extern {other_abi} "), + }; + let ident = self.tcx.item_ident(fn_id); + let ty_params = fn_args.types().map(|ty| format!("{ty}")); + let const_params = fn_args.consts().map(|c| format!("{c}")); + let params = ty_params.chain(const_params).join(", "); + let num_args = fn_sig.inputs().map_bound(|inputs| inputs.len()).skip_binder(); + let variadic = if fn_sig.c_variadic() { ", ..." } else { "" }; + let ret = if fn_sig.output().skip_binder().is_unit() { "" } else { " -> _" }; + let sugg = format!( + "{} as {}{}fn({}{}){}", + if params.is_empty() { ident.to_string() } else { format!("{ident}::<{params}>") }, + unsafety, + abi, + vec!["_"; num_args].join(", "), + variadic, + ret, + ); + + self.tcx.emit_node_span_lint( + FUNCTION_ITEM_REFERENCES, + lint_root, + span, + errors::FnItemRef { span, sugg, ident }, + ); + } +} diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs new file mode 100644 index 00000000000..b17b7f45000 --- /dev/null +++ b/compiler/rustc_mir_transform/src/gvn.rs @@ -0,0 +1,1854 @@ +//! Global value numbering. +//! +//! MIR may contain repeated and/or redundant computations. The objective of this pass is to detect +//! such redundancies and re-use the already-computed result when possible. +//! +//! From those assignments, we construct a mapping `VnIndex -> Vec<(Local, Location)>` of available +//! values, the locals in which they are stored, and the assignment location. +//! +//! We traverse all assignments `x = rvalue` and operands. +//! +//! For each SSA one, we compute a symbolic representation of values that are assigned to SSA +//! locals. This symbolic representation is defined by the `Value` enum. Each produced instance of +//! `Value` is interned as a `VnIndex`, which allows us to cheaply compute identical values. +//! +//! For each non-SSA +//! one, we compute the `VnIndex` of the rvalue. If this `VnIndex` is associated to a constant, we +//! replace the rvalue/operand by that constant. Otherwise, if there is an SSA local `y` +//! associated to this `VnIndex`, and if its definition location strictly dominates the assignment +//! to `x`, we replace the assignment by `x = y`. +//! +//! By opportunity, this pass simplifies some `Rvalue`s based on the accumulated knowledge. +//! +//! # Operational semantic +//! +//! Operationally, this pass attempts to prove bitwise equality between locals. Given this MIR: +//! ```ignore (MIR) +//! _a = some value // has VnIndex i +//! // some MIR +//! _b = some other value // also has VnIndex i +//! ``` +//! +//! We consider it to be replacable by: +//! ```ignore (MIR) +//! _a = some value // has VnIndex i +//! // some MIR +//! _c = some other value // also has VnIndex i +//! assume(_a bitwise equal to _c) // follows from having the same VnIndex +//! _b = _a // follows from the `assume` +//! ``` +//! +//! Which is simplifiable to: +//! ```ignore (MIR) +//! _a = some value // has VnIndex i +//! // some MIR +//! _b = _a +//! ``` +//! +//! # Handling of references +//! +//! We handle references by assigning a different "provenance" index to each Ref/RawPtr rvalue. +//! This ensure that we do not spuriously merge borrows that should not be merged. Meanwhile, we +//! consider all the derefs of an immutable reference to a freeze type to give the same value: +//! ```ignore (MIR) +//! _a = *_b // _b is &Freeze +//! _c = *_b // replaced by _c = _a +//! ``` +//! +//! # Determinism of constant propagation +//! +//! When registering a new `Value`, we attempt to opportunistically evaluate it as a constant. +//! The evaluated form is inserted in `evaluated` as an `OpTy` or `None` if evaluation failed. +//! +//! The difficulty is non-deterministic evaluation of MIR constants. Some `Const` can have +//! different runtime values each time they are evaluated. This is the case with +//! `Const::Slice` which have a new pointer each time they are evaluated, and constants that +//! contain a fn pointer (`AllocId` pointing to a `GlobalAlloc::Function`) pointing to a different +//! symbol in each codegen unit. +//! +//! Meanwhile, we want to be able to read indirect constants. For instance: +//! ``` +//! static A: &'static &'static u8 = &&63; +//! fn foo() -> u8 { +//! **A // We want to replace by 63. +//! } +//! fn bar() -> u8 { +//! b"abc"[1] // We want to replace by 'b'. +//! } +//! ``` +//! +//! The `Value::Constant` variant stores a possibly unevaluated constant. Evaluating that constant +//! may be non-deterministic. When that happens, we assign a disambiguator to ensure that we do not +//! merge the constants. See `duplicate_slice` test in `gvn.rs`. +//! +//! Second, when writing constants in MIR, we do not write `Const::Slice` or `Const` +//! that contain `AllocId`s. + +use std::borrow::Cow; + +use either::Either; +use rustc_abi::{self as abi, BackendRepr, FIRST_VARIANT, FieldIdx, Primitive, Size, VariantIdx}; +use rustc_const_eval::const_eval::DummyMachine; +use rustc_const_eval::interpret::{ + ImmTy, Immediate, InterpCx, MemPlaceMeta, MemoryKind, OpTy, Projectable, Scalar, + intern_const_alloc_for_constprop, +}; +use rustc_data_structures::fx::{FxIndexSet, MutableValues}; +use rustc_data_structures::graph::dominators::Dominators; +use rustc_hir::def::DefKind; +use rustc_index::bit_set::DenseBitSet; +use rustc_index::{IndexVec, newtype_index}; +use rustc_middle::bug; +use rustc_middle::mir::interpret::GlobalAlloc; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::layout::HasTypingEnv; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_span::DUMMY_SP; +use rustc_span::def_id::DefId; +use smallvec::SmallVec; +use tracing::{debug, instrument, trace}; + +use crate::ssa::SsaLocals; + +pub(super) struct GVN; + +impl<'tcx> crate::MirPass<'tcx> for GVN { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 2 + } + + #[instrument(level = "trace", skip(self, tcx, body))] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); + + let typing_env = body.typing_env(tcx); + let ssa = SsaLocals::new(tcx, body, typing_env); + // Clone dominators because we need them while mutating the body. + let dominators = body.basic_blocks.dominators().clone(); + + let mut state = VnState::new(tcx, body, typing_env, &ssa, dominators, &body.local_decls); + + for local in body.args_iter().filter(|&local| ssa.is_ssa(local)) { + let opaque = state.new_opaque(); + state.assign(local, opaque); + } + + let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec(); + for bb in reverse_postorder { + let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb]; + state.visit_basic_block_data(bb, data); + } + + // For each local that is reused (`y` above), we remove its storage statements do avoid any + // difficulty. Those locals are SSA, so should be easy to optimize by LLVM without storage + // statements. + StorageRemover { tcx, reused_locals: state.reused_locals }.visit_body_preserves_cfg(body); + } + + fn is_required(&self) -> bool { + false + } +} + +newtype_index! { + struct VnIndex {} +} + +/// Computing the aggregate's type can be quite slow, so we only keep the minimal amount of +/// information to reconstruct it when needed. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +enum AggregateTy<'tcx> { + /// Invariant: this must not be used for an empty array. + Array, + Tuple, + Def(DefId, ty::GenericArgsRef<'tcx>), + RawPtr { + /// Needed for cast propagation. + data_pointer_ty: Ty<'tcx>, + /// The data pointer can be anything thin, so doesn't determine the output. + output_pointer_ty: Ty<'tcx>, + }, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +enum AddressKind { + Ref(BorrowKind), + Address(RawPtrKind), +} + +#[derive(Debug, PartialEq, Eq, Hash)] +enum Value<'tcx> { + // Root values. + /// Used to represent values we know nothing about. + /// The `usize` is a counter incremented by `new_opaque`. + Opaque(usize), + /// Evaluated or unevaluated constant value. + Constant { + value: Const<'tcx>, + /// Some constants do not have a deterministic value. To avoid merging two instances of the + /// same `Const`, we assign them an additional integer index. + // `disambiguator` is 0 iff the constant is deterministic. + disambiguator: usize, + }, + /// An aggregate value, either tuple/closure/struct/enum. + /// This does not contain unions, as we cannot reason with the value. + Aggregate(AggregateTy<'tcx>, VariantIdx, Vec<VnIndex>), + /// This corresponds to a `[value; count]` expression. + Repeat(VnIndex, ty::Const<'tcx>), + /// The address of a place. + Address { + place: Place<'tcx>, + kind: AddressKind, + /// Give each borrow and pointer a different provenance, so we don't merge them. + provenance: usize, + }, + + // Extractions. + /// This is the *value* obtained by projecting another value. + Projection(VnIndex, ProjectionElem<VnIndex, Ty<'tcx>>), + /// Discriminant of the given value. + Discriminant(VnIndex), + /// Length of an array or slice. + Len(VnIndex), + + // Operations. + NullaryOp(NullOp<'tcx>, Ty<'tcx>), + UnaryOp(UnOp, VnIndex), + BinaryOp(BinOp, VnIndex, VnIndex), + Cast { + kind: CastKind, + value: VnIndex, + from: Ty<'tcx>, + to: Ty<'tcx>, + }, +} + +struct VnState<'body, 'tcx> { + tcx: TyCtxt<'tcx>, + ecx: InterpCx<'tcx, DummyMachine>, + local_decls: &'body LocalDecls<'tcx>, + /// Value stored in each local. + locals: IndexVec<Local, Option<VnIndex>>, + /// Locals that are assigned that value. + // This vector does not hold all the values of `VnIndex` that we create. + rev_locals: IndexVec<VnIndex, SmallVec<[Local; 1]>>, + values: FxIndexSet<Value<'tcx>>, + /// Values evaluated as constants if possible. + evaluated: IndexVec<VnIndex, Option<OpTy<'tcx>>>, + /// Counter to generate different values. + next_opaque: usize, + /// Cache the deref values. + derefs: Vec<VnIndex>, + ssa: &'body SsaLocals, + dominators: Dominators<BasicBlock>, + reused_locals: DenseBitSet<Local>, +} + +impl<'body, 'tcx> VnState<'body, 'tcx> { + fn new( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + ssa: &'body SsaLocals, + dominators: Dominators<BasicBlock>, + local_decls: &'body LocalDecls<'tcx>, + ) -> Self { + // Compute a rough estimate of the number of values in the body from the number of + // statements. This is meant to reduce the number of allocations, but it's all right if + // we miss the exact amount. We estimate based on 2 values per statement (one in LHS and + // one in RHS) and 4 values per terminator (for call operands). + let num_values = + 2 * body.basic_blocks.iter().map(|bbdata| bbdata.statements.len()).sum::<usize>() + + 4 * body.basic_blocks.len(); + VnState { + tcx, + ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine), + local_decls, + locals: IndexVec::from_elem(None, local_decls), + rev_locals: IndexVec::with_capacity(num_values), + values: FxIndexSet::with_capacity_and_hasher(num_values, Default::default()), + evaluated: IndexVec::with_capacity(num_values), + next_opaque: 1, + derefs: Vec::new(), + ssa, + dominators, + reused_locals: DenseBitSet::new_empty(local_decls.len()), + } + } + + fn typing_env(&self) -> ty::TypingEnv<'tcx> { + self.ecx.typing_env() + } + + #[instrument(level = "trace", skip(self), ret)] + fn insert(&mut self, value: Value<'tcx>) -> VnIndex { + let (index, new) = self.values.insert_full(value); + let index = VnIndex::from_usize(index); + if new { + // Grow `evaluated` and `rev_locals` here to amortize the allocations. + let evaluated = self.eval_to_const(index); + let _index = self.evaluated.push(evaluated); + debug_assert_eq!(index, _index); + let _index = self.rev_locals.push(SmallVec::new()); + debug_assert_eq!(index, _index); + } + index + } + + fn next_opaque(&mut self) -> usize { + let next_opaque = self.next_opaque; + self.next_opaque += 1; + next_opaque + } + + /// Create a new `Value` for which we have no information at all, except that it is distinct + /// from all the others. + #[instrument(level = "trace", skip(self), ret)] + fn new_opaque(&mut self) -> VnIndex { + let value = Value::Opaque(self.next_opaque()); + self.insert(value) + } + + /// Create a new `Value::Address` distinct from all the others. + #[instrument(level = "trace", skip(self), ret)] + fn new_pointer(&mut self, place: Place<'tcx>, kind: AddressKind) -> VnIndex { + let value = Value::Address { place, kind, provenance: self.next_opaque() }; + self.insert(value) + } + + fn get(&self, index: VnIndex) -> &Value<'tcx> { + self.values.get_index(index.as_usize()).unwrap() + } + + /// Record that `local` is assigned `value`. `local` must be SSA. + #[instrument(level = "trace", skip(self))] + fn assign(&mut self, local: Local, value: VnIndex) { + debug_assert!(self.ssa.is_ssa(local)); + self.locals[local] = Some(value); + self.rev_locals[value].push(local); + } + + fn insert_constant(&mut self, value: Const<'tcx>) -> VnIndex { + let disambiguator = if value.is_deterministic() { + // The constant is deterministic, no need to disambiguate. + 0 + } else { + // Multiple mentions of this constant will yield different values, + // so assign a different `disambiguator` to ensure they do not get the same `VnIndex`. + let disambiguator = self.next_opaque(); + // `disambiguator: 0` means deterministic. + debug_assert_ne!(disambiguator, 0); + disambiguator + }; + self.insert(Value::Constant { value, disambiguator }) + } + + fn insert_bool(&mut self, flag: bool) -> VnIndex { + // Booleans are deterministic. + let value = Const::from_bool(self.tcx, flag); + debug_assert!(value.is_deterministic()); + self.insert(Value::Constant { value, disambiguator: 0 }) + } + + fn insert_scalar(&mut self, scalar: Scalar, ty: Ty<'tcx>) -> VnIndex { + // Scalars are deterministic. + let value = Const::from_scalar(self.tcx, scalar, ty); + debug_assert!(value.is_deterministic()); + self.insert(Value::Constant { value, disambiguator: 0 }) + } + + fn insert_tuple(&mut self, values: Vec<VnIndex>) -> VnIndex { + self.insert(Value::Aggregate(AggregateTy::Tuple, VariantIdx::ZERO, values)) + } + + fn insert_deref(&mut self, value: VnIndex) -> VnIndex { + let value = self.insert(Value::Projection(value, ProjectionElem::Deref)); + self.derefs.push(value); + value + } + + fn invalidate_derefs(&mut self) { + for deref in std::mem::take(&mut self.derefs) { + let opaque = self.next_opaque(); + *self.values.get_index_mut2(deref.index()).unwrap() = Value::Opaque(opaque); + } + } + + #[instrument(level = "trace", skip(self), ret)] + fn eval_to_const(&mut self, value: VnIndex) -> Option<OpTy<'tcx>> { + use Value::*; + let op = match *self.get(value) { + Opaque(_) => return None, + // Do not bother evaluating repeat expressions. This would uselessly consume memory. + Repeat(..) => return None, + + Constant { ref value, disambiguator: _ } => { + self.ecx.eval_mir_constant(value, DUMMY_SP, None).discard_err()? + } + Aggregate(kind, variant, ref fields) => { + let fields = fields + .iter() + .map(|&f| self.evaluated[f].as_ref()) + .collect::<Option<Vec<_>>>()?; + let ty = match kind { + AggregateTy::Array => { + assert!(fields.len() > 0); + Ty::new_array(self.tcx, fields[0].layout.ty, fields.len() as u64) + } + AggregateTy::Tuple => { + Ty::new_tup_from_iter(self.tcx, fields.iter().map(|f| f.layout.ty)) + } + AggregateTy::Def(def_id, args) => { + self.tcx.type_of(def_id).instantiate(self.tcx, args) + } + AggregateTy::RawPtr { output_pointer_ty, .. } => output_pointer_ty, + }; + let variant = if ty.is_enum() { Some(variant) } else { None }; + let ty = self.ecx.layout_of(ty).ok()?; + if ty.is_zst() { + ImmTy::uninit(ty).into() + } else if matches!(kind, AggregateTy::RawPtr { .. }) { + // Pointers don't have fields, so don't `project_field` them. + let data = self.ecx.read_pointer(fields[0]).discard_err()?; + let meta = if fields[1].layout.is_zst() { + MemPlaceMeta::None + } else { + MemPlaceMeta::Meta(self.ecx.read_scalar(fields[1]).discard_err()?) + }; + let ptr_imm = Immediate::new_pointer_with_meta(data, meta, &self.ecx); + ImmTy::from_immediate(ptr_imm, ty).into() + } else if matches!( + ty.backend_repr, + BackendRepr::Scalar(..) | BackendRepr::ScalarPair(..) + ) { + let dest = self.ecx.allocate(ty, MemoryKind::Stack).discard_err()?; + let variant_dest = if let Some(variant) = variant { + self.ecx.project_downcast(&dest, variant).discard_err()? + } else { + dest.clone() + }; + for (field_index, op) in fields.into_iter().enumerate() { + let field_dest = self + .ecx + .project_field(&variant_dest, FieldIdx::from_usize(field_index)) + .discard_err()?; + self.ecx.copy_op(op, &field_dest).discard_err()?; + } + self.ecx + .write_discriminant(variant.unwrap_or(FIRST_VARIANT), &dest) + .discard_err()?; + self.ecx + .alloc_mark_immutable(dest.ptr().provenance.unwrap().alloc_id()) + .discard_err()?; + dest.into() + } else { + return None; + } + } + + Projection(base, elem) => { + let value = self.evaluated[base].as_ref()?; + let elem = match elem { + ProjectionElem::Deref => ProjectionElem::Deref, + ProjectionElem::Downcast(name, read_variant) => { + ProjectionElem::Downcast(name, read_variant) + } + ProjectionElem::Field(f, ty) => ProjectionElem::Field(f, ty), + ProjectionElem::ConstantIndex { offset, min_length, from_end } => { + ProjectionElem::ConstantIndex { offset, min_length, from_end } + } + ProjectionElem::Subslice { from, to, from_end } => { + ProjectionElem::Subslice { from, to, from_end } + } + ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty), + ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty), + ProjectionElem::UnwrapUnsafeBinder(ty) => { + ProjectionElem::UnwrapUnsafeBinder(ty) + } + // This should have been replaced by a `ConstantIndex` earlier. + ProjectionElem::Index(_) => return None, + }; + self.ecx.project(value, elem).discard_err()? + } + Address { place, kind, provenance: _ } => { + if !place.is_indirect_first_projection() { + return None; + } + let local = self.locals[place.local]?; + let pointer = self.evaluated[local].as_ref()?; + let mut mplace = self.ecx.deref_pointer(pointer).discard_err()?; + for proj in place.projection.iter().skip(1) { + // We have no call stack to associate a local with a value, so we cannot + // interpret indexing. + if matches!(proj, ProjectionElem::Index(_)) { + return None; + } + mplace = self.ecx.project(&mplace, proj).discard_err()?; + } + let pointer = mplace.to_ref(&self.ecx); + let ty = match kind { + AddressKind::Ref(bk) => Ty::new_ref( + self.tcx, + self.tcx.lifetimes.re_erased, + mplace.layout.ty, + bk.to_mutbl_lossy(), + ), + AddressKind::Address(mutbl) => { + Ty::new_ptr(self.tcx, mplace.layout.ty, mutbl.to_mutbl_lossy()) + } + }; + let layout = self.ecx.layout_of(ty).ok()?; + ImmTy::from_immediate(pointer, layout).into() + } + + Discriminant(base) => { + let base = self.evaluated[base].as_ref()?; + let variant = self.ecx.read_discriminant(base).discard_err()?; + let discr_value = + self.ecx.discriminant_for_variant(base.layout.ty, variant).discard_err()?; + discr_value.into() + } + Len(slice) => { + let slice = self.evaluated[slice].as_ref()?; + let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); + let len = slice.len(&self.ecx).discard_err()?; + let imm = ImmTy::from_uint(len, usize_layout); + imm.into() + } + NullaryOp(null_op, ty) => { + let layout = self.ecx.layout_of(ty).ok()?; + if let NullOp::SizeOf | NullOp::AlignOf = null_op + && layout.is_unsized() + { + return None; + } + let val = match null_op { + NullOp::SizeOf => layout.size.bytes(), + NullOp::AlignOf => layout.align.abi.bytes(), + NullOp::OffsetOf(fields) => self + .ecx + .tcx + .offset_of_subfield(self.typing_env(), layout, fields.iter()) + .bytes(), + NullOp::UbChecks => return None, + NullOp::ContractChecks => return None, + }; + let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); + let imm = ImmTy::from_uint(val, usize_layout); + imm.into() + } + UnaryOp(un_op, operand) => { + let operand = self.evaluated[operand].as_ref()?; + let operand = self.ecx.read_immediate(operand).discard_err()?; + let val = self.ecx.unary_op(un_op, &operand).discard_err()?; + val.into() + } + BinaryOp(bin_op, lhs, rhs) => { + let lhs = self.evaluated[lhs].as_ref()?; + let lhs = self.ecx.read_immediate(lhs).discard_err()?; + let rhs = self.evaluated[rhs].as_ref()?; + let rhs = self.ecx.read_immediate(rhs).discard_err()?; + let val = self.ecx.binary_op(bin_op, &lhs, &rhs).discard_err()?; + val.into() + } + Cast { kind, value, from: _, to } => match kind { + CastKind::IntToInt | CastKind::IntToFloat => { + let value = self.evaluated[value].as_ref()?; + let value = self.ecx.read_immediate(value).discard_err()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.int_to_int_or_float(&value, to).discard_err()?; + res.into() + } + CastKind::FloatToFloat | CastKind::FloatToInt => { + let value = self.evaluated[value].as_ref()?; + let value = self.ecx.read_immediate(value).discard_err()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.float_to_float_or_int(&value, to).discard_err()?; + res.into() + } + CastKind::Transmute => { + let value = self.evaluated[value].as_ref()?; + let to = self.ecx.layout_of(to).ok()?; + // `offset` for immediates generally only supports projections that match the + // type of the immediate. However, as a HACK, we exploit that it can also do + // limited transmutes: it only works between types with the same layout, and + // cannot transmute pointers to integers. + if value.as_mplace_or_imm().is_right() { + let can_transmute = match (value.layout.backend_repr, to.backend_repr) { + (BackendRepr::Scalar(s1), BackendRepr::Scalar(s2)) => { + s1.size(&self.ecx) == s2.size(&self.ecx) + && !matches!(s1.primitive(), Primitive::Pointer(..)) + } + (BackendRepr::ScalarPair(a1, b1), BackendRepr::ScalarPair(a2, b2)) => { + a1.size(&self.ecx) == a2.size(&self.ecx) && + b1.size(&self.ecx) == b2.size(&self.ecx) && + // The alignment of the second component determines its offset, so that also needs to match. + b1.align(&self.ecx) == b2.align(&self.ecx) && + // None of the inputs may be a pointer. + !matches!(a1.primitive(), Primitive::Pointer(..)) + && !matches!(b1.primitive(), Primitive::Pointer(..)) + } + _ => false, + }; + if !can_transmute { + return None; + } + } + value.offset(Size::ZERO, to, &self.ecx).discard_err()? + } + CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize, _) => { + let src = self.evaluated[value].as_ref()?; + let to = self.ecx.layout_of(to).ok()?; + let dest = self.ecx.allocate(to, MemoryKind::Stack).discard_err()?; + self.ecx.unsize_into(src, to, &dest.clone().into()).discard_err()?; + self.ecx + .alloc_mark_immutable(dest.ptr().provenance.unwrap().alloc_id()) + .discard_err()?; + dest.into() + } + CastKind::FnPtrToPtr | CastKind::PtrToPtr => { + let src = self.evaluated[value].as_ref()?; + let src = self.ecx.read_immediate(src).discard_err()?; + let to = self.ecx.layout_of(to).ok()?; + let ret = self.ecx.ptr_to_ptr(&src, to).discard_err()?; + ret.into() + } + CastKind::PointerCoercion(ty::adjustment::PointerCoercion::UnsafeFnPointer, _) => { + let src = self.evaluated[value].as_ref()?; + let src = self.ecx.read_immediate(src).discard_err()?; + let to = self.ecx.layout_of(to).ok()?; + ImmTy::from_immediate(*src, to).into() + } + _ => return None, + }, + }; + Some(op) + } + + fn project( + &mut self, + place: PlaceRef<'tcx>, + value: VnIndex, + proj: PlaceElem<'tcx>, + from_non_ssa_index: &mut bool, + ) -> Option<VnIndex> { + let proj = match proj { + ProjectionElem::Deref => { + let ty = place.ty(self.local_decls, self.tcx).ty; + if let Some(Mutability::Not) = ty.ref_mutability() + && let Some(pointee_ty) = ty.builtin_deref(true) + && pointee_ty.is_freeze(self.tcx, self.typing_env()) + { + // An immutable borrow `_x` always points to the same value for the + // lifetime of the borrow, so we can merge all instances of `*_x`. + return Some(self.insert_deref(value)); + } else { + return None; + } + } + ProjectionElem::Downcast(name, index) => ProjectionElem::Downcast(name, index), + ProjectionElem::Field(f, ty) => { + if let Value::Aggregate(_, _, fields) = self.get(value) { + return Some(fields[f.as_usize()]); + } else if let Value::Projection(outer_value, ProjectionElem::Downcast(_, read_variant)) = self.get(value) + && let Value::Aggregate(_, written_variant, fields) = self.get(*outer_value) + // This pass is not aware of control-flow, so we do not know whether the + // replacement we are doing is actually reachable. We could be in any arm of + // ``` + // match Some(x) { + // Some(y) => /* stuff */, + // None => /* other */, + // } + // ``` + // + // In surface rust, the current statement would be unreachable. + // + // However, from the reference chapter on enums and RFC 2195, + // accessing the wrong variant is not UB if the enum has repr. + // So it's not impossible for a series of MIR opts to generate + // a downcast to an inactive variant. + && written_variant == read_variant + { + return Some(fields[f.as_usize()]); + } + ProjectionElem::Field(f, ty) + } + ProjectionElem::Index(idx) => { + if let Value::Repeat(inner, _) = self.get(value) { + *from_non_ssa_index |= self.locals[idx].is_none(); + return Some(*inner); + } + let idx = self.locals[idx]?; + ProjectionElem::Index(idx) + } + ProjectionElem::ConstantIndex { offset, min_length, from_end } => { + match self.get(value) { + Value::Repeat(inner, _) => { + return Some(*inner); + } + Value::Aggregate(AggregateTy::Array, _, operands) => { + let offset = if from_end { + operands.len() - offset as usize + } else { + offset as usize + }; + return operands.get(offset).copied(); + } + _ => {} + }; + ProjectionElem::ConstantIndex { offset, min_length, from_end } + } + ProjectionElem::Subslice { from, to, from_end } => { + ProjectionElem::Subslice { from, to, from_end } + } + ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty), + ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty), + ProjectionElem::UnwrapUnsafeBinder(ty) => ProjectionElem::UnwrapUnsafeBinder(ty), + }; + + Some(self.insert(Value::Projection(value, proj))) + } + + /// Simplify the projection chain if we know better. + #[instrument(level = "trace", skip(self))] + fn simplify_place_projection(&mut self, place: &mut Place<'tcx>, location: Location) { + // If the projection is indirect, we treat the local as a value, so can replace it with + // another local. + if place.is_indirect_first_projection() + && let Some(base) = self.locals[place.local] + && let Some(new_local) = self.try_as_local(base, location) + && place.local != new_local + { + place.local = new_local; + self.reused_locals.insert(new_local); + } + + let mut projection = Cow::Borrowed(&place.projection[..]); + + for i in 0..projection.len() { + let elem = projection[i]; + if let ProjectionElem::Index(idx_local) = elem + && let Some(idx) = self.locals[idx_local] + { + if let Some(offset) = self.evaluated[idx].as_ref() + && let Some(offset) = self.ecx.read_target_usize(offset).discard_err() + && let Some(min_length) = offset.checked_add(1) + { + projection.to_mut()[i] = + ProjectionElem::ConstantIndex { offset, min_length, from_end: false }; + } else if let Some(new_idx_local) = self.try_as_local(idx, location) + && idx_local != new_idx_local + { + projection.to_mut()[i] = ProjectionElem::Index(new_idx_local); + self.reused_locals.insert(new_idx_local); + } + } + } + + if projection.is_owned() { + place.projection = self.tcx.mk_place_elems(&projection); + } + + trace!(?place); + } + + /// Represent the *value* which would be read from `place`, and point `place` to a preexisting + /// place with the same value (if that already exists). + #[instrument(level = "trace", skip(self), ret)] + fn simplify_place_value( + &mut self, + place: &mut Place<'tcx>, + location: Location, + ) -> Option<VnIndex> { + self.simplify_place_projection(place, location); + + // Invariant: `place` and `place_ref` point to the same value, even if they point to + // different memory locations. + let mut place_ref = place.as_ref(); + + // Invariant: `value` holds the value up-to the `index`th projection excluded. + let mut value = self.locals[place.local]?; + let mut from_non_ssa_index = false; + for (index, proj) in place.projection.iter().enumerate() { + if let Value::Projection(pointer, ProjectionElem::Deref) = *self.get(value) + && let Value::Address { place: mut pointee, kind, .. } = *self.get(pointer) + && let AddressKind::Ref(BorrowKind::Shared) = kind + && let Some(v) = self.simplify_place_value(&mut pointee, location) + { + value = v; + place_ref = pointee.project_deeper(&place.projection[index..], self.tcx).as_ref(); + } + if let Some(local) = self.try_as_local(value, location) { + // Both `local` and `Place { local: place.local, projection: projection[..index] }` + // hold the same value. Therefore, following place holds the value in the original + // `place`. + place_ref = PlaceRef { local, projection: &place.projection[index..] }; + } + + let base = PlaceRef { local: place.local, projection: &place.projection[..index] }; + value = self.project(base, value, proj, &mut from_non_ssa_index)?; + } + + if let Value::Projection(pointer, ProjectionElem::Deref) = *self.get(value) + && let Value::Address { place: mut pointee, kind, .. } = *self.get(pointer) + && let AddressKind::Ref(BorrowKind::Shared) = kind + && let Some(v) = self.simplify_place_value(&mut pointee, location) + { + value = v; + place_ref = pointee.project_deeper(&[], self.tcx).as_ref(); + } + if let Some(new_local) = self.try_as_local(value, location) { + place_ref = PlaceRef { local: new_local, projection: &[] }; + } else if from_non_ssa_index { + // If access to non-SSA locals is unavoidable, bail out. + return None; + } + + if place_ref.local != place.local || place_ref.projection.len() < place.projection.len() { + // By the invariant on `place_ref`. + *place = place_ref.project_deeper(&[], self.tcx); + self.reused_locals.insert(place_ref.local); + } + + Some(value) + } + + #[instrument(level = "trace", skip(self), ret)] + fn simplify_operand( + &mut self, + operand: &mut Operand<'tcx>, + location: Location, + ) -> Option<VnIndex> { + match *operand { + Operand::Constant(ref constant) => Some(self.insert_constant(constant.const_)), + Operand::Copy(ref mut place) | Operand::Move(ref mut place) => { + let value = self.simplify_place_value(place, location)?; + if let Some(const_) = self.try_as_constant(value) { + *operand = Operand::Constant(Box::new(const_)); + } + Some(value) + } + } + } + + #[instrument(level = "trace", skip(self), ret)] + fn simplify_rvalue( + &mut self, + lhs: &Place<'tcx>, + rvalue: &mut Rvalue<'tcx>, + location: Location, + ) -> Option<VnIndex> { + let value = match *rvalue { + // Forward values. + Rvalue::Use(ref mut operand) => return self.simplify_operand(operand, location), + Rvalue::CopyForDeref(place) => { + let mut operand = Operand::Copy(place); + let val = self.simplify_operand(&mut operand, location); + *rvalue = Rvalue::Use(operand); + return val; + } + + // Roots. + Rvalue::Repeat(ref mut op, amount) => { + let op = self.simplify_operand(op, location)?; + Value::Repeat(op, amount) + } + Rvalue::NullaryOp(op, ty) => Value::NullaryOp(op, ty), + Rvalue::Aggregate(..) => return self.simplify_aggregate(lhs, rvalue, location), + Rvalue::Ref(_, borrow_kind, ref mut place) => { + self.simplify_place_projection(place, location); + return Some(self.new_pointer(*place, AddressKind::Ref(borrow_kind))); + } + Rvalue::RawPtr(mutbl, ref mut place) => { + self.simplify_place_projection(place, location); + return Some(self.new_pointer(*place, AddressKind::Address(mutbl))); + } + Rvalue::WrapUnsafeBinder(ref mut op, ty) => { + let value = self.simplify_operand(op, location)?; + Value::Cast { + kind: CastKind::Transmute, + value, + from: op.ty(self.local_decls, self.tcx), + to: ty, + } + } + + // Operations. + Rvalue::Len(ref mut place) => return self.simplify_len(place, location), + Rvalue::Cast(ref mut kind, ref mut value, to) => { + return self.simplify_cast(kind, value, to, location); + } + Rvalue::BinaryOp(op, box (ref mut lhs, ref mut rhs)) => { + return self.simplify_binary(op, lhs, rhs, location); + } + Rvalue::UnaryOp(op, ref mut arg_op) => { + return self.simplify_unary(op, arg_op, location); + } + Rvalue::Discriminant(ref mut place) => { + let place = self.simplify_place_value(place, location)?; + if let Some(discr) = self.simplify_discriminant(place) { + return Some(discr); + } + Value::Discriminant(place) + } + + // Unsupported values. + Rvalue::ThreadLocalRef(..) | Rvalue::ShallowInitBox(..) => return None, + }; + debug!(?value); + Some(self.insert(value)) + } + + fn simplify_discriminant(&mut self, place: VnIndex) -> Option<VnIndex> { + if let Value::Aggregate(enum_ty, variant, _) = *self.get(place) + && let AggregateTy::Def(enum_did, enum_args) = enum_ty + && let DefKind::Enum = self.tcx.def_kind(enum_did) + { + let enum_ty = self.tcx.type_of(enum_did).instantiate(self.tcx, enum_args); + let discr = self.ecx.discriminant_for_variant(enum_ty, variant).discard_err()?; + return Some(self.insert_scalar(discr.to_scalar(), discr.layout.ty)); + } + + None + } + + fn try_as_place_elem( + &mut self, + proj: ProjectionElem<VnIndex, Ty<'tcx>>, + loc: Location, + ) -> Option<PlaceElem<'tcx>> { + Some(match proj { + ProjectionElem::Deref => ProjectionElem::Deref, + ProjectionElem::Field(idx, ty) => ProjectionElem::Field(idx, ty), + ProjectionElem::Index(idx) => { + let Some(local) = self.try_as_local(idx, loc) else { + return None; + }; + self.reused_locals.insert(local); + ProjectionElem::Index(local) + } + ProjectionElem::ConstantIndex { offset, min_length, from_end } => { + ProjectionElem::ConstantIndex { offset, min_length, from_end } + } + ProjectionElem::Subslice { from, to, from_end } => { + ProjectionElem::Subslice { from, to, from_end } + } + ProjectionElem::Downcast(symbol, idx) => ProjectionElem::Downcast(symbol, idx), + ProjectionElem::OpaqueCast(idx) => ProjectionElem::OpaqueCast(idx), + ProjectionElem::Subtype(idx) => ProjectionElem::Subtype(idx), + ProjectionElem::UnwrapUnsafeBinder(ty) => ProjectionElem::UnwrapUnsafeBinder(ty), + }) + } + + fn simplify_aggregate_to_copy( + &mut self, + lhs: &Place<'tcx>, + rvalue: &mut Rvalue<'tcx>, + location: Location, + fields: &[VnIndex], + variant_index: VariantIdx, + ) -> Option<VnIndex> { + let Some(&first_field) = fields.first() else { + return None; + }; + let Value::Projection(copy_from_value, _) = *self.get(first_field) else { + return None; + }; + // All fields must correspond one-to-one and come from the same aggregate value. + if fields.iter().enumerate().any(|(index, &v)| { + if let Value::Projection(pointer, ProjectionElem::Field(from_index, _)) = *self.get(v) + && copy_from_value == pointer + && from_index.index() == index + { + return false; + } + true + }) { + return None; + } + + let mut copy_from_local_value = copy_from_value; + if let Value::Projection(pointer, proj) = *self.get(copy_from_value) + && let ProjectionElem::Downcast(_, read_variant) = proj + { + if variant_index == read_variant { + // When copying a variant, there is no need to downcast. + copy_from_local_value = pointer; + } else { + // The copied variant must be identical. + return None; + } + } + + // Allow introducing places with non-constant offsets, as those are still better than + // reconstructing an aggregate. + if let Some(place) = self.try_as_place(copy_from_local_value, location, true) + && rvalue.ty(self.local_decls, self.tcx) == place.ty(self.local_decls, self.tcx).ty + { + // Avoid creating `*a = copy (*b)`, as they might be aliases resulting in overlapping assignments. + // FIXME: This also avoids any kind of projection, not just derefs. We can add allowed projections. + if lhs.as_local().is_some() { + self.reused_locals.insert(place.local); + *rvalue = Rvalue::Use(Operand::Copy(place)); + } + return Some(copy_from_local_value); + } + + None + } + + fn simplify_aggregate( + &mut self, + lhs: &Place<'tcx>, + rvalue: &mut Rvalue<'tcx>, + location: Location, + ) -> Option<VnIndex> { + let Rvalue::Aggregate(box ref kind, ref mut field_ops) = *rvalue else { bug!() }; + + let tcx = self.tcx; + if field_ops.is_empty() { + let is_zst = match *kind { + AggregateKind::Array(..) + | AggregateKind::Tuple + | AggregateKind::Closure(..) + | AggregateKind::CoroutineClosure(..) => true, + // Only enums can be non-ZST. + AggregateKind::Adt(did, ..) => tcx.def_kind(did) != DefKind::Enum, + // Coroutines are never ZST, as they at least contain the implicit states. + AggregateKind::Coroutine(..) => false, + AggregateKind::RawPtr(..) => bug!("MIR for RawPtr aggregate must have 2 fields"), + }; + + if is_zst { + let ty = rvalue.ty(self.local_decls, tcx); + return Some(self.insert_constant(Const::zero_sized(ty))); + } + } + + let (mut ty, variant_index) = match *kind { + AggregateKind::Array(..) => { + assert!(!field_ops.is_empty()); + (AggregateTy::Array, FIRST_VARIANT) + } + AggregateKind::Tuple => { + assert!(!field_ops.is_empty()); + (AggregateTy::Tuple, FIRST_VARIANT) + } + AggregateKind::Closure(did, args) + | AggregateKind::CoroutineClosure(did, args) + | AggregateKind::Coroutine(did, args) => (AggregateTy::Def(did, args), FIRST_VARIANT), + AggregateKind::Adt(did, variant_index, args, _, None) => { + (AggregateTy::Def(did, args), variant_index) + } + // Do not track unions. + AggregateKind::Adt(_, _, _, _, Some(_)) => return None, + AggregateKind::RawPtr(pointee_ty, mtbl) => { + assert_eq!(field_ops.len(), 2); + let data_pointer_ty = field_ops[FieldIdx::ZERO].ty(self.local_decls, self.tcx); + let output_pointer_ty = Ty::new_ptr(self.tcx, pointee_ty, mtbl); + (AggregateTy::RawPtr { data_pointer_ty, output_pointer_ty }, FIRST_VARIANT) + } + }; + + let mut fields: Vec<_> = field_ops + .iter_mut() + .map(|op| self.simplify_operand(op, location).unwrap_or_else(|| self.new_opaque())) + .collect(); + + if let AggregateTy::RawPtr { data_pointer_ty, output_pointer_ty } = &mut ty { + let mut was_updated = false; + + // Any thin pointer of matching mutability is fine as the data pointer. + while let Value::Cast { + kind: CastKind::PtrToPtr, + value: cast_value, + from: cast_from, + to: _, + } = self.get(fields[0]) + && let ty::RawPtr(from_pointee_ty, from_mtbl) = cast_from.kind() + && let ty::RawPtr(_, output_mtbl) = output_pointer_ty.kind() + && from_mtbl == output_mtbl + && from_pointee_ty.is_sized(self.tcx, self.typing_env()) + { + fields[0] = *cast_value; + *data_pointer_ty = *cast_from; + was_updated = true; + } + + if was_updated && let Some(op) = self.try_as_operand(fields[0], location) { + field_ops[FieldIdx::ZERO] = op; + } + } + + if let AggregateTy::Array = ty + && fields.len() > 4 + { + let first = fields[0]; + if fields.iter().all(|&v| v == first) { + let len = ty::Const::from_target_usize(self.tcx, fields.len().try_into().unwrap()); + if let Some(op) = self.try_as_operand(first, location) { + *rvalue = Rvalue::Repeat(op, len); + } + return Some(self.insert(Value::Repeat(first, len))); + } + } + + if let AggregateTy::Def(_, _) = ty + && let Some(value) = + self.simplify_aggregate_to_copy(lhs, rvalue, location, &fields, variant_index) + { + return Some(value); + } + + Some(self.insert(Value::Aggregate(ty, variant_index, fields))) + } + + #[instrument(level = "trace", skip(self), ret)] + fn simplify_unary( + &mut self, + op: UnOp, + arg_op: &mut Operand<'tcx>, + location: Location, + ) -> Option<VnIndex> { + let mut arg_index = self.simplify_operand(arg_op, location)?; + + // PtrMetadata doesn't care about *const vs *mut vs & vs &mut, + // so start by removing those distinctions so we can update the `Operand` + if op == UnOp::PtrMetadata { + let mut was_updated = false; + loop { + match self.get(arg_index) { + // Pointer casts that preserve metadata, such as + // `*const [i32]` <-> `*mut [i32]` <-> `*mut [f32]`. + // It's critical that this not eliminate cases like + // `*const [T]` -> `*const T` which remove metadata. + // We run on potentially-generic MIR, though, so unlike codegen + // we can't always know exactly what the metadata are. + // To allow things like `*mut (?A, ?T)` <-> `*mut (?B, ?T)`, + // it's fine to get a projection as the type. + Value::Cast { kind: CastKind::PtrToPtr, value: inner, from, to } + if self.pointers_have_same_metadata(*from, *to) => + { + arg_index = *inner; + was_updated = true; + continue; + } + + // `&mut *p`, `&raw *p`, etc don't change metadata. + Value::Address { place, kind: _, provenance: _ } + if let PlaceRef { local, projection: [PlaceElem::Deref] } = + place.as_ref() + && let Some(local_index) = self.locals[local] => + { + arg_index = local_index; + was_updated = true; + continue; + } + + _ => { + if was_updated && let Some(op) = self.try_as_operand(arg_index, location) { + *arg_op = op; + } + break; + } + } + } + } + + let value = match (op, self.get(arg_index)) { + (UnOp::Not, Value::UnaryOp(UnOp::Not, inner)) => return Some(*inner), + (UnOp::Neg, Value::UnaryOp(UnOp::Neg, inner)) => return Some(*inner), + (UnOp::Not, Value::BinaryOp(BinOp::Eq, lhs, rhs)) => { + Value::BinaryOp(BinOp::Ne, *lhs, *rhs) + } + (UnOp::Not, Value::BinaryOp(BinOp::Ne, lhs, rhs)) => { + Value::BinaryOp(BinOp::Eq, *lhs, *rhs) + } + (UnOp::PtrMetadata, Value::Aggregate(AggregateTy::RawPtr { .. }, _, fields)) => { + return Some(fields[1]); + } + // We have an unsizing cast, which assigns the length to wide pointer metadata. + ( + UnOp::PtrMetadata, + Value::Cast { + kind: CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize, _), + from, + to, + .. + }, + ) if let ty::Slice(..) = to.builtin_deref(true).unwrap().kind() + && let ty::Array(_, len) = from.builtin_deref(true).unwrap().kind() => + { + return Some(self.insert_constant(Const::Ty(self.tcx.types.usize, *len))); + } + _ => Value::UnaryOp(op, arg_index), + }; + Some(self.insert(value)) + } + + #[instrument(level = "trace", skip(self), ret)] + fn simplify_binary( + &mut self, + op: BinOp, + lhs_operand: &mut Operand<'tcx>, + rhs_operand: &mut Operand<'tcx>, + location: Location, + ) -> Option<VnIndex> { + let lhs = self.simplify_operand(lhs_operand, location); + let rhs = self.simplify_operand(rhs_operand, location); + // Only short-circuit options after we called `simplify_operand` + // on both operands for side effect. + let mut lhs = lhs?; + let mut rhs = rhs?; + + let lhs_ty = lhs_operand.ty(self.local_decls, self.tcx); + + // If we're comparing pointers, remove `PtrToPtr` casts if the from + // types of both casts and the metadata all match. + if let BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge = op + && lhs_ty.is_any_ptr() + && let Value::Cast { + kind: CastKind::PtrToPtr, value: lhs_value, from: lhs_from, .. + } = self.get(lhs) + && let Value::Cast { + kind: CastKind::PtrToPtr, value: rhs_value, from: rhs_from, .. + } = self.get(rhs) + && lhs_from == rhs_from + && self.pointers_have_same_metadata(*lhs_from, lhs_ty) + { + lhs = *lhs_value; + rhs = *rhs_value; + if let Some(lhs_op) = self.try_as_operand(lhs, location) + && let Some(rhs_op) = self.try_as_operand(rhs, location) + { + *lhs_operand = lhs_op; + *rhs_operand = rhs_op; + } + } + + if let Some(value) = self.simplify_binary_inner(op, lhs_ty, lhs, rhs) { + return Some(value); + } + let value = Value::BinaryOp(op, lhs, rhs); + Some(self.insert(value)) + } + + fn simplify_binary_inner( + &mut self, + op: BinOp, + lhs_ty: Ty<'tcx>, + lhs: VnIndex, + rhs: VnIndex, + ) -> Option<VnIndex> { + // Floats are weird enough that none of the logic below applies. + let reasonable_ty = + lhs_ty.is_integral() || lhs_ty.is_bool() || lhs_ty.is_char() || lhs_ty.is_any_ptr(); + if !reasonable_ty { + return None; + } + + let layout = self.ecx.layout_of(lhs_ty).ok()?; + + let as_bits = |value: VnIndex| { + let constant = self.evaluated[value].as_ref()?; + if layout.backend_repr.is_scalar() { + let scalar = self.ecx.read_scalar(constant).discard_err()?; + scalar.to_bits(constant.layout.size).discard_err() + } else { + // `constant` is a wide pointer. Do not evaluate to bits. + None + } + }; + + // Represent the values as `Left(bits)` or `Right(VnIndex)`. + use Either::{Left, Right}; + let a = as_bits(lhs).map_or(Right(lhs), Left); + let b = as_bits(rhs).map_or(Right(rhs), Left); + + let result = match (op, a, b) { + // Neutral elements. + ( + BinOp::Add + | BinOp::AddWithOverflow + | BinOp::AddUnchecked + | BinOp::BitOr + | BinOp::BitXor, + Left(0), + Right(p), + ) + | ( + BinOp::Add + | BinOp::AddWithOverflow + | BinOp::AddUnchecked + | BinOp::BitOr + | BinOp::BitXor + | BinOp::Sub + | BinOp::SubWithOverflow + | BinOp::SubUnchecked + | BinOp::Offset + | BinOp::Shl + | BinOp::Shr, + Right(p), + Left(0), + ) + | (BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked, Left(1), Right(p)) + | ( + BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked | BinOp::Div, + Right(p), + Left(1), + ) => p, + // Attempt to simplify `x & ALL_ONES` to `x`, with `ALL_ONES` depending on type size. + (BinOp::BitAnd, Right(p), Left(ones)) | (BinOp::BitAnd, Left(ones), Right(p)) + if ones == layout.size.truncate(u128::MAX) + || (layout.ty.is_bool() && ones == 1) => + { + p + } + // Absorbing elements. + ( + BinOp::Mul | BinOp::MulWithOverflow | BinOp::MulUnchecked | BinOp::BitAnd, + _, + Left(0), + ) + | (BinOp::Rem, _, Left(1)) + | ( + BinOp::Mul + | BinOp::MulWithOverflow + | BinOp::MulUnchecked + | BinOp::Div + | BinOp::Rem + | BinOp::BitAnd + | BinOp::Shl + | BinOp::Shr, + Left(0), + _, + ) => self.insert_scalar(Scalar::from_uint(0u128, layout.size), lhs_ty), + // Attempt to simplify `x | ALL_ONES` to `ALL_ONES`. + (BinOp::BitOr, _, Left(ones)) | (BinOp::BitOr, Left(ones), _) + if ones == layout.size.truncate(u128::MAX) + || (layout.ty.is_bool() && ones == 1) => + { + self.insert_scalar(Scalar::from_uint(ones, layout.size), lhs_ty) + } + // Sub/Xor with itself. + (BinOp::Sub | BinOp::SubWithOverflow | BinOp::SubUnchecked | BinOp::BitXor, a, b) + if a == b => + { + self.insert_scalar(Scalar::from_uint(0u128, layout.size), lhs_ty) + } + // Comparison: + // - if both operands can be computed as bits, just compare the bits; + // - if we proved that both operands have the same value, we can insert true/false; + // - otherwise, do nothing, as we do not try to prove inequality. + (BinOp::Eq, Left(a), Left(b)) => self.insert_bool(a == b), + (BinOp::Eq, a, b) if a == b => self.insert_bool(true), + (BinOp::Ne, Left(a), Left(b)) => self.insert_bool(a != b), + (BinOp::Ne, a, b) if a == b => self.insert_bool(false), + _ => return None, + }; + + if op.is_overflowing() { + let false_val = self.insert_bool(false); + Some(self.insert_tuple(vec![result, false_val])) + } else { + Some(result) + } + } + + fn simplify_cast( + &mut self, + initial_kind: &mut CastKind, + initial_operand: &mut Operand<'tcx>, + to: Ty<'tcx>, + location: Location, + ) -> Option<VnIndex> { + use CastKind::*; + use rustc_middle::ty::adjustment::PointerCoercion::*; + + let mut from = initial_operand.ty(self.local_decls, self.tcx); + let mut kind = *initial_kind; + let mut value = self.simplify_operand(initial_operand, location)?; + if from == to { + return Some(value); + } + + if let CastKind::PointerCoercion(ReifyFnPointer | ClosureFnPointer(_), _) = kind { + // Each reification of a generic fn may get a different pointer. + // Do not try to merge them. + return Some(self.new_opaque()); + } + + let mut was_ever_updated = false; + loop { + let mut was_updated_this_iteration = false; + + // Transmuting between raw pointers is just a pointer cast so long as + // they have the same metadata type (like `*const i32` <=> `*mut u64` + // or `*mut [i32]` <=> `*const [u64]`), including the common special + // case of `*const T` <=> `*mut T`. + if let Transmute = kind + && from.is_raw_ptr() + && to.is_raw_ptr() + && self.pointers_have_same_metadata(from, to) + { + kind = PtrToPtr; + was_updated_this_iteration = true; + } + + // If a cast just casts away the metadata again, then we can get it by + // casting the original thin pointer passed to `from_raw_parts` + if let PtrToPtr = kind + && let Value::Aggregate(AggregateTy::RawPtr { data_pointer_ty, .. }, _, fields) = + self.get(value) + && let ty::RawPtr(to_pointee, _) = to.kind() + && to_pointee.is_sized(self.tcx, self.typing_env()) + { + from = *data_pointer_ty; + value = fields[0]; + was_updated_this_iteration = true; + if *data_pointer_ty == to { + return Some(fields[0]); + } + } + + // Aggregate-then-Transmute can just transmute the original field value, + // so long as the bytes of a value from only from a single field. + if let Transmute = kind + && let Value::Aggregate(_aggregate_ty, variant_idx, field_values) = self.get(value) + && let Some((field_idx, field_ty)) = + self.value_is_all_in_one_field(from, *variant_idx) + { + from = field_ty; + value = field_values[field_idx.as_usize()]; + was_updated_this_iteration = true; + if field_ty == to { + return Some(value); + } + } + + // Various cast-then-cast cases can be simplified. + if let Value::Cast { + kind: inner_kind, + value: inner_value, + from: inner_from, + to: inner_to, + } = *self.get(value) + { + let new_kind = match (inner_kind, kind) { + // Even if there's a narrowing cast in here that's fine, because + // things like `*mut [i32] -> *mut i32 -> *const i32` and + // `*mut [i32] -> *const [i32] -> *const i32` can skip the middle in MIR. + (PtrToPtr, PtrToPtr) => Some(PtrToPtr), + // PtrToPtr-then-Transmute is fine so long as the pointer cast is identity: + // `*const T -> *mut T -> NonNull<T>` is fine, but we need to check for narrowing + // to skip things like `*const [i32] -> *const i32 -> NonNull<T>`. + (PtrToPtr, Transmute) + if self.pointers_have_same_metadata(inner_from, inner_to) => + { + Some(Transmute) + } + // Similarly, for Transmute-then-PtrToPtr. Note that we need to check different + // variables for their metadata, and thus this can't merge with the previous arm. + (Transmute, PtrToPtr) if self.pointers_have_same_metadata(from, to) => { + Some(Transmute) + } + // If would be legal to always do this, but we don't want to hide information + // from the backend that it'd otherwise be able to use for optimizations. + (Transmute, Transmute) + if !self.type_may_have_niche_of_interest_to_backend(inner_to) => + { + Some(Transmute) + } + _ => None, + }; + if let Some(new_kind) = new_kind { + kind = new_kind; + from = inner_from; + value = inner_value; + was_updated_this_iteration = true; + if inner_from == to { + return Some(inner_value); + } + } + } + + if was_updated_this_iteration { + was_ever_updated = true; + } else { + break; + } + } + + if was_ever_updated && let Some(op) = self.try_as_operand(value, location) { + *initial_operand = op; + *initial_kind = kind; + } + + Some(self.insert(Value::Cast { kind, value, from, to })) + } + + fn simplify_len(&mut self, place: &mut Place<'tcx>, location: Location) -> Option<VnIndex> { + // Trivial case: we are fetching a statically known length. + let place_ty = place.ty(self.local_decls, self.tcx).ty; + if let ty::Array(_, len) = place_ty.kind() { + return Some(self.insert_constant(Const::Ty(self.tcx.types.usize, *len))); + } + + let mut inner = self.simplify_place_value(place, location)?; + + // The length information is stored in the wide pointer. + // Reborrowing copies length information from one pointer to the other. + while let Value::Address { place: borrowed, .. } = self.get(inner) + && let [PlaceElem::Deref] = borrowed.projection[..] + && let Some(borrowed) = self.locals[borrowed.local] + { + inner = borrowed; + } + + // We have an unsizing cast, which assigns the length to wide pointer metadata. + if let Value::Cast { kind, from, to, .. } = self.get(inner) + && let CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize, _) = kind + && let Some(from) = from.builtin_deref(true) + && let ty::Array(_, len) = from.kind() + && let Some(to) = to.builtin_deref(true) + && let ty::Slice(..) = to.kind() + { + return Some(self.insert_constant(Const::Ty(self.tcx.types.usize, *len))); + } + + // Fallback: a symbolic `Len`. + Some(self.insert(Value::Len(inner))) + } + + fn pointers_have_same_metadata(&self, left_ptr_ty: Ty<'tcx>, right_ptr_ty: Ty<'tcx>) -> bool { + let left_meta_ty = left_ptr_ty.pointee_metadata_ty_or_projection(self.tcx); + let right_meta_ty = right_ptr_ty.pointee_metadata_ty_or_projection(self.tcx); + if left_meta_ty == right_meta_ty { + true + } else if let Ok(left) = + self.tcx.try_normalize_erasing_regions(self.typing_env(), left_meta_ty) + && let Ok(right) = + self.tcx.try_normalize_erasing_regions(self.typing_env(), right_meta_ty) + { + left == right + } else { + false + } + } + + /// Returns `false` if we know for sure that this type has no interesting niche, + /// and thus we can skip transmuting through it without worrying. + /// + /// The backend will emit `assume`s when transmuting between types with niches, + /// so we want to preserve `i32 -> char -> u32` so that that data is around, + /// but it's fine to skip whole-range-is-value steps like `A -> u32 -> B`. + fn type_may_have_niche_of_interest_to_backend(&self, ty: Ty<'tcx>) -> bool { + let Ok(layout) = self.ecx.layout_of(ty) else { + // If it's too generic or something, then assume it might be interesting later. + return true; + }; + + if layout.uninhabited { + return true; + } + + match layout.backend_repr { + BackendRepr::Scalar(a) => !a.is_always_valid(&self.ecx), + BackendRepr::ScalarPair(a, b) => { + !a.is_always_valid(&self.ecx) || !b.is_always_valid(&self.ecx) + } + BackendRepr::SimdVector { .. } | BackendRepr::Memory { .. } => false, + } + } + + fn value_is_all_in_one_field( + &self, + ty: Ty<'tcx>, + variant: VariantIdx, + ) -> Option<(FieldIdx, Ty<'tcx>)> { + if let Ok(layout) = self.ecx.layout_of(ty) + && let abi::Variants::Single { index } = layout.variants + && index == variant + && let Some((field_idx, field_layout)) = layout.non_1zst_field(&self.ecx) + && layout.size == field_layout.size + { + // We needed to check the variant to avoid trying to read the tag + // field from an enum where no fields have variants, since that tag + // field isn't in the `Aggregate` from which we're getting values. + Some((field_idx, field_layout.ty)) + } else if let ty::Adt(adt, args) = ty.kind() + && adt.is_struct() + && adt.repr().transparent() + && let [single_field] = adt.non_enum_variant().fields.raw.as_slice() + { + Some((FieldIdx::ZERO, single_field.ty(self.tcx, args))) + } else { + None + } + } +} + +fn op_to_prop_const<'tcx>( + ecx: &mut InterpCx<'tcx, DummyMachine>, + op: &OpTy<'tcx>, +) -> Option<ConstValue<'tcx>> { + // Do not attempt to propagate unsized locals. + if op.layout.is_unsized() { + return None; + } + + // This constant is a ZST, just return an empty value. + if op.layout.is_zst() { + return Some(ConstValue::ZeroSized); + } + + // Do not synthetize too large constants. Codegen will just memcpy them, which we'd like to + // avoid. + if !matches!(op.layout.backend_repr, BackendRepr::Scalar(..) | BackendRepr::ScalarPair(..)) { + return None; + } + + // If this constant has scalar ABI, return it as a `ConstValue::Scalar`. + if let BackendRepr::Scalar(abi::Scalar::Initialized { .. }) = op.layout.backend_repr + && let Some(scalar) = ecx.read_scalar(op).discard_err() + { + if !scalar.try_to_scalar_int().is_ok() { + // Check that we do not leak a pointer. + // Those pointers may lose part of their identity in codegen. + // FIXME: remove this hack once https://github.com/rust-lang/rust/issues/79738 is fixed. + return None; + } + return Some(ConstValue::Scalar(scalar)); + } + + // If this constant is already represented as an `Allocation`, + // try putting it into global memory to return it. + if let Either::Left(mplace) = op.as_mplace_or_imm() { + let (size, _align) = ecx.size_and_align_of_val(&mplace).discard_err()??; + + // Do not try interning a value that contains provenance. + // Due to https://github.com/rust-lang/rust/issues/79738, doing so could lead to bugs. + // FIXME: remove this hack once that issue is fixed. + let alloc_ref = ecx.get_ptr_alloc(mplace.ptr(), size).discard_err()??; + if alloc_ref.has_provenance() { + return None; + } + + let pointer = mplace.ptr().into_pointer_or_addr().ok()?; + let (prov, offset) = pointer.into_parts(); + let alloc_id = prov.alloc_id(); + intern_const_alloc_for_constprop(ecx, alloc_id).discard_err()?; + + // `alloc_id` may point to a static. Codegen will choke on an `Indirect` with anything + // by `GlobalAlloc::Memory`, so do fall through to copying if needed. + // FIXME: find a way to treat this more uniformly (probably by fixing codegen) + if let GlobalAlloc::Memory(alloc) = ecx.tcx.global_alloc(alloc_id) + // Transmuting a constant is just an offset in the allocation. If the alignment of the + // allocation is not enough, fallback to copying into a properly aligned value. + && alloc.inner().align >= op.layout.align.abi + { + return Some(ConstValue::Indirect { alloc_id, offset }); + } + } + + // Everything failed: create a new allocation to hold the data. + let alloc_id = + ecx.intern_with_temp_alloc(op.layout, |ecx, dest| ecx.copy_op(op, dest)).discard_err()?; + let value = ConstValue::Indirect { alloc_id, offset: Size::ZERO }; + + // Check that we do not leak a pointer. + // Those pointers may lose part of their identity in codegen. + // FIXME: remove this hack once https://github.com/rust-lang/rust/issues/79738 is fixed. + if ecx.tcx.global_alloc(alloc_id).unwrap_memory().inner().provenance().ptrs().is_empty() { + return Some(value); + } + + None +} + +impl<'tcx> VnState<'_, 'tcx> { + /// If either [`Self::try_as_constant`] as [`Self::try_as_place`] succeeds, + /// returns that result as an [`Operand`]. + fn try_as_operand(&mut self, index: VnIndex, location: Location) -> Option<Operand<'tcx>> { + if let Some(const_) = self.try_as_constant(index) { + Some(Operand::Constant(Box::new(const_))) + } else if let Some(place) = self.try_as_place(index, location, false) { + self.reused_locals.insert(place.local); + Some(Operand::Copy(place)) + } else { + None + } + } + + /// If `index` is a `Value::Constant`, return the `Constant` to be put in the MIR. + fn try_as_constant(&mut self, index: VnIndex) -> Option<ConstOperand<'tcx>> { + // This was already constant in MIR, do not change it. If the constant is not + // deterministic, adding an additional mention of it in MIR will not give the same value as + // the former mention. + if let Value::Constant { value, disambiguator: 0 } = *self.get(index) { + debug_assert!(value.is_deterministic()); + return Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_: value }); + } + + let op = self.evaluated[index].as_ref()?; + if op.layout.is_unsized() { + // Do not attempt to propagate unsized locals. + return None; + } + + let value = op_to_prop_const(&mut self.ecx, op)?; + + // Check that we do not leak a pointer. + // Those pointers may lose part of their identity in codegen. + // FIXME: remove this hack once https://github.com/rust-lang/rust/issues/79738 is fixed. + assert!(!value.may_have_provenance(self.tcx, op.layout.size)); + + let const_ = Const::Val(value, op.layout.ty); + Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_ }) + } + + /// Construct a place which holds the same value as `index` and for which all locals strictly + /// dominate `loc`. If you used this place, add its base local to `reused_locals` to remove + /// storage statements. + #[instrument(level = "trace", skip(self), ret)] + fn try_as_place( + &mut self, + mut index: VnIndex, + loc: Location, + allow_complex_projection: bool, + ) -> Option<Place<'tcx>> { + let mut projection = SmallVec::<[PlaceElem<'tcx>; 1]>::new(); + loop { + if let Some(local) = self.try_as_local(index, loc) { + projection.reverse(); + let place = + Place { local, projection: self.tcx.mk_place_elems(projection.as_slice()) }; + return Some(place); + } else if let Value::Projection(pointer, proj) = *self.get(index) + && (allow_complex_projection || proj.is_stable_offset()) + && let Some(proj) = self.try_as_place_elem(proj, loc) + { + projection.push(proj); + index = pointer; + } else { + return None; + } + } + } + + /// If there is a local which is assigned `index`, and its assignment strictly dominates `loc`, + /// return it. If you used this local, add it to `reused_locals` to remove storage statements. + fn try_as_local(&mut self, index: VnIndex, loc: Location) -> Option<Local> { + let other = self.rev_locals.get(index)?; + other + .iter() + .find(|&&other| self.ssa.assignment_dominates(&self.dominators, other, loc)) + .copied() + } +} + +impl<'tcx> MutVisitor<'tcx> for VnState<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + self.simplify_place_projection(place, location); + if context.is_mutating_use() && !place.projection.is_empty() { + // Non-local mutation maybe invalidate deref. + self.invalidate_derefs(); + } + self.super_place(place, context, location); + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { + self.simplify_operand(operand, location); + self.super_operand(operand, location); + } + + fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) { + if let StatementKind::Assign(box (ref mut lhs, ref mut rvalue)) = stmt.kind { + self.simplify_place_projection(lhs, location); + + let value = self.simplify_rvalue(lhs, rvalue, location); + let value = if let Some(local) = lhs.as_local() + && self.ssa.is_ssa(local) + // FIXME(#112651) `rvalue` may have a subtype to `local`. We can only mark + // `local` as reusable if we have an exact type match. + && self.local_decls[local].ty == rvalue.ty(self.local_decls, self.tcx) + { + let value = value.unwrap_or_else(|| self.new_opaque()); + self.assign(local, value); + Some(value) + } else { + value + }; + if let Some(value) = value { + if let Some(const_) = self.try_as_constant(value) { + *rvalue = Rvalue::Use(Operand::Constant(Box::new(const_))); + } else if let Some(place) = self.try_as_place(value, location, false) + && *rvalue != Rvalue::Use(Operand::Move(place)) + && *rvalue != Rvalue::Use(Operand::Copy(place)) + { + *rvalue = Rvalue::Use(Operand::Copy(place)); + self.reused_locals.insert(place.local); + } + } + } + self.super_statement(stmt, location); + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) { + if let Terminator { kind: TerminatorKind::Call { destination, .. }, .. } = terminator { + if let Some(local) = destination.as_local() + && self.ssa.is_ssa(local) + { + let opaque = self.new_opaque(); + self.assign(local, opaque); + } + } + // Function calls and ASM may invalidate (nested) derefs. We must handle them carefully. + // Currently, only preserving derefs for trivial terminators like SwitchInt and Goto. + let safe_to_preserve_derefs = matches!( + terminator.kind, + TerminatorKind::SwitchInt { .. } | TerminatorKind::Goto { .. } + ); + if !safe_to_preserve_derefs { + self.invalidate_derefs(); + } + self.super_terminator(terminator, location); + } +} + +struct StorageRemover<'tcx> { + tcx: TyCtxt<'tcx>, + reused_locals: DenseBitSet<Local>, +} + +impl<'tcx> MutVisitor<'tcx> for StorageRemover<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, _: Location) { + if let Operand::Move(place) = *operand + && !place.is_indirect_first_projection() + && self.reused_locals.contains(place.local) + { + *operand = Operand::Copy(place); + } + } + + fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, loc: Location) { + match stmt.kind { + // When removing storage statements, we need to remove both (#107511). + StatementKind::StorageLive(l) | StatementKind::StorageDead(l) + if self.reused_locals.contains(l) => + { + stmt.make_nop() + } + _ => self.super_statement(stmt, loc), + } + } +} diff --git a/compiler/rustc_mir_transform/src/impossible_predicates.rs b/compiler/rustc_mir_transform/src/impossible_predicates.rs new file mode 100644 index 00000000000..86e2bf6cb3c --- /dev/null +++ b/compiler/rustc_mir_transform/src/impossible_predicates.rs @@ -0,0 +1,60 @@ +//! Check if it's even possible to satisfy the 'where' clauses +//! for this item. +//! +//! It's possible to `#!feature(trivial_bounds)]` to write +//! a function with impossible to satisfy clauses, e.g.: +//! `fn foo() where String: Copy {}`. +//! +//! We don't usually need to worry about this kind of case, +//! since we would get a compilation error if the user tried +//! to call it. However, since we optimize even without any +//! calls to the function, we need to make sure that it even +//! makes sense to try to evaluate the body. +//! +//! If there are unsatisfiable where clauses, then all bets are +//! off, and we just give up. +//! +//! We manually filter the predicates, skipping anything that's not +//! "global". We are in a potentially generic context +//! (e.g. we are evaluating a function without instantiating generic +//! parameters, so this filtering serves two purposes: +//! +//! 1. We skip evaluating any predicates that we would +//! never be able prove are unsatisfiable (e.g. `<T as Foo>` +//! 2. We avoid trying to normalize predicates involving generic +//! parameters (e.g. `<T as Foo>::MyItem`). This can confuse +//! the normalization code (leading to cycle errors), since +//! it's usually never invoked in this way. + +use rustc_middle::mir::{Body, START_BLOCK, TerminatorKind}; +use rustc_middle::ty::{TyCtxt, TypeVisitableExt}; +use rustc_trait_selection::traits; +use tracing::trace; + +use crate::pass_manager::MirPass; + +pub(crate) struct ImpossiblePredicates; + +impl<'tcx> MirPass<'tcx> for ImpossiblePredicates { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let predicates = tcx + .predicates_of(body.source.def_id()) + .predicates + .iter() + .filter_map(|(p, _)| if p.is_global() { Some(*p) } else { None }); + if traits::impossible_predicates(tcx, traits::elaborate(tcx, predicates).collect()) { + trace!("found unsatisfiable predicates for {:?}", body.source); + // Clear the body to only contain a single `unreachable` statement. + let bbs = body.basic_blocks.as_mut(); + bbs.raw.truncate(1); + bbs[START_BLOCK].statements.clear(); + bbs[START_BLOCK].terminator_mut().kind = TerminatorKind::Unreachable; + body.var_debug_info.clear(); + body.local_decls.raw.truncate(body.arg_count + 1); + } + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs new file mode 100644 index 00000000000..f48dba9663a --- /dev/null +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -0,0 +1,1416 @@ +//! Inlining pass for MIR functions. + +use std::assert_matches::debug_assert_matches; +use std::iter; +use std::ops::{Range, RangeFrom}; + +use rustc_abi::{ExternAbi, FieldIdx}; +use rustc_attr_data_structures::{InlineAttr, OptimizeAttr}; +use rustc_hir::def::DefKind; +use rustc_hir::def_id::DefId; +use rustc_index::Idx; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::bug; +use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Instance, InstanceKind, Ty, TyCtxt, TypeFlags, TypeVisitableExt}; +use rustc_session::config::{DebugInfo, OptLevel}; +use rustc_span::source_map::Spanned; +use tracing::{debug, instrument, trace, trace_span}; + +use crate::cost_checker::{CostChecker, is_call_like}; +use crate::deref_separator::deref_finder; +use crate::simplify::simplify_cfg; +use crate::validate::validate_types; +use crate::{check_inline, util}; + +pub(crate) mod cycle; + +const HISTORY_DEPTH_LIMIT: usize = 20; +const TOP_DOWN_DEPTH_LIMIT: usize = 5; + +#[derive(Clone, Debug)] +struct CallSite<'tcx> { + callee: Instance<'tcx>, + fn_sig: ty::PolyFnSig<'tcx>, + block: BasicBlock, + source_info: SourceInfo, +} + +// Made public so that `mir_drops_elaborated_and_const_checked` can be overridden +// by custom rustc drivers, running all the steps by themselves. See #114628. +pub struct Inline; + +impl<'tcx> crate::MirPass<'tcx> for Inline { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + if let Some(enabled) = sess.opts.unstable_opts.inline_mir { + return enabled; + } + + match sess.mir_opt_level() { + 0 | 1 => false, + 2 => { + (sess.opts.optimize == OptLevel::More || sess.opts.optimize == OptLevel::Aggressive) + && sess.opts.incremental == None + } + _ => true, + } + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let span = trace_span!("inline", body = %tcx.def_path_str(body.source.def_id())); + let _guard = span.enter(); + if inline::<NormalInliner<'tcx>>(tcx, body) { + debug!("running simplify cfg on {:?}", body.source); + simplify_cfg(tcx, body); + deref_finder(tcx, body); + } + } + + fn is_required(&self) -> bool { + false + } +} + +pub struct ForceInline; + +impl ForceInline { + pub fn should_run_pass_for_callee<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> bool { + matches!(tcx.codegen_fn_attrs(def_id).inline, InlineAttr::Force { .. }) + } +} + +impl<'tcx> crate::MirPass<'tcx> for ForceInline { + fn is_enabled(&self, _: &rustc_session::Session) -> bool { + true + } + + fn can_be_overridden(&self) -> bool { + false + } + + fn is_required(&self) -> bool { + true + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let span = trace_span!("force_inline", body = %tcx.def_path_str(body.source.def_id())); + let _guard = span.enter(); + if inline::<ForceInliner<'tcx>>(tcx, body) { + debug!("running simplify cfg on {:?}", body.source); + simplify_cfg(tcx, body); + deref_finder(tcx, body); + } + } +} + +trait Inliner<'tcx> { + fn new(tcx: TyCtxt<'tcx>, def_id: DefId, body: &Body<'tcx>) -> Self; + + fn tcx(&self) -> TyCtxt<'tcx>; + fn typing_env(&self) -> ty::TypingEnv<'tcx>; + fn history(&self) -> &[DefId]; + fn caller_def_id(&self) -> DefId; + + /// Has the caller body been changed? + fn changed(self) -> bool; + + /// Should inlining happen for a given callee? + fn should_inline_for_callee(&self, def_id: DefId) -> bool; + + fn check_codegen_attributes_extra( + &self, + callee_attrs: &CodegenFnAttrs, + ) -> Result<(), &'static str>; + + fn check_caller_mir_body(&self, body: &Body<'tcx>) -> bool; + + /// Returns inlining decision that is based on the examination of callee MIR body. + /// Assumes that codegen attributes have been checked for compatibility already. + fn check_callee_mir_body( + &self, + callsite: &CallSite<'tcx>, + callee_body: &Body<'tcx>, + callee_attrs: &CodegenFnAttrs, + ) -> Result<(), &'static str>; + + /// Called when inlining succeeds. + fn on_inline_success( + &mut self, + callsite: &CallSite<'tcx>, + caller_body: &mut Body<'tcx>, + new_blocks: std::ops::Range<BasicBlock>, + ); + + /// Called when inlining failed or was not performed. + fn on_inline_failure(&self, callsite: &CallSite<'tcx>, reason: &'static str); +} + +struct ForceInliner<'tcx> { + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + /// `DefId` of caller. + def_id: DefId, + /// Stack of inlined instances. + /// We only check the `DefId` and not the args because we want to + /// avoid inlining cases of polymorphic recursion. + /// The number of `DefId`s is finite, so checking history is enough + /// to ensure that we do not loop endlessly while inlining. + history: Vec<DefId>, + /// Indicates that the caller body has been modified. + changed: bool, +} + +impl<'tcx> Inliner<'tcx> for ForceInliner<'tcx> { + fn new(tcx: TyCtxt<'tcx>, def_id: DefId, body: &Body<'tcx>) -> Self { + Self { tcx, typing_env: body.typing_env(tcx), def_id, history: Vec::new(), changed: false } + } + + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn typing_env(&self) -> ty::TypingEnv<'tcx> { + self.typing_env + } + + fn history(&self) -> &[DefId] { + &self.history + } + + fn caller_def_id(&self) -> DefId { + self.def_id + } + + fn changed(self) -> bool { + self.changed + } + + fn should_inline_for_callee(&self, def_id: DefId) -> bool { + ForceInline::should_run_pass_for_callee(self.tcx(), def_id) + } + + fn check_codegen_attributes_extra( + &self, + callee_attrs: &CodegenFnAttrs, + ) -> Result<(), &'static str> { + debug_assert_matches!(callee_attrs.inline, InlineAttr::Force { .. }); + Ok(()) + } + + fn check_caller_mir_body(&self, _: &Body<'tcx>) -> bool { + true + } + + #[instrument(level = "debug", skip(self, callee_body))] + fn check_callee_mir_body( + &self, + _: &CallSite<'tcx>, + callee_body: &Body<'tcx>, + callee_attrs: &CodegenFnAttrs, + ) -> Result<(), &'static str> { + if callee_body.tainted_by_errors.is_some() { + return Err("body has errors"); + } + + let caller_attrs = self.tcx().codegen_fn_attrs(self.caller_def_id()); + if callee_attrs.instruction_set != caller_attrs.instruction_set + && callee_body + .basic_blocks + .iter() + .any(|bb| matches!(bb.terminator().kind, TerminatorKind::InlineAsm { .. })) + { + // During the attribute checking stage we allow a callee with no + // instruction_set assigned to count as compatible with a function that does + // assign one. However, during this stage we require an exact match when any + // inline-asm is detected. LLVM will still possibly do an inline later on + // if the no-attribute function ends up with the same instruction set anyway. + Err("cannot move inline-asm across instruction sets") + } else { + Ok(()) + } + } + + fn on_inline_success( + &mut self, + callsite: &CallSite<'tcx>, + caller_body: &mut Body<'tcx>, + new_blocks: std::ops::Range<BasicBlock>, + ) { + self.changed = true; + + self.history.push(callsite.callee.def_id()); + process_blocks(self, caller_body, new_blocks); + self.history.pop(); + } + + fn on_inline_failure(&self, callsite: &CallSite<'tcx>, reason: &'static str) { + let tcx = self.tcx(); + let InlineAttr::Force { attr_span, reason: justification } = + tcx.codegen_fn_attrs(callsite.callee.def_id()).inline + else { + bug!("called on item without required inlining"); + }; + + let call_span = callsite.source_info.span; + tcx.dcx().emit_err(crate::errors::ForceInlineFailure { + call_span, + attr_span, + caller_span: tcx.def_span(self.def_id), + caller: tcx.def_path_str(self.def_id), + callee_span: tcx.def_span(callsite.callee.def_id()), + callee: tcx.def_path_str(callsite.callee.def_id()), + reason, + justification: justification.map(|sym| crate::errors::ForceInlineJustification { sym }), + }); + } +} + +struct NormalInliner<'tcx> { + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + /// `DefId` of caller. + def_id: DefId, + /// Stack of inlined instances. + /// We only check the `DefId` and not the args because we want to + /// avoid inlining cases of polymorphic recursion. + /// The number of `DefId`s is finite, so checking history is enough + /// to ensure that we do not loop endlessly while inlining. + history: Vec<DefId>, + /// How many (multi-call) callsites have we inlined for the top-level call? + /// + /// We need to limit this in order to prevent super-linear growth in MIR size. + top_down_counter: usize, + /// Indicates that the caller body has been modified. + changed: bool, + /// Indicates that the caller is #[inline] and just calls another function, + /// and thus we can inline less into it as it'll be inlined itself. + caller_is_inline_forwarder: bool, +} + +impl<'tcx> NormalInliner<'tcx> { + fn past_depth_limit(&self) -> bool { + self.history.len() > HISTORY_DEPTH_LIMIT || self.top_down_counter > TOP_DOWN_DEPTH_LIMIT + } +} + +impl<'tcx> Inliner<'tcx> for NormalInliner<'tcx> { + fn new(tcx: TyCtxt<'tcx>, def_id: DefId, body: &Body<'tcx>) -> Self { + let typing_env = body.typing_env(tcx); + let codegen_fn_attrs = tcx.codegen_fn_attrs(def_id); + + Self { + tcx, + typing_env, + def_id, + history: Vec::new(), + top_down_counter: 0, + changed: false, + caller_is_inline_forwarder: matches!( + codegen_fn_attrs.inline, + InlineAttr::Hint | InlineAttr::Always | InlineAttr::Force { .. } + ) && body_is_forwarder(body), + } + } + + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn caller_def_id(&self) -> DefId { + self.def_id + } + + fn typing_env(&self) -> ty::TypingEnv<'tcx> { + self.typing_env + } + + fn history(&self) -> &[DefId] { + &self.history + } + + fn changed(self) -> bool { + self.changed + } + + fn should_inline_for_callee(&self, _: DefId) -> bool { + true + } + + fn check_codegen_attributes_extra( + &self, + callee_attrs: &CodegenFnAttrs, + ) -> Result<(), &'static str> { + if self.past_depth_limit() && matches!(callee_attrs.inline, InlineAttr::None) { + Err("Past depth limit so not inspecting unmarked callee") + } else { + Ok(()) + } + } + + fn check_caller_mir_body(&self, body: &Body<'tcx>) -> bool { + // Avoid inlining into coroutines, since their `optimized_mir` is used for layout computation, + // which can create a cycle, even when no attempt is made to inline the function in the other + // direction. + if body.coroutine.is_some() { + return false; + } + + true + } + + #[instrument(level = "debug", skip(self, callee_body))] + fn check_callee_mir_body( + &self, + callsite: &CallSite<'tcx>, + callee_body: &Body<'tcx>, + callee_attrs: &CodegenFnAttrs, + ) -> Result<(), &'static str> { + let tcx = self.tcx(); + + if let Some(_) = callee_body.tainted_by_errors { + return Err("body has errors"); + } + + if self.past_depth_limit() && callee_body.basic_blocks.len() > 1 { + return Err("Not inlining multi-block body as we're past a depth limit"); + } + + let mut threshold = if self.caller_is_inline_forwarder || self.past_depth_limit() { + tcx.sess.opts.unstable_opts.inline_mir_forwarder_threshold.unwrap_or(30) + } else if tcx.cross_crate_inlinable(callsite.callee.def_id()) { + tcx.sess.opts.unstable_opts.inline_mir_hint_threshold.unwrap_or(100) + } else { + tcx.sess.opts.unstable_opts.inline_mir_threshold.unwrap_or(50) + }; + + // Give a bonus functions with a small number of blocks, + // We normally have two or three blocks for even + // very small functions. + if callee_body.basic_blocks.len() <= 3 { + threshold += threshold / 4; + } + debug!(" final inline threshold = {}", threshold); + + // FIXME: Give a bonus to functions with only a single caller + + let mut checker = + CostChecker::new(tcx, self.typing_env(), Some(callsite.callee), callee_body); + + checker.add_function_level_costs(); + + // Traverse the MIR manually so we can account for the effects of inlining on the CFG. + let mut work_list = vec![START_BLOCK]; + let mut visited = DenseBitSet::new_empty(callee_body.basic_blocks.len()); + while let Some(bb) = work_list.pop() { + if !visited.insert(bb.index()) { + continue; + } + + let blk = &callee_body.basic_blocks[bb]; + checker.visit_basic_block_data(bb, blk); + + let term = blk.terminator(); + let caller_attrs = tcx.codegen_fn_attrs(self.caller_def_id()); + if let TerminatorKind::Drop { + ref place, + target, + unwind, + replace: _, + drop: _, + async_fut: _, + } = term.kind + { + work_list.push(target); + + // If the place doesn't actually need dropping, treat it like a regular goto. + let ty = callsite + .callee + .instantiate_mir(tcx, ty::EarlyBinder::bind(&place.ty(callee_body, tcx).ty)); + if ty.needs_drop(tcx, self.typing_env()) + && let UnwindAction::Cleanup(unwind) = unwind + { + work_list.push(unwind); + } + } else if callee_attrs.instruction_set != caller_attrs.instruction_set + && matches!(term.kind, TerminatorKind::InlineAsm { .. }) + { + // During the attribute checking stage we allow a callee with no + // instruction_set assigned to count as compatible with a function that does + // assign one. However, during this stage we require an exact match when any + // inline-asm is detected. LLVM will still possibly do an inline later on + // if the no-attribute function ends up with the same instruction set anyway. + return Err("cannot move inline-asm across instruction sets"); + } else if let TerminatorKind::TailCall { .. } = term.kind { + // FIXME(explicit_tail_calls): figure out how exactly functions containing tail + // calls can be inlined (and if they even should) + return Err("can't inline functions with tail calls"); + } else { + work_list.extend(term.successors()) + } + } + + // N.B. We still apply our cost threshold to #[inline(always)] functions. + // That attribute is often applied to very large functions that exceed LLVM's (very + // generous) inlining threshold. Such functions are very poor MIR inlining candidates. + // Always inlining #[inline(always)] functions in MIR, on net, slows down the compiler. + let cost = checker.cost(); + if cost <= threshold { + debug!("INLINING {:?} [cost={} <= threshold={}]", callsite, cost, threshold); + Ok(()) + } else { + debug!("NOT inlining {:?} [cost={} > threshold={}]", callsite, cost, threshold); + Err("cost above threshold") + } + } + + fn on_inline_success( + &mut self, + callsite: &CallSite<'tcx>, + caller_body: &mut Body<'tcx>, + new_blocks: std::ops::Range<BasicBlock>, + ) { + self.changed = true; + + let new_calls_count = new_blocks + .clone() + .filter(|&bb| is_call_like(caller_body.basic_blocks[bb].terminator())) + .count(); + if new_calls_count > 1 { + self.top_down_counter += 1; + } + + self.history.push(callsite.callee.def_id()); + process_blocks(self, caller_body, new_blocks); + self.history.pop(); + + if self.history.is_empty() { + self.top_down_counter = 0; + } + } + + fn on_inline_failure(&self, _: &CallSite<'tcx>, _: &'static str) {} +} + +fn inline<'tcx, T: Inliner<'tcx>>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { + let def_id = body.source.def_id(); + + // Only do inlining into fn bodies. + if !tcx.hir_body_owner_kind(def_id).is_fn_or_closure() { + return false; + } + + let mut inliner = T::new(tcx, def_id, body); + if !inliner.check_caller_mir_body(body) { + return false; + } + + let blocks = START_BLOCK..body.basic_blocks.next_index(); + process_blocks(&mut inliner, body, blocks); + inliner.changed() +} + +fn process_blocks<'tcx, I: Inliner<'tcx>>( + inliner: &mut I, + caller_body: &mut Body<'tcx>, + blocks: Range<BasicBlock>, +) { + for bb in blocks { + let bb_data = &caller_body[bb]; + if bb_data.is_cleanup { + continue; + } + + let Some(callsite) = resolve_callsite(inliner, caller_body, bb, bb_data) else { + continue; + }; + + let span = trace_span!("process_blocks", %callsite.callee, ?bb); + let _guard = span.enter(); + + match try_inlining(inliner, caller_body, &callsite) { + Err(reason) => { + debug!("not-inlined {} [{}]", callsite.callee, reason); + inliner.on_inline_failure(&callsite, reason); + } + Ok(new_blocks) => { + debug!("inlined {}", callsite.callee); + inliner.on_inline_success(&callsite, caller_body, new_blocks); + } + } + } +} + +fn resolve_callsite<'tcx, I: Inliner<'tcx>>( + inliner: &I, + caller_body: &Body<'tcx>, + bb: BasicBlock, + bb_data: &BasicBlockData<'tcx>, +) -> Option<CallSite<'tcx>> { + let tcx = inliner.tcx(); + // Only consider direct calls to functions + let terminator = bb_data.terminator(); + + // FIXME(explicit_tail_calls): figure out if we can inline tail calls + if let TerminatorKind::Call { ref func, fn_span, .. } = terminator.kind { + let func_ty = func.ty(caller_body, tcx); + if let ty::FnDef(def_id, args) = *func_ty.kind() { + if !inliner.should_inline_for_callee(def_id) { + debug!("not enabled"); + return None; + } + + // To resolve an instance its args have to be fully normalized. + let args = tcx.try_normalize_erasing_regions(inliner.typing_env(), args).ok()?; + let callee = + Instance::try_resolve(tcx, inliner.typing_env(), def_id, args).ok().flatten()?; + + if let InstanceKind::Virtual(..) | InstanceKind::Intrinsic(_) = callee.def { + return None; + } + + if inliner.history().contains(&callee.def_id()) { + return None; + } + + let fn_sig = tcx.fn_sig(def_id).instantiate(tcx, args); + + // Additionally, check that the body that we're inlining actually agrees + // with the ABI of the trait that the item comes from. + if let InstanceKind::Item(instance_def_id) = callee.def + && tcx.def_kind(instance_def_id) == DefKind::AssocFn + && let instance_fn_sig = tcx.fn_sig(instance_def_id).skip_binder() + && instance_fn_sig.abi() != fn_sig.abi() + { + return None; + } + + let source_info = SourceInfo { span: fn_span, ..terminator.source_info }; + + return Some(CallSite { callee, fn_sig, block: bb, source_info }); + } + } + + None +} + +/// Attempts to inline a callsite into the caller body. When successful returns basic blocks +/// containing the inlined body. Otherwise returns an error describing why inlining didn't take +/// place. +fn try_inlining<'tcx, I: Inliner<'tcx>>( + inliner: &I, + caller_body: &mut Body<'tcx>, + callsite: &CallSite<'tcx>, +) -> Result<std::ops::Range<BasicBlock>, &'static str> { + let tcx = inliner.tcx(); + check_mir_is_available(inliner, caller_body, callsite.callee)?; + + let callee_attrs = tcx.codegen_fn_attrs(callsite.callee.def_id()); + check_inline::is_inline_valid_on_fn(tcx, callsite.callee.def_id())?; + check_codegen_attributes(inliner, callsite, callee_attrs)?; + inliner.check_codegen_attributes_extra(callee_attrs)?; + + let terminator = caller_body[callsite.block].terminator.as_ref().unwrap(); + let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() }; + let destination_ty = destination.ty(&caller_body.local_decls, tcx).ty; + for arg in args { + if !arg.node.ty(&caller_body.local_decls, tcx).is_sized(tcx, inliner.typing_env()) { + // We do not allow inlining functions with unsized params. Inlining these functions + // could create unsized locals, which are unsound and being phased out. + return Err("call has unsized argument"); + } + } + + let callee_body = try_instance_mir(tcx, callsite.callee.def)?; + check_inline::is_inline_valid_on_body(tcx, callee_body)?; + inliner.check_callee_mir_body(callsite, callee_body, callee_attrs)?; + + let Ok(callee_body) = callsite.callee.try_instantiate_mir_and_normalize_erasing_regions( + tcx, + inliner.typing_env(), + ty::EarlyBinder::bind(callee_body.clone()), + ) else { + debug!("failed to normalize callee body"); + return Err("implementation limitation -- could not normalize callee body"); + }; + + // Normally, this shouldn't be required, but trait normalization failure can create a + // validation ICE. + if !validate_types(tcx, inliner.typing_env(), &callee_body, &caller_body).is_empty() { + debug!("failed to validate callee body"); + return Err("implementation limitation -- callee body failed validation"); + } + + // Check call signature compatibility. + // Normally, this shouldn't be required, but trait normalization failure can create a + // validation ICE. + let output_type = callee_body.return_ty(); + if !util::sub_types(tcx, inliner.typing_env(), output_type, destination_ty) { + trace!(?output_type, ?destination_ty); + return Err("implementation limitation -- return type mismatch"); + } + if callsite.fn_sig.abi() == ExternAbi::RustCall { + let (self_arg, arg_tuple) = match &args[..] { + [arg_tuple] => (None, arg_tuple), + [self_arg, arg_tuple] => (Some(self_arg), arg_tuple), + _ => bug!("Expected `rust-call` to have 1 or 2 args"), + }; + + let self_arg_ty = self_arg.map(|self_arg| self_arg.node.ty(&caller_body.local_decls, tcx)); + + let arg_tuple_ty = arg_tuple.node.ty(&caller_body.local_decls, tcx); + let arg_tys = if callee_body.spread_arg.is_some() { + std::slice::from_ref(&arg_tuple_ty) + } else { + let ty::Tuple(arg_tuple_tys) = *arg_tuple_ty.kind() else { + bug!("Closure arguments are not passed as a tuple"); + }; + arg_tuple_tys.as_slice() + }; + + for (arg_ty, input) in + self_arg_ty.into_iter().chain(arg_tys.iter().copied()).zip(callee_body.args_iter()) + { + let input_type = callee_body.local_decls[input].ty; + if !util::sub_types(tcx, inliner.typing_env(), input_type, arg_ty) { + trace!(?arg_ty, ?input_type); + debug!("failed to normalize tuple argument type"); + return Err("implementation limitation"); + } + } + } else { + for (arg, input) in args.iter().zip(callee_body.args_iter()) { + let input_type = callee_body.local_decls[input].ty; + let arg_ty = arg.node.ty(&caller_body.local_decls, tcx); + if !util::sub_types(tcx, inliner.typing_env(), input_type, arg_ty) { + trace!(?arg_ty, ?input_type); + debug!("failed to normalize argument type"); + return Err("implementation limitation -- arg mismatch"); + } + } + } + + let old_blocks = caller_body.basic_blocks.next_index(); + inline_call(inliner, caller_body, callsite, callee_body); + let new_blocks = old_blocks..caller_body.basic_blocks.next_index(); + + Ok(new_blocks) +} + +fn check_mir_is_available<'tcx, I: Inliner<'tcx>>( + inliner: &I, + caller_body: &Body<'tcx>, + callee: Instance<'tcx>, +) -> Result<(), &'static str> { + let caller_def_id = caller_body.source.def_id(); + let callee_def_id = callee.def_id(); + if callee_def_id == caller_def_id { + return Err("self-recursion"); + } + + match callee.def { + InstanceKind::Item(_) => { + // If there is no MIR available (either because it was not in metadata or + // because it has no MIR because it's an extern function), then the inliner + // won't cause cycles on this. + if !inliner.tcx().is_mir_available(callee_def_id) { + debug!("item MIR unavailable"); + return Err("implementation limitation -- MIR unavailable"); + } + } + // These have no own callable MIR. + InstanceKind::Intrinsic(_) | InstanceKind::Virtual(..) => { + debug!("instance without MIR (intrinsic / virtual)"); + return Err("implementation limitation -- cannot inline intrinsic"); + } + + // FIXME(#127030): `ConstParamHasTy` has bad interactions with + // the drop shim builder, which does not evaluate predicates in + // the correct param-env for types being dropped. Stall resolving + // the MIR for this instance until all of its const params are + // substituted. + InstanceKind::DropGlue(_, Some(ty)) if ty.has_type_flags(TypeFlags::HAS_CT_PARAM) => { + debug!("still needs substitution"); + return Err("implementation limitation -- HACK for dropping polymorphic type"); + } + InstanceKind::AsyncDropGlue(_, ty) | InstanceKind::AsyncDropGlueCtorShim(_, ty) => { + return if ty.still_further_specializable() { + Err("still needs substitution") + } else { + Ok(()) + }; + } + InstanceKind::FutureDropPollShim(_, ty, ty2) => { + return if ty.still_further_specializable() || ty2.still_further_specializable() { + Err("still needs substitution") + } else { + Ok(()) + }; + } + + // This cannot result in an immediate cycle since the callee MIR is a shim, which does + // not get any optimizations run on it. Any subsequent inlining may cause cycles, but we + // do not need to catch this here, we can wait until the inliner decides to continue + // inlining a second time. + InstanceKind::VTableShim(_) + | InstanceKind::ReifyShim(..) + | InstanceKind::FnPtrShim(..) + | InstanceKind::ClosureOnceShim { .. } + | InstanceKind::ConstructCoroutineInClosureShim { .. } + | InstanceKind::DropGlue(..) + | InstanceKind::CloneShim(..) + | InstanceKind::ThreadLocalShim(..) + | InstanceKind::FnPtrAddrShim(..) => return Ok(()), + } + + if inliner.tcx().is_constructor(callee_def_id) { + trace!("constructors always have MIR"); + // Constructor functions cannot cause a query cycle. + return Ok(()); + } + + if callee_def_id.is_local() + && !inliner + .tcx() + .is_lang_item(inliner.tcx().parent(caller_def_id), rustc_hir::LangItem::FnOnce) + { + // If we know for sure that the function we're calling will itself try to + // call us, then we avoid inlining that function. + if inliner.tcx().mir_callgraph_reachable((callee, caller_def_id.expect_local())) { + debug!("query cycle avoidance"); + return Err("caller might be reachable from callee"); + } + + Ok(()) + } else { + // This cannot result in an immediate cycle since the callee MIR is from another crate + // and is already optimized. Any subsequent inlining may cause cycles, but we do + // not need to catch this here, we can wait until the inliner decides to continue + // inlining a second time. + trace!("functions from other crates always have MIR"); + Ok(()) + } +} + +/// Returns an error if inlining is not possible based on codegen attributes alone. A success +/// indicates that inlining decision should be based on other criteria. +fn check_codegen_attributes<'tcx, I: Inliner<'tcx>>( + inliner: &I, + callsite: &CallSite<'tcx>, + callee_attrs: &CodegenFnAttrs, +) -> Result<(), &'static str> { + let tcx = inliner.tcx(); + if let InlineAttr::Never = callee_attrs.inline { + return Err("never inline attribute"); + } + + if let OptimizeAttr::DoNotOptimize = callee_attrs.optimize { + return Err("has DoNotOptimize attribute"); + } + + inliner.check_codegen_attributes_extra(callee_attrs)?; + + // Reachability pass defines which functions are eligible for inlining. Generally inlining + // other functions is incorrect because they could reference symbols that aren't exported. + let is_generic = callsite.callee.args.non_erasable_generics().next().is_some(); + if !is_generic && !tcx.cross_crate_inlinable(callsite.callee.def_id()) { + return Err("not exported"); + } + + let codegen_fn_attrs = tcx.codegen_fn_attrs(inliner.caller_def_id()); + if callee_attrs.no_sanitize != codegen_fn_attrs.no_sanitize { + return Err("incompatible sanitizer set"); + } + + // Two functions are compatible if the callee has no attribute (meaning + // that it's codegen agnostic), or sets an attribute that is identical + // to this function's attribute. + if callee_attrs.instruction_set.is_some() + && callee_attrs.instruction_set != codegen_fn_attrs.instruction_set + { + return Err("incompatible instruction set"); + } + + let callee_feature_names = callee_attrs.target_features.iter().map(|f| f.name); + let this_feature_names = codegen_fn_attrs.target_features.iter().map(|f| f.name); + if callee_feature_names.ne(this_feature_names) { + // In general it is not correct to inline a callee with target features that are a + // subset of the caller. This is because the callee might contain calls, and the ABI of + // those calls depends on the target features of the surrounding function. By moving a + // `Call` terminator from one MIR body to another with more target features, we might + // change the ABI of that call! + return Err("incompatible target features"); + } + + Ok(()) +} + +fn inline_call<'tcx, I: Inliner<'tcx>>( + inliner: &I, + caller_body: &mut Body<'tcx>, + callsite: &CallSite<'tcx>, + mut callee_body: Body<'tcx>, +) { + let tcx = inliner.tcx(); + let terminator = caller_body[callsite.block].terminator.take().unwrap(); + let TerminatorKind::Call { func, args, destination, unwind, target, .. } = terminator.kind + else { + bug!("unexpected terminator kind {:?}", terminator.kind); + }; + + let return_block = if let Some(block) = target { + // Prepare a new block for code that should execute when call returns. We don't use + // target block directly since it might have other predecessors. + let data = BasicBlockData::new( + Some(Terminator { + source_info: terminator.source_info, + kind: TerminatorKind::Goto { target: block }, + }), + caller_body[block].is_cleanup, + ); + Some(caller_body.basic_blocks_mut().push(data)) + } else { + None + }; + + // If the call is something like `a[*i] = f(i)`, where + // `i : &mut usize`, then just duplicating the `a[*i]` + // Place could result in two different locations if `f` + // writes to `i`. To prevent this we need to create a temporary + // borrow of the place and pass the destination as `*temp` instead. + fn dest_needs_borrow(place: Place<'_>) -> bool { + for elem in place.projection.iter() { + match elem { + ProjectionElem::Deref | ProjectionElem::Index(_) => return true, + _ => {} + } + } + + false + } + + let dest = if dest_needs_borrow(destination) { + trace!("creating temp for return destination"); + let dest = Rvalue::Ref( + tcx.lifetimes.re_erased, + BorrowKind::Mut { kind: MutBorrowKind::Default }, + destination, + ); + let dest_ty = dest.ty(caller_body, tcx); + let temp = Place::from(new_call_temp(caller_body, callsite, dest_ty, return_block)); + caller_body[callsite.block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::Assign(Box::new((temp, dest))), + }); + tcx.mk_place_deref(temp) + } else { + destination + }; + + // Always create a local to hold the destination, as `RETURN_PLACE` may appear + // where a full `Place` is not allowed. + let (remap_destination, destination_local) = if let Some(d) = dest.as_local() { + (false, d) + } else { + ( + true, + new_call_temp(caller_body, callsite, destination.ty(caller_body, tcx).ty, return_block), + ) + }; + + // Copy the arguments if needed. + let args = make_call_args(inliner, args, callsite, caller_body, &callee_body, return_block); + + let mut integrator = Integrator { + args: &args, + new_locals: caller_body.local_decls.next_index().., + new_scopes: caller_body.source_scopes.next_index().., + new_blocks: caller_body.basic_blocks.next_index().., + destination: destination_local, + callsite_scope: caller_body.source_scopes[callsite.source_info.scope].clone(), + callsite, + cleanup_block: unwind, + in_cleanup_block: false, + return_block, + tcx, + always_live_locals: DenseBitSet::new_filled(callee_body.local_decls.len()), + }; + + // Map all `Local`s, `SourceScope`s and `BasicBlock`s to new ones + // (or existing ones, in a few special cases) in the caller. + integrator.visit_body(&mut callee_body); + + // If there are any locals without storage markers, give them storage only for the + // duration of the call. + for local in callee_body.vars_and_temps_iter() { + if integrator.always_live_locals.contains(local) { + let new_local = integrator.map_local(local); + caller_body[callsite.block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::StorageLive(new_local), + }); + } + } + if let Some(block) = return_block { + // To avoid repeated O(n) insert, push any new statements to the end and rotate + // the slice once. + let mut n = 0; + if remap_destination { + caller_body[block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::Assign(Box::new(( + dest, + Rvalue::Use(Operand::Move(destination_local.into())), + ))), + }); + n += 1; + } + for local in callee_body.vars_and_temps_iter().rev() { + if integrator.always_live_locals.contains(local) { + let new_local = integrator.map_local(local); + caller_body[block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::StorageDead(new_local), + }); + n += 1; + } + } + caller_body[block].statements.rotate_right(n); + } + + // Insert all of the (mapped) parts of the callee body into the caller. + caller_body.local_decls.extend(callee_body.drain_vars_and_temps()); + caller_body.source_scopes.append(&mut callee_body.source_scopes); + if tcx + .sess + .opts + .unstable_opts + .inline_mir_preserve_debug + .unwrap_or(tcx.sess.opts.debuginfo != DebugInfo::None) + { + // Note that we need to preserve these in the standard library so that + // people working on rust can build with or without debuginfo while + // still getting consistent results from the mir-opt tests. + caller_body.var_debug_info.append(&mut callee_body.var_debug_info); + } + caller_body.basic_blocks_mut().append(callee_body.basic_blocks_mut()); + + caller_body[callsite.block].terminator = Some(Terminator { + source_info: callsite.source_info, + kind: TerminatorKind::Goto { target: integrator.map_block(START_BLOCK) }, + }); + + // Copy required constants from the callee_body into the caller_body. Although we are only + // pushing unevaluated consts to `required_consts`, here they may have been evaluated + // because we are calling `instantiate_and_normalize_erasing_regions` -- so we filter again. + caller_body.required_consts.as_mut().unwrap().extend( + callee_body.required_consts().into_iter().filter(|ct| ct.const_.is_required_const()), + ); + // Now that we incorporated the callee's `required_consts`, we can remove the callee from + // `mentioned_items` -- but we have to take their `mentioned_items` in return. This does + // some extra work here to save the monomorphization collector work later. It helps a lot, + // since monomorphization can avoid a lot of work when the "mentioned items" are similar to + // the actually used items. By doing this we can entirely avoid visiting the callee! + // We need to reconstruct the `required_item` for the callee so that we can find and + // remove it. + let callee_item = MentionedItem::Fn(func.ty(caller_body, tcx)); + let caller_mentioned_items = caller_body.mentioned_items.as_mut().unwrap(); + if let Some(idx) = caller_mentioned_items.iter().position(|item| item.node == callee_item) { + // We found the callee, so remove it and add its items instead. + caller_mentioned_items.remove(idx); + caller_mentioned_items.extend(callee_body.mentioned_items()); + } else { + // If we can't find the callee, there's no point in adding its items. Probably it + // already got removed by being inlined elsewhere in the same function, so we already + // took its items. + } +} + +fn make_call_args<'tcx, I: Inliner<'tcx>>( + inliner: &I, + args: Box<[Spanned<Operand<'tcx>>]>, + callsite: &CallSite<'tcx>, + caller_body: &mut Body<'tcx>, + callee_body: &Body<'tcx>, + return_block: Option<BasicBlock>, +) -> Box<[Local]> { + let tcx = inliner.tcx(); + + // There is a bit of a mismatch between the *caller* of a closure and the *callee*. + // The caller provides the arguments wrapped up in a tuple: + // + // tuple_tmp = (a, b, c) + // Fn::call(closure_ref, tuple_tmp) + // + // meanwhile the closure body expects the arguments (here, `a`, `b`, and `c`) + // as distinct arguments. (This is the "rust-call" ABI hack.) Normally, codegen has + // the job of unpacking this tuple. But here, we are codegen. =) So we want to create + // a vector like + // + // [closure_ref, tuple_tmp.0, tuple_tmp.1, tuple_tmp.2] + // + // Except for one tiny wrinkle: we don't actually want `tuple_tmp.0`. It's more convenient + // if we "spill" that into *another* temporary, so that we can map the argument + // variable in the callee MIR directly to an argument variable on our side. + // So we introduce temporaries like: + // + // tmp0 = tuple_tmp.0 + // tmp1 = tuple_tmp.1 + // tmp2 = tuple_tmp.2 + // + // and the vector is `[closure_ref, tmp0, tmp1, tmp2]`. + if callsite.fn_sig.abi() == ExternAbi::RustCall && callee_body.spread_arg.is_none() { + // FIXME(edition_2024): switch back to a normal method call. + let mut args = <_>::into_iter(args); + let self_ = create_temp_if_necessary( + inliner, + args.next().unwrap().node, + callsite, + caller_body, + return_block, + ); + let tuple = create_temp_if_necessary( + inliner, + args.next().unwrap().node, + callsite, + caller_body, + return_block, + ); + assert!(args.next().is_none()); + + let tuple = Place::from(tuple); + let ty::Tuple(tuple_tys) = tuple.ty(caller_body, tcx).ty.kind() else { + bug!("Closure arguments are not passed as a tuple"); + }; + + // The `closure_ref` in our example above. + let closure_ref_arg = iter::once(self_); + + // The `tmp0`, `tmp1`, and `tmp2` in our example above. + let tuple_tmp_args = tuple_tys.iter().enumerate().map(|(i, ty)| { + // This is e.g., `tuple_tmp.0` in our example above. + let tuple_field = Operand::Move(tcx.mk_place_field(tuple, FieldIdx::new(i), ty)); + + // Spill to a local to make e.g., `tmp0`. + create_temp_if_necessary(inliner, tuple_field, callsite, caller_body, return_block) + }); + + closure_ref_arg.chain(tuple_tmp_args).collect() + } else { + args.into_iter() + .map(|a| create_temp_if_necessary(inliner, a.node, callsite, caller_body, return_block)) + .collect() + } +} + +/// If `arg` is already a temporary, returns it. Otherwise, introduces a fresh temporary `T` and an +/// instruction `T = arg`, and returns `T`. +fn create_temp_if_necessary<'tcx, I: Inliner<'tcx>>( + inliner: &I, + arg: Operand<'tcx>, + callsite: &CallSite<'tcx>, + caller_body: &mut Body<'tcx>, + return_block: Option<BasicBlock>, +) -> Local { + // Reuse the operand if it is a moved temporary. + if let Operand::Move(place) = &arg + && let Some(local) = place.as_local() + && caller_body.local_kind(local) == LocalKind::Temp + { + return local; + } + + // Otherwise, create a temporary for the argument. + trace!("creating temp for argument {:?}", arg); + let arg_ty = arg.ty(caller_body, inliner.tcx()); + let local = new_call_temp(caller_body, callsite, arg_ty, return_block); + caller_body[callsite.block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::Assign(Box::new((Place::from(local), Rvalue::Use(arg)))), + }); + local +} + +/// Introduces a new temporary into the caller body that is live for the duration of the call. +fn new_call_temp<'tcx>( + caller_body: &mut Body<'tcx>, + callsite: &CallSite<'tcx>, + ty: Ty<'tcx>, + return_block: Option<BasicBlock>, +) -> Local { + let local = caller_body.local_decls.push(LocalDecl::new(ty, callsite.source_info.span)); + + caller_body[callsite.block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::StorageLive(local), + }); + + if let Some(block) = return_block { + caller_body[block].statements.insert( + 0, + Statement { + source_info: callsite.source_info, + kind: StatementKind::StorageDead(local), + }, + ); + } + + local +} + +/** + * Integrator. + * + * Integrates blocks from the callee function into the calling function. + * Updates block indices, references to locals and other control flow + * stuff. +*/ +struct Integrator<'a, 'tcx> { + args: &'a [Local], + new_locals: RangeFrom<Local>, + new_scopes: RangeFrom<SourceScope>, + new_blocks: RangeFrom<BasicBlock>, + destination: Local, + callsite_scope: SourceScopeData<'tcx>, + callsite: &'a CallSite<'tcx>, + cleanup_block: UnwindAction, + in_cleanup_block: bool, + return_block: Option<BasicBlock>, + tcx: TyCtxt<'tcx>, + always_live_locals: DenseBitSet<Local>, +} + +impl Integrator<'_, '_> { + fn map_local(&self, local: Local) -> Local { + let new = if local == RETURN_PLACE { + self.destination + } else { + let idx = local.index() - 1; + if idx < self.args.len() { + self.args[idx] + } else { + self.new_locals.start + (idx - self.args.len()) + } + }; + trace!("mapping local `{:?}` to `{:?}`", local, new); + new + } + + fn map_scope(&self, scope: SourceScope) -> SourceScope { + let new = self.new_scopes.start + scope.index(); + trace!("mapping scope `{:?}` to `{:?}`", scope, new); + new + } + + fn map_block(&self, block: BasicBlock) -> BasicBlock { + let new = self.new_blocks.start + block.index(); + trace!("mapping block `{:?}` to `{:?}`", block, new); + new + } + + fn map_unwind(&self, unwind: UnwindAction) -> UnwindAction { + if self.in_cleanup_block { + match unwind { + UnwindAction::Cleanup(_) | UnwindAction::Continue => { + bug!("cleanup on cleanup block"); + } + UnwindAction::Unreachable | UnwindAction::Terminate(_) => return unwind, + } + } + + match unwind { + UnwindAction::Unreachable | UnwindAction::Terminate(_) => unwind, + UnwindAction::Cleanup(target) => UnwindAction::Cleanup(self.map_block(target)), + // Add an unwind edge to the original call's cleanup block + UnwindAction::Continue => self.cleanup_block, + } + } +} + +impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _ctxt: PlaceContext, _location: Location) { + *local = self.map_local(*local); + } + + fn visit_source_scope_data(&mut self, scope_data: &mut SourceScopeData<'tcx>) { + self.super_source_scope_data(scope_data); + if scope_data.parent_scope.is_none() { + // Attach the outermost callee scope as a child of the callsite + // scope, via the `parent_scope` and `inlined_parent_scope` chains. + scope_data.parent_scope = Some(self.callsite.source_info.scope); + assert_eq!(scope_data.inlined_parent_scope, None); + scope_data.inlined_parent_scope = if self.callsite_scope.inlined.is_some() { + Some(self.callsite.source_info.scope) + } else { + self.callsite_scope.inlined_parent_scope + }; + + // Mark the outermost callee scope as an inlined one. + assert_eq!(scope_data.inlined, None); + scope_data.inlined = Some((self.callsite.callee, self.callsite.source_info.span)); + } else if scope_data.inlined_parent_scope.is_none() { + // Make it easy to find the scope with `inlined` set above. + scope_data.inlined_parent_scope = Some(self.map_scope(OUTERMOST_SOURCE_SCOPE)); + } + } + + fn visit_source_scope(&mut self, scope: &mut SourceScope) { + *scope = self.map_scope(*scope); + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { + self.in_cleanup_block = data.is_cleanup; + self.super_basic_block_data(block, data); + self.in_cleanup_block = false; + } + + fn visit_retag(&mut self, kind: &mut RetagKind, place: &mut Place<'tcx>, loc: Location) { + self.super_retag(kind, place, loc); + + // We have to patch all inlined retags to be aware that they are no longer + // happening on function entry. + if *kind == RetagKind::FnEntry { + *kind = RetagKind::Default; + } + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + if let StatementKind::StorageLive(local) | StatementKind::StorageDead(local) = + statement.kind + { + self.always_live_locals.remove(local); + } + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, loc: Location) { + // Don't try to modify the implicit `_0` access on return (`return` terminators are + // replaced down below anyways). + if !matches!(terminator.kind, TerminatorKind::Return) { + self.super_terminator(terminator, loc); + } else { + self.visit_source_info(&mut terminator.source_info); + } + + match terminator.kind { + TerminatorKind::CoroutineDrop | TerminatorKind::Yield { .. } => bug!(), + TerminatorKind::Goto { ref mut target } => { + *target = self.map_block(*target); + } + TerminatorKind::SwitchInt { ref mut targets, .. } => { + for tgt in targets.all_targets_mut() { + *tgt = self.map_block(*tgt); + } + } + TerminatorKind::Drop { ref mut target, ref mut unwind, .. } => { + *target = self.map_block(*target); + *unwind = self.map_unwind(*unwind); + } + TerminatorKind::TailCall { .. } => { + // check_mir_body forbids tail calls + unreachable!() + } + TerminatorKind::Call { ref mut target, ref mut unwind, .. } => { + if let Some(ref mut tgt) = *target { + *tgt = self.map_block(*tgt); + } + *unwind = self.map_unwind(*unwind); + } + TerminatorKind::Assert { ref mut target, ref mut unwind, .. } => { + *target = self.map_block(*target); + *unwind = self.map_unwind(*unwind); + } + TerminatorKind::Return => { + terminator.kind = if let Some(tgt) = self.return_block { + TerminatorKind::Goto { target: tgt } + } else { + TerminatorKind::Unreachable + } + } + TerminatorKind::UnwindResume => { + terminator.kind = match self.cleanup_block { + UnwindAction::Cleanup(tgt) => TerminatorKind::Goto { target: tgt }, + UnwindAction::Continue => TerminatorKind::UnwindResume, + UnwindAction::Unreachable => TerminatorKind::Unreachable, + UnwindAction::Terminate(reason) => TerminatorKind::UnwindTerminate(reason), + }; + } + TerminatorKind::UnwindTerminate(_) => {} + TerminatorKind::Unreachable => {} + TerminatorKind::FalseEdge { ref mut real_target, ref mut imaginary_target } => { + *real_target = self.map_block(*real_target); + *imaginary_target = self.map_block(*imaginary_target); + } + TerminatorKind::FalseUnwind { real_target: _, unwind: _ } => + // see the ordering of passes in the optimized_mir query. + { + bug!("False unwinds should have been removed before inlining") + } + TerminatorKind::InlineAsm { ref mut targets, ref mut unwind, .. } => { + for tgt in targets.iter_mut() { + *tgt = self.map_block(*tgt); + } + *unwind = self.map_unwind(*unwind); + } + } + } +} + +#[instrument(skip(tcx), level = "debug")] +fn try_instance_mir<'tcx>( + tcx: TyCtxt<'tcx>, + instance: InstanceKind<'tcx>, +) -> Result<&'tcx Body<'tcx>, &'static str> { + if let ty::InstanceKind::DropGlue(_, Some(ty)) | ty::InstanceKind::AsyncDropGlueCtorShim(_, ty) = + instance + && let ty::Adt(def, args) = ty.kind() + { + let fields = def.all_fields(); + for field in fields { + let field_ty = field.ty(tcx, args); + if field_ty.has_param() && field_ty.has_aliases() { + return Err("cannot build drop shim for polymorphic type"); + } + } + } + Ok(tcx.instance_mir(instance)) +} + +fn body_is_forwarder(body: &Body<'_>) -> bool { + let TerminatorKind::Call { target, .. } = body.basic_blocks[START_BLOCK].terminator().kind + else { + return false; + }; + if let Some(target) = target { + let TerminatorKind::Return = body.basic_blocks[target].terminator().kind else { + return false; + }; + } + + let max_blocks = if !body.is_polymorphic { + 2 + } else if target.is_none() { + 3 + } else { + 4 + }; + if body.basic_blocks.len() > max_blocks { + return false; + } + + body.basic_blocks.iter_enumerated().all(|(bb, bb_data)| { + bb == START_BLOCK + || matches!( + bb_data.terminator().kind, + TerminatorKind::Return + | TerminatorKind::Drop { .. } + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + ) + }) +} diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs new file mode 100644 index 00000000000..a944960ce4a --- /dev/null +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -0,0 +1,198 @@ +use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet}; +use rustc_data_structures::stack::ensure_sufficient_stack; +use rustc_hir::def_id::{DefId, LocalDefId}; +use rustc_middle::mir::TerminatorKind; +use rustc_middle::ty::{self, GenericArgsRef, InstanceKind, TyCtxt, TypeVisitableExt}; +use rustc_session::Limit; +use rustc_span::sym; +use tracing::{instrument, trace}; + +// FIXME: check whether it is cheaper to precompute the entire call graph instead of invoking +// this query ridiculously often. +#[instrument(level = "debug", skip(tcx, root, target))] +pub(crate) fn mir_callgraph_reachable<'tcx>( + tcx: TyCtxt<'tcx>, + (root, target): (ty::Instance<'tcx>, LocalDefId), +) -> bool { + trace!(%root, target = %tcx.def_path_str(target)); + assert_ne!( + root.def_id().expect_local(), + target, + "you should not call `mir_callgraph_reachable` on immediate self recursion" + ); + assert!( + matches!(root.def, InstanceKind::Item(_)), + "you should not call `mir_callgraph_reachable` on shims" + ); + assert!( + !tcx.is_constructor(root.def_id()), + "you should not call `mir_callgraph_reachable` on enum/struct constructor functions" + ); + #[instrument( + level = "debug", + skip(tcx, typing_env, target, stack, seen, recursion_limiter, caller, recursion_limit) + )] + fn process<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + caller: ty::Instance<'tcx>, + target: LocalDefId, + stack: &mut Vec<ty::Instance<'tcx>>, + seen: &mut FxHashSet<ty::Instance<'tcx>>, + recursion_limiter: &mut FxHashMap<DefId, usize>, + recursion_limit: Limit, + ) -> bool { + trace!(%caller); + for &(callee, args) in tcx.mir_inliner_callees(caller.def) { + let Ok(args) = caller.try_instantiate_mir_and_normalize_erasing_regions( + tcx, + typing_env, + ty::EarlyBinder::bind(args), + ) else { + trace!(?caller, ?typing_env, ?args, "cannot normalize, skipping"); + continue; + }; + let Ok(Some(callee)) = ty::Instance::try_resolve(tcx, typing_env, callee, args) else { + trace!(?callee, "cannot resolve, skipping"); + continue; + }; + + // Found a path. + if callee.def_id() == target.to_def_id() { + return true; + } + + if tcx.is_constructor(callee.def_id()) { + trace!("constructors always have MIR"); + // Constructor functions cannot cause a query cycle. + continue; + } + + match callee.def { + InstanceKind::Item(_) => { + // If there is no MIR available (either because it was not in metadata or + // because it has no MIR because it's an extern function), then the inliner + // won't cause cycles on this. + if !tcx.is_mir_available(callee.def_id()) { + trace!(?callee, "no mir available, skipping"); + continue; + } + } + // These have no own callable MIR. + InstanceKind::Intrinsic(_) | InstanceKind::Virtual(..) => continue, + // These have MIR and if that MIR is inlined, instantiated and then inlining is run + // again, a function item can end up getting inlined. Thus we'll be able to cause + // a cycle that way + InstanceKind::VTableShim(_) + | InstanceKind::ReifyShim(..) + | InstanceKind::FnPtrShim(..) + | InstanceKind::ClosureOnceShim { .. } + | InstanceKind::ConstructCoroutineInClosureShim { .. } + | InstanceKind::ThreadLocalShim { .. } + | InstanceKind::CloneShim(..) => {} + + // This shim does not call any other functions, thus there can be no recursion. + InstanceKind::FnPtrAddrShim(..) => { + continue; + } + InstanceKind::DropGlue(..) + | InstanceKind::FutureDropPollShim(..) + | InstanceKind::AsyncDropGlue(..) + | InstanceKind::AsyncDropGlueCtorShim(..) => { + // FIXME: A not fully instantiated drop shim can cause ICEs if one attempts to + // have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this + // needs some more analysis. + if callee.has_param() { + continue; + } + } + } + + if seen.insert(callee) { + let recursion = recursion_limiter.entry(callee.def_id()).or_default(); + trace!(?callee, recursion = *recursion); + if recursion_limit.value_within_limit(*recursion) { + *recursion += 1; + stack.push(callee); + let found_recursion = ensure_sufficient_stack(|| { + process( + tcx, + typing_env, + callee, + target, + stack, + seen, + recursion_limiter, + recursion_limit, + ) + }); + if found_recursion { + return true; + } + stack.pop(); + } else { + // Pessimistically assume that there could be recursion. + return true; + } + } + } + false + } + // FIXME(-Znext-solver=no): Remove this hack when trait solver overflow can return an error. + // In code like that pointed out in #128887, the type complexity we ask the solver to deal with + // grows as we recurse into the call graph. If we use the same recursion limit here and in the + // solver, the solver hits the limit first and emits a fatal error. But if we use a reduced + // limit, we will hit the limit first and give up on looking for inlining. And in any case, + // the default recursion limits are quite generous for us. If we need to recurse 64 times + // into the call graph, we're probably not going to find any useful MIR inlining. + let recursion_limit = tcx.recursion_limit() / 2; + process( + tcx, + ty::TypingEnv::post_analysis(tcx, target), + root, + target, + &mut Vec::new(), + &mut FxHashSet::default(), + &mut FxHashMap::default(), + recursion_limit, + ) +} + +pub(crate) fn mir_inliner_callees<'tcx>( + tcx: TyCtxt<'tcx>, + instance: ty::InstanceKind<'tcx>, +) -> &'tcx [(DefId, GenericArgsRef<'tcx>)] { + let steal; + let guard; + let body = match (instance, instance.def_id().as_local()) { + (InstanceKind::Item(_), Some(def_id)) => { + steal = tcx.mir_promoted(def_id).0; + guard = steal.borrow(); + &*guard + } + // Functions from other crates and MIR shims + _ => tcx.instance_mir(instance), + }; + let mut calls = FxIndexSet::default(); + for bb_data in body.basic_blocks.iter() { + let terminator = bb_data.terminator(); + if let TerminatorKind::Call { func, args: call_args, .. } = &terminator.kind { + let ty = func.ty(&body.local_decls, tcx); + let ty::FnDef(def_id, generic_args) = ty.kind() else { + continue; + }; + let call = if tcx.is_intrinsic(*def_id, sym::const_eval_select) { + let func = &call_args[2].node; + let ty = func.ty(&body.local_decls, tcx); + let ty::FnDef(def_id, generic_args) = ty.kind() else { + continue; + }; + (*def_id, *generic_args) + } else { + (*def_id, *generic_args) + }; + calls.insert(call); + } + } + tcx.arena.alloc_from_iter(calls.iter().copied()) +} diff --git a/compiler/rustc_mir_transform/src/instsimplify.rs b/compiler/rustc_mir_transform/src/instsimplify.rs new file mode 100644 index 00000000000..5f0c55ddc09 --- /dev/null +++ b/compiler/rustc_mir_transform/src/instsimplify.rs @@ -0,0 +1,325 @@ +//! Performs various peephole optimizations. + +use rustc_abi::ExternAbi; +use rustc_ast::attr; +use rustc_hir::LangItem; +use rustc_middle::bug; +use rustc_middle::mir::*; +use rustc_middle::ty::layout::ValidityRequirement; +use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, layout}; +use rustc_span::{DUMMY_SP, Symbol, sym}; + +use crate::simplify::simplify_duplicate_switch_targets; + +pub(super) enum InstSimplify { + BeforeInline, + AfterSimplifyCfg, +} + +impl<'tcx> crate::MirPass<'tcx> for InstSimplify { + fn name(&self) -> &'static str { + match self { + InstSimplify::BeforeInline => "InstSimplify-before-inline", + InstSimplify::AfterSimplifyCfg => "InstSimplify-after-simplifycfg", + } + } + + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let ctx = InstSimplifyContext { + tcx, + local_decls: &body.local_decls, + typing_env: body.typing_env(tcx), + }; + let preserve_ub_checks = + attr::contains_name(tcx.hir_krate_attrs(), sym::rustc_preserve_ub_checks); + for block in body.basic_blocks.as_mut() { + for statement in block.statements.iter_mut() { + let StatementKind::Assign(box (.., rvalue)) = &mut statement.kind else { + continue; + }; + + if !preserve_ub_checks { + ctx.simplify_ub_check(rvalue); + } + ctx.simplify_bool_cmp(rvalue); + ctx.simplify_ref_deref(rvalue); + ctx.simplify_ptr_aggregate(rvalue); + ctx.simplify_cast(rvalue); + ctx.simplify_repeated_aggregate(rvalue); + ctx.simplify_repeat_once(rvalue); + } + + let terminator = block.terminator.as_mut().unwrap(); + ctx.simplify_primitive_clone(terminator, &mut block.statements); + ctx.simplify_intrinsic_assert(terminator); + ctx.simplify_nounwind_call(terminator); + simplify_duplicate_switch_targets(terminator); + } + } + + fn is_required(&self) -> bool { + false + } +} + +struct InstSimplifyContext<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + local_decls: &'a LocalDecls<'tcx>, + typing_env: ty::TypingEnv<'tcx>, +} + +impl<'tcx> InstSimplifyContext<'_, 'tcx> { + /// Transform aggregates like [0, 0, 0, 0, 0] into [0; 5]. + /// GVN can also do this optimization, but GVN is only run at mir-opt-level 2 so having this in + /// InstSimplify helps unoptimized builds. + fn simplify_repeated_aggregate(&self, rvalue: &mut Rvalue<'tcx>) { + let Rvalue::Aggregate(box AggregateKind::Array(_), fields) = &*rvalue else { + return; + }; + if fields.len() < 5 { + return; + } + let (first, rest) = fields[..].split_first().unwrap(); + let Operand::Constant(first) = first else { + return; + }; + let Ok(first_val) = first.const_.eval(self.tcx, self.typing_env, first.span) else { + return; + }; + if rest.iter().all(|field| { + let Operand::Constant(field) = field else { + return false; + }; + let field = field.const_.eval(self.tcx, self.typing_env, field.span); + field == Ok(first_val) + }) { + let len = ty::Const::from_target_usize(self.tcx, fields.len().try_into().unwrap()); + *rvalue = Rvalue::Repeat(Operand::Constant(first.clone()), len); + } + } + + /// Transform boolean comparisons into logical operations. + fn simplify_bool_cmp(&self, rvalue: &mut Rvalue<'tcx>) { + let Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), box (a, b)) = &*rvalue else { return }; + *rvalue = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) { + // Transform "Eq(a, true)" ==> "a" + (BinOp::Eq, _, Some(true)) => Rvalue::Use(a.clone()), + + // Transform "Ne(a, false)" ==> "a" + (BinOp::Ne, _, Some(false)) => Rvalue::Use(a.clone()), + + // Transform "Eq(true, b)" ==> "b" + (BinOp::Eq, Some(true), _) => Rvalue::Use(b.clone()), + + // Transform "Ne(false, b)" ==> "b" + (BinOp::Ne, Some(false), _) => Rvalue::Use(b.clone()), + + // Transform "Eq(false, b)" ==> "Not(b)" + (BinOp::Eq, Some(false), _) => Rvalue::UnaryOp(UnOp::Not, b.clone()), + + // Transform "Ne(true, b)" ==> "Not(b)" + (BinOp::Ne, Some(true), _) => Rvalue::UnaryOp(UnOp::Not, b.clone()), + + // Transform "Eq(a, false)" ==> "Not(a)" + (BinOp::Eq, _, Some(false)) => Rvalue::UnaryOp(UnOp::Not, a.clone()), + + // Transform "Ne(a, true)" ==> "Not(a)" + (BinOp::Ne, _, Some(true)) => Rvalue::UnaryOp(UnOp::Not, a.clone()), + + _ => return, + }; + } + + fn try_eval_bool(&self, a: &Operand<'_>) -> Option<bool> { + let a = a.constant()?; + if a.const_.ty().is_bool() { a.const_.try_to_bool() } else { None } + } + + /// Transform `&(*a)` ==> `a`. + fn simplify_ref_deref(&self, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Ref(_, _, place) | Rvalue::RawPtr(_, place) = rvalue + && let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() + && rvalue.ty(self.local_decls, self.tcx) == base.ty(self.local_decls, self.tcx).ty + { + *rvalue = Rvalue::Use(Operand::Copy(Place { + local: base.local, + projection: self.tcx.mk_place_elems(base.projection), + })); + } + } + + /// Transform `Aggregate(RawPtr, [p, ()])` ==> `Cast(PtrToPtr, p)`. + fn simplify_ptr_aggregate(&self, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Aggregate(box AggregateKind::RawPtr(pointee_ty, mutability), fields) = rvalue + && let meta_ty = fields.raw[1].ty(self.local_decls, self.tcx) + && meta_ty.is_unit() + { + // The mutable borrows we're holding prevent printing `rvalue` here + let mut fields = std::mem::take(fields); + let _meta = fields.pop().unwrap(); + let data = fields.pop().unwrap(); + let ptr_ty = Ty::new_ptr(self.tcx, *pointee_ty, *mutability); + *rvalue = Rvalue::Cast(CastKind::PtrToPtr, data, ptr_ty); + } + } + + fn simplify_ub_check(&self, rvalue: &mut Rvalue<'tcx>) { + let Rvalue::NullaryOp(NullOp::UbChecks, _) = *rvalue else { return }; + + let const_ = Const::from_bool(self.tcx, self.tcx.sess.ub_checks()); + let constant = ConstOperand { span: DUMMY_SP, const_, user_ty: None }; + *rvalue = Rvalue::Use(Operand::Constant(Box::new(constant))); + } + + fn simplify_cast(&self, rvalue: &mut Rvalue<'tcx>) { + let Rvalue::Cast(kind, operand, cast_ty) = rvalue else { return }; + + let operand_ty = operand.ty(self.local_decls, self.tcx); + if operand_ty == *cast_ty { + *rvalue = Rvalue::Use(operand.clone()); + } else if *kind == CastKind::Transmute + // Transmuting an integer to another integer is just a signedness cast + && let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) = + (operand_ty.kind(), cast_ty.kind()) + && int.bit_width() == uint.bit_width() + { + // The width check isn't strictly necessary, as different widths + // are UB and thus we'd be allowed to turn it into a cast anyway. + // But let's keep the UB around for codegen to exploit later. + // (If `CastKind::Transmute` ever becomes *not* UB for mismatched sizes, + // then the width check is necessary for big-endian correctness.) + *kind = CastKind::IntToInt; + } + } + + /// Simplify `[x; 1]` to just `[x]`. + fn simplify_repeat_once(&self, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Repeat(operand, count) = rvalue + && let Some(1) = count.try_to_target_usize(self.tcx) + { + *rvalue = Rvalue::Aggregate( + Box::new(AggregateKind::Array(operand.ty(self.local_decls, self.tcx))), + [operand.clone()].into(), + ); + } + } + + fn simplify_primitive_clone( + &self, + terminator: &mut Terminator<'tcx>, + statements: &mut Vec<Statement<'tcx>>, + ) { + let TerminatorKind::Call { + func, args, destination, target: Some(destination_block), .. + } = &terminator.kind + else { + return; + }; + + // It's definitely not a clone if there are multiple arguments + let [arg] = &args[..] else { return }; + + // Only bother looking more if it's easy to know what we're calling + let Some((fn_def_id, ..)) = func.const_fn_def() else { return }; + + // These types are easily available from locals, so check that before + // doing DefId lookups to figure out what we're actually calling. + let arg_ty = arg.node.ty(self.local_decls, self.tcx); + + let ty::Ref(_region, inner_ty, Mutability::Not) = *arg_ty.kind() else { return }; + + if !self.tcx.is_lang_item(fn_def_id, LangItem::CloneFn) + || !inner_ty.is_trivially_pure_clone_copy() + { + return; + } + + let Some(arg_place) = arg.node.place() else { return }; + + statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Use(Operand::Copy( + arg_place.project_deeper(&[ProjectionElem::Deref], self.tcx), + )), + ))), + }); + terminator.kind = TerminatorKind::Goto { target: *destination_block }; + } + + fn simplify_nounwind_call(&self, terminator: &mut Terminator<'tcx>) { + let TerminatorKind::Call { ref func, ref mut unwind, .. } = terminator.kind else { + return; + }; + + let Some((def_id, _)) = func.const_fn_def() else { + return; + }; + + let body_ty = self.tcx.type_of(def_id).skip_binder(); + let body_abi = match body_ty.kind() { + ty::FnDef(..) => body_ty.fn_sig(self.tcx).abi(), + ty::Closure(..) => ExternAbi::RustCall, + ty::Coroutine(..) => ExternAbi::Rust, + _ => bug!("unexpected body ty: {body_ty:?}"), + }; + + if !layout::fn_can_unwind(self.tcx, Some(def_id), body_abi) { + *unwind = UnwindAction::Unreachable; + } + } + + fn simplify_intrinsic_assert(&self, terminator: &mut Terminator<'tcx>) { + let TerminatorKind::Call { ref func, target: ref mut target @ Some(target_block), .. } = + terminator.kind + else { + return; + }; + let func_ty = func.ty(self.local_decls, self.tcx); + let Some((intrinsic_name, args)) = resolve_rust_intrinsic(self.tcx, func_ty) else { + return; + }; + // The intrinsics we are interested in have one generic parameter + let [arg, ..] = args[..] else { return }; + + let known_is_valid = + intrinsic_assert_panics(self.tcx, self.typing_env, arg, intrinsic_name); + match known_is_valid { + // We don't know the layout or it's not validity assertion at all, don't touch it + None => {} + Some(true) => { + // If we know the assert panics, indicate to later opts that the call diverges + *target = None; + } + Some(false) => { + // If we know the assert does not panic, turn the call into a Goto + terminator.kind = TerminatorKind::Goto { target: target_block }; + } + } + } +} + +fn intrinsic_assert_panics<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + arg: ty::GenericArg<'tcx>, + intrinsic_name: Symbol, +) -> Option<bool> { + let requirement = ValidityRequirement::from_intrinsic(intrinsic_name)?; + let ty = arg.expect_ty(); + Some(!tcx.check_validity_requirement((requirement, typing_env.as_query_input(ty))).ok()?) +} + +fn resolve_rust_intrinsic<'tcx>( + tcx: TyCtxt<'tcx>, + func_ty: Ty<'tcx>, +) -> Option<(Symbol, GenericArgsRef<'tcx>)> { + let ty::FnDef(def_id, args) = *func_ty.kind() else { return None }; + let intrinsic = tcx.intrinsic(def_id)?; + Some((intrinsic.name, args)) +} diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs new file mode 100644 index 00000000000..f9e642e28eb --- /dev/null +++ b/compiler/rustc_mir_transform/src/jump_threading.rs @@ -0,0 +1,852 @@ +//! A jump threading optimization. +//! +//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps +//! X = 0 X = 0 +//! ------------\ /-------- ------------ +//! X = 1 X----X SwitchInt(X) => X = 1 +//! ------------/ \-------- ------------ +//! +//! +//! We proceed by walking the cfg backwards starting from each `SwitchInt` terminator, +//! looking for assignments that will turn the `SwitchInt` into a simple `Goto`. +//! +//! The algorithm maintains a set of replacement conditions: +//! - `conditions[place]` contains `Condition { value, polarity: Eq, target }` +//! if assigning `value` to `place` turns the `SwitchInt` into `Goto { target }`. +//! - `conditions[place]` contains `Condition { value, polarity: Ne, target }` +//! if assigning anything different from `value` to `place` turns the `SwitchInt` +//! into `Goto { target }`. +//! +//! In this file, we denote as `place ?= value` the existence of a replacement condition +//! on `place` with given `value`, irrespective of the polarity and target of that +//! replacement condition. +//! +//! We then walk the CFG backwards transforming the set of conditions. +//! When we find a fulfilling assignment, we record a `ThreadingOpportunity`. +//! All `ThreadingOpportunity`s are applied to the body, by duplicating blocks if required. +//! +//! The optimization search can be very heavy, as it performs a DFS on MIR starting from +//! each `SwitchInt` terminator. To manage the complexity, we: +//! - bound the maximum depth by a constant `MAX_BACKTRACK`; +//! - we only traverse `Goto` terminators. +//! +//! We try to avoid creating irreducible control-flow by not threading through a loop header. +//! +//! Likewise, applying the optimisation can create a lot of new MIR, so we bound the instruction +//! cost by `MAX_COST`. + +use rustc_arena::DroplessArena; +use rustc_const_eval::const_eval::DummyMachine; +use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable}; +use rustc_data_structures::fx::FxHashSet; +use rustc_index::IndexVec; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::bug; +use rustc_middle::mir::interpret::Scalar; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, ScalarInt, TyCtxt}; +use rustc_mir_dataflow::lattice::HasBottom; +use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem}; +use rustc_span::DUMMY_SP; +use tracing::{debug, instrument, trace}; + +use crate::cost_checker::CostChecker; + +pub(super) struct JumpThreading; + +const MAX_BACKTRACK: usize = 5; +const MAX_COST: usize = 100; +const MAX_PLACES: usize = 100; + +impl<'tcx> crate::MirPass<'tcx> for JumpThreading { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 2 + } + + #[instrument(skip_all level = "debug")] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + debug!(?def_id); + + // Optimizing coroutines creates query cycles. + if tcx.is_coroutine(def_id) { + trace!("Skipped for coroutine {:?}", def_id); + return; + } + + let typing_env = body.typing_env(tcx); + let arena = &DroplessArena::default(); + let mut finder = TOFinder { + tcx, + typing_env, + ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine), + body, + arena, + map: Map::new(tcx, body, Some(MAX_PLACES)), + loop_headers: loop_headers(body), + opportunities: Vec::new(), + }; + + for (bb, _) in traversal::preorder(body) { + finder.start_from_switch(bb); + } + + let opportunities = finder.opportunities; + debug!(?opportunities); + if opportunities.is_empty() { + return; + } + + // Verify that we do not thread through a loop header. + for to in opportunities.iter() { + assert!(to.chain.iter().all(|&block| !finder.loop_headers.contains(block))); + } + OpportunitySet::new(body, opportunities).apply(body); + } + + fn is_required(&self) -> bool { + false + } +} + +#[derive(Debug)] +struct ThreadingOpportunity { + /// The list of `BasicBlock`s from the one that found the opportunity to the `SwitchInt`. + chain: Vec<BasicBlock>, + /// The `SwitchInt` will be replaced by `Goto { target }`. + target: BasicBlock, +} + +struct TOFinder<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + ecx: InterpCx<'tcx, DummyMachine>, + body: &'a Body<'tcx>, + map: Map<'tcx>, + loop_headers: DenseBitSet<BasicBlock>, + /// We use an arena to avoid cloning the slices when cloning `state`. + arena: &'a DroplessArena, + opportunities: Vec<ThreadingOpportunity>, +} + +/// Represent the following statement. If we can prove that the current local is equal/not-equal +/// to `value`, jump to `target`. +#[derive(Copy, Clone, Debug)] +struct Condition { + value: ScalarInt, + polarity: Polarity, + target: BasicBlock, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +enum Polarity { + Ne, + Eq, +} + +impl Condition { + fn matches(&self, value: ScalarInt) -> bool { + (self.value == value) == (self.polarity == Polarity::Eq) + } +} + +#[derive(Copy, Clone, Debug)] +struct ConditionSet<'a>(&'a [Condition]); + +impl HasBottom for ConditionSet<'_> { + const BOTTOM: Self = ConditionSet(&[]); + + fn is_bottom(&self) -> bool { + self.0.is_empty() + } +} + +impl<'a> ConditionSet<'a> { + fn iter(self) -> impl Iterator<Item = Condition> { + self.0.iter().copied() + } + + fn iter_matches(self, value: ScalarInt) -> impl Iterator<Item = Condition> { + self.iter().filter(move |c| c.matches(value)) + } + + fn map( + self, + arena: &'a DroplessArena, + f: impl Fn(Condition) -> Option<Condition>, + ) -> Option<ConditionSet<'a>> { + let set = arena.try_alloc_from_iter(self.iter().map(|c| f(c).ok_or(()))).ok()?; + Some(ConditionSet(set)) + } +} + +impl<'a, 'tcx> TOFinder<'a, 'tcx> { + fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool { + state.all_bottom() + } + + /// Recursion entry point to find threading opportunities. + #[instrument(level = "trace", skip(self))] + fn start_from_switch(&mut self, bb: BasicBlock) { + let bbdata = &self.body[bb]; + if bbdata.is_cleanup || self.loop_headers.contains(bb) { + return; + } + let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return }; + let Some(discr) = discr.place() else { return }; + debug!(?discr, ?bb); + + let discr_ty = discr.ty(self.body, self.tcx).ty; + let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return }; + + let Some(discr) = self.map.find(discr.as_ref()) else { return }; + debug!(?discr); + + let cost = CostChecker::new(self.tcx, self.typing_env, None, self.body); + let mut state = State::new_reachable(); + + let conds = if let Some((value, then, else_)) = targets.as_static_if() { + let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return }; + self.arena.alloc_from_iter([ + Condition { value, polarity: Polarity::Eq, target: then }, + Condition { value, polarity: Polarity::Ne, target: else_ }, + ]) + } else { + self.arena.alloc_from_iter(targets.iter().filter_map(|(value, target)| { + let value = ScalarInt::try_from_uint(value, discr_layout.size)?; + Some(Condition { value, polarity: Polarity::Eq, target }) + })) + }; + let conds = ConditionSet(conds); + state.insert_value_idx(discr, conds, &self.map); + + self.find_opportunity(bb, state, cost, 0) + } + + /// Recursively walk statements backwards from this bb's terminator to find threading + /// opportunities. + #[instrument(level = "trace", skip(self, cost), ret)] + fn find_opportunity( + &mut self, + bb: BasicBlock, + mut state: State<ConditionSet<'a>>, + mut cost: CostChecker<'_, 'tcx>, + depth: usize, + ) { + // Do not thread through loop headers. + if self.loop_headers.contains(bb) { + return; + } + + debug!(cost = ?cost.cost()); + for (statement_index, stmt) in + self.body.basic_blocks[bb].statements.iter().enumerate().rev() + { + if self.is_empty(&state) { + return; + } + + cost.visit_statement(stmt, Location { block: bb, statement_index }); + if cost.cost() > MAX_COST { + return; + } + + // Attempt to turn the `current_condition` on `lhs` into a condition on another place. + self.process_statement(bb, stmt, &mut state); + + // When a statement mutates a place, assignments to that place that happen + // above the mutation cannot fulfill a condition. + // _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`. + // _1 = 6 + if let Some((lhs, tail)) = self.mutated_statement(stmt) { + state.flood_with_tail_elem(lhs.as_ref(), tail, &self.map, ConditionSet::BOTTOM); + } + } + + if self.is_empty(&state) || depth >= MAX_BACKTRACK { + return; + } + + let last_non_rec = self.opportunities.len(); + + let predecessors = &self.body.basic_blocks.predecessors()[bb]; + if let &[pred] = &predecessors[..] + && bb != START_BLOCK + { + let term = self.body.basic_blocks[pred].terminator(); + match term.kind { + TerminatorKind::SwitchInt { ref discr, ref targets } => { + self.process_switch_int(discr, targets, bb, &mut state); + self.find_opportunity(pred, state, cost, depth + 1); + } + _ => self.recurse_through_terminator(pred, || state, &cost, depth), + } + } else if let &[ref predecessors @ .., last_pred] = &predecessors[..] { + for &pred in predecessors { + self.recurse_through_terminator(pred, || state.clone(), &cost, depth); + } + self.recurse_through_terminator(last_pred, || state, &cost, depth); + } + + let new_tos = &mut self.opportunities[last_non_rec..]; + debug!(?new_tos); + + // Try to deduplicate threading opportunities. + if new_tos.len() > 1 + && new_tos.len() == predecessors.len() + && predecessors + .iter() + .zip(new_tos.iter()) + .all(|(&pred, to)| to.chain == &[pred] && to.target == new_tos[0].target) + { + // All predecessors have a threading opportunity, and they all point to the same block. + debug!(?new_tos, "dedup"); + let first = &mut new_tos[0]; + *first = ThreadingOpportunity { chain: vec![bb], target: first.target }; + self.opportunities.truncate(last_non_rec + 1); + return; + } + + for op in self.opportunities[last_non_rec..].iter_mut() { + op.chain.push(bb); + } + } + + /// Extract the mutated place from a statement. + /// + /// This method returns the `Place` so we can flood the state in case of a partial assignment. + /// (_1 as Ok).0 = _5; + /// (_1 as Err).0 = _6; + /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as + /// the value may have been mangled by the second assignment. + /// + /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can + /// stop at flooding the discriminant, and preserve the variant fields. + /// (_1 as Some).0 = _6; + /// SetDiscriminant(_1, 1); + /// switchInt((_1 as Some).0) + #[instrument(level = "trace", skip(self), ret)] + fn mutated_statement( + &self, + stmt: &Statement<'tcx>, + ) -> Option<(Place<'tcx>, Option<TrackElem>)> { + match stmt.kind { + StatementKind::Assign(box (place, _)) + | StatementKind::Deinit(box place) => Some((place, None)), + StatementKind::SetDiscriminant { box place, variant_index: _ } => { + Some((place, Some(TrackElem::Discriminant))) + } + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + Some((Place::from(local), None)) + } + StatementKind::Retag(..) + | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..)) + // copy_nonoverlapping takes pointers and mutated the pointed-to value. + | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..)) + | StatementKind::AscribeUserType(..) + | StatementKind::Coverage(..) + | StatementKind::FakeRead(..) + | StatementKind::ConstEvalCounter + | StatementKind::PlaceMention(..) + | StatementKind::BackwardIncompatibleDropHint { .. } + | StatementKind::Nop => None, + } + } + + #[instrument(level = "trace", skip(self))] + fn process_immediate( + &mut self, + bb: BasicBlock, + lhs: PlaceIndex, + rhs: ImmTy<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) { + let register_opportunity = |c: Condition| { + debug!(?bb, ?c.target, "register"); + self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target }) + }; + + if let Some(conditions) = state.try_get_idx(lhs, &self.map) + && let Immediate::Scalar(Scalar::Int(int)) = *rhs + { + conditions.iter_matches(int).for_each(register_opportunity); + } + } + + /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. + #[instrument(level = "trace", skip(self))] + fn process_constant( + &mut self, + bb: BasicBlock, + lhs: PlaceIndex, + constant: OpTy<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) { + self.map.for_each_projection_value( + lhs, + constant, + &mut |elem, op| match elem { + TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(), + TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(), + TrackElem::Discriminant => { + let variant = self.ecx.read_discriminant(op).discard_err()?; + let discr_value = + self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?; + Some(discr_value.into()) + } + TrackElem::DerefLen => { + let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into(); + let len_usize = op.len(&self.ecx).discard_err()?; + let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); + Some(ImmTy::from_uint(len_usize, layout).into()) + } + }, + &mut |place, op| { + if let Some(conditions) = state.try_get_idx(place, &self.map) + && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err() + && let Some(imm) = imm.right() + && let Immediate::Scalar(Scalar::Int(int)) = *imm + { + conditions.iter_matches(int).for_each(|c: Condition| { + self.opportunities + .push(ThreadingOpportunity { chain: vec![bb], target: c.target }) + }) + } + }, + ); + } + + #[instrument(level = "trace", skip(self))] + fn process_operand( + &mut self, + bb: BasicBlock, + lhs: PlaceIndex, + rhs: &Operand<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) { + match rhs { + // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. + Operand::Constant(constant) => { + let Some(constant) = + self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err() + else { + return; + }; + self.process_constant(bb, lhs, constant, state); + } + // Transfer the conditions on the copied rhs. + Operand::Move(rhs) | Operand::Copy(rhs) => { + let Some(rhs) = self.map.find(rhs.as_ref()) else { return }; + state.insert_place_idx(rhs, lhs, &self.map); + } + } + } + + #[instrument(level = "trace", skip(self))] + fn process_assign( + &mut self, + bb: BasicBlock, + lhs_place: &Place<'tcx>, + rhs: &Rvalue<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) { + let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return }; + match rhs { + Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state), + // Transfer the conditions on the copy rhs. + Rvalue::CopyForDeref(rhs) => self.process_operand(bb, lhs, &Operand::Copy(*rhs), state), + Rvalue::Discriminant(rhs) => { + let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return }; + state.insert_place_idx(rhs, lhs, &self.map); + } + // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. + Rvalue::Aggregate(box kind, operands) => { + let agg_ty = lhs_place.ty(self.body, self.tcx).ty; + let lhs = match kind { + // Do not support unions. + AggregateKind::Adt(.., Some(_)) => return, + AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => { + if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant) + && let Some(discr_value) = self + .ecx + .discriminant_for_variant(agg_ty, *variant_index) + .discard_err() + { + self.process_immediate(bb, discr_target, discr_value, state); + } + if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) { + idx + } else { + return; + } + } + _ => lhs, + }; + for (field_index, operand) in operands.iter_enumerated() { + if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) { + self.process_operand(bb, field, operand, state); + } + } + } + // Transfer the conditions on the copy rhs, after inverting the value of the condition. + Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => { + let layout = self.ecx.layout_of(place.ty(self.body, self.tcx).ty).unwrap(); + let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return }; + let Some(place) = self.map.find(place.as_ref()) else { return }; + let Some(conds) = conditions.map(self.arena, |mut cond| { + cond.value = self + .ecx + .unary_op(UnOp::Not, &ImmTy::from_scalar_int(cond.value, layout)) + .discard_err()? + .to_scalar_int() + .discard_err()?; + Some(cond) + }) else { + return; + }; + state.insert_value_idx(place, conds, &self.map); + } + // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`. + // Create a condition on `rhs ?= B`. + Rvalue::BinaryOp( + op, + box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value)) + | box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)), + ) => { + let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return }; + let Some(place) = self.map.find(place.as_ref()) else { return }; + let equals = match op { + BinOp::Eq => ScalarInt::TRUE, + BinOp::Ne => ScalarInt::FALSE, + _ => return, + }; + if value.const_.ty().is_floating_point() { + // Floating point equality does not follow bit-patterns. + // -0.0 and NaN both have special rules for equality, + // and therefore we cannot use integer comparisons for them. + // Avoid handling them, though this could be extended in the future. + return; + } + let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env) + else { + return; + }; + let Some(conds) = conditions.map(self.arena, |c| { + Some(Condition { + value, + polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne }, + ..c + }) + }) else { + return; + }; + state.insert_value_idx(place, conds, &self.map); + } + + _ => {} + } + } + + #[instrument(level = "trace", skip(self))] + fn process_statement( + &mut self, + bb: BasicBlock, + stmt: &Statement<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) { + let register_opportunity = |c: Condition| { + debug!(?bb, ?c.target, "register"); + self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target }) + }; + + // Below, `lhs` is the return value of `mutated_statement`, + // the place to which `conditions` apply. + + match &stmt.kind { + // If we expect `discriminant(place) ?= A`, + // we have an opportunity if `variant_index ?= A`. + StatementKind::SetDiscriminant { box place, variant_index } => { + let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return }; + let enum_ty = place.ty(self.body, self.tcx).ty; + // `SetDiscriminant` guarantees that the discriminant is now `variant_index`. + // Even if the discriminant write does nothing due to niches, it is UB to set the + // discriminant when the data does not encode the desired discriminant. + let Some(discr) = + self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err() + else { + return; + }; + self.process_immediate(bb, discr_target, discr, state) + } + // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`. + StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume( + Operand::Copy(place) | Operand::Move(place), + )) => { + let Some(conditions) = state.try_get(place.as_ref(), &self.map) else { return }; + conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity) + } + StatementKind::Assign(box (lhs_place, rhs)) => { + self.process_assign(bb, lhs_place, rhs, state) + } + _ => {} + } + } + + #[instrument(level = "trace", skip(self, state, cost))] + fn recurse_through_terminator( + &mut self, + bb: BasicBlock, + // Pass a closure that may clone the state, as we don't want to do it each time. + state: impl FnOnce() -> State<ConditionSet<'a>>, + cost: &CostChecker<'_, 'tcx>, + depth: usize, + ) { + let term = self.body.basic_blocks[bb].terminator(); + let place_to_flood = match term.kind { + // We come from a target, so those are not possible. + TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::TailCall { .. } + | TerminatorKind::Unreachable + | TerminatorKind::CoroutineDrop => bug!("{term:?} has no terminators"), + // Disallowed during optimizations. + TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"), + // Cannot reason about inline asm. + TerminatorKind::InlineAsm { .. } => return, + // `SwitchInt` is handled specially. + TerminatorKind::SwitchInt { .. } => return, + // We can recurse, no thing particular to do. + TerminatorKind::Goto { .. } => None, + // Flood the overwritten place, and progress through. + TerminatorKind::Drop { place: destination, .. } + | TerminatorKind::Call { destination, .. } => Some(destination), + // Ignore, as this can be a no-op at codegen time. + TerminatorKind::Assert { .. } => None, + }; + + // We can recurse through this terminator. + let mut state = state(); + if let Some(place_to_flood) = place_to_flood { + state.flood_with(place_to_flood.as_ref(), &self.map, ConditionSet::BOTTOM); + } + self.find_opportunity(bb, state, cost.clone(), depth + 1) + } + + #[instrument(level = "trace", skip(self))] + fn process_switch_int( + &mut self, + discr: &Operand<'tcx>, + targets: &SwitchTargets, + target_bb: BasicBlock, + state: &mut State<ConditionSet<'a>>, + ) { + debug_assert_ne!(target_bb, START_BLOCK); + debug_assert_eq!(self.body.basic_blocks.predecessors()[target_bb].len(), 1); + + let Some(discr) = discr.place() else { return }; + let discr_ty = discr.ty(self.body, self.tcx).ty; + let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { + return; + }; + let Some(conditions) = state.try_get(discr.as_ref(), &self.map) else { return }; + + if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) { + let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return }; + debug_assert_eq!(targets.iter().filter(|&(_, target)| target == target_bb).count(), 1); + + // We are inside `target_bb`. Since we have a single predecessor, we know we passed + // through the `SwitchInt` before arriving here. Therefore, we know that + // `discr == value`. If one condition can be fulfilled by `discr == value`, + // that's an opportunity. + for c in conditions.iter_matches(value) { + debug!(?target_bb, ?c.target, "register"); + self.opportunities.push(ThreadingOpportunity { chain: vec![], target: c.target }); + } + } else if let Some((value, _, else_bb)) = targets.as_static_if() + && target_bb == else_bb + { + let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return }; + + // We only know that `discr != value`. That's much weaker information than + // the equality we had in the previous arm. All we can conclude is that + // the replacement condition `discr != value` can be threaded, and nothing else. + for c in conditions.iter() { + if c.value == value && c.polarity == Polarity::Ne { + debug!(?target_bb, ?c.target, "register"); + self.opportunities + .push(ThreadingOpportunity { chain: vec![], target: c.target }); + } + } + } + } +} + +struct OpportunitySet { + opportunities: Vec<ThreadingOpportunity>, + /// For each bb, give the TOs in which it appears. The pair corresponds to the index + /// in `opportunities` and the index in `ThreadingOpportunity::chain`. + involving_tos: IndexVec<BasicBlock, Vec<(usize, usize)>>, + /// Cache the number of predecessors for each block, as we clear the basic block cache.. + predecessors: IndexVec<BasicBlock, usize>, +} + +impl OpportunitySet { + fn new(body: &Body<'_>, opportunities: Vec<ThreadingOpportunity>) -> OpportunitySet { + let mut involving_tos = IndexVec::from_elem(Vec::new(), &body.basic_blocks); + for (index, to) in opportunities.iter().enumerate() { + for (ibb, &bb) in to.chain.iter().enumerate() { + involving_tos[bb].push((index, ibb)); + } + involving_tos[to.target].push((index, to.chain.len())); + } + let predecessors = predecessor_count(body); + OpportunitySet { opportunities, involving_tos, predecessors } + } + + /// Apply the opportunities on the graph. + fn apply(&mut self, body: &mut Body<'_>) { + for i in 0..self.opportunities.len() { + self.apply_once(i, body); + } + } + + #[instrument(level = "trace", skip(self, body))] + fn apply_once(&mut self, index: usize, body: &mut Body<'_>) { + debug!(?self.predecessors); + debug!(?self.involving_tos); + + // Check that `predecessors` satisfies its invariant. + debug_assert_eq!(self.predecessors, predecessor_count(body)); + + // Remove the TO from the vector to allow modifying the other ones later. + let op = &mut self.opportunities[index]; + debug!(?op); + let op_chain = std::mem::take(&mut op.chain); + let op_target = op.target; + debug_assert_eq!(op_chain.len(), op_chain.iter().collect::<FxHashSet<_>>().len()); + + let Some((current, chain)) = op_chain.split_first() else { return }; + let basic_blocks = body.basic_blocks.as_mut(); + + // Invariant: the control-flow is well-formed at the end of each iteration. + let mut current = *current; + for &succ in chain { + debug!(?current, ?succ); + + // `succ` must be a successor of `current`. If it is not, this means this TO is not + // satisfiable and a previous TO erased this edge, so we bail out. + if !basic_blocks[current].terminator().successors().any(|s| s == succ) { + debug!("impossible"); + return; + } + + // Fast path: `succ` is only used once, so we can reuse it directly. + if self.predecessors[succ] == 1 { + debug!("single"); + current = succ; + continue; + } + + let new_succ = basic_blocks.push(basic_blocks[succ].clone()); + debug!(?new_succ); + + // Replace `succ` by `new_succ` where it appears. + let mut num_edges = 0; + basic_blocks[current].terminator_mut().successors_mut(|s| { + if *s == succ { + *s = new_succ; + num_edges += 1; + } + }); + + // Update predecessors with the new block. + let _new_succ = self.predecessors.push(num_edges); + debug_assert_eq!(new_succ, _new_succ); + self.predecessors[succ] -= num_edges; + self.update_predecessor_count(basic_blocks[new_succ].terminator(), Update::Incr); + + // Replace the `current -> succ` edge by `current -> new_succ` in all the following + // TOs. This is necessary to avoid trying to thread through a non-existing edge. We + // use `involving_tos` here to avoid traversing the full set of TOs on each iteration. + let mut new_involved = Vec::new(); + for &(to_index, in_to_index) in &self.involving_tos[current] { + // That TO has already been applied, do nothing. + if to_index <= index { + continue; + } + + let other_to = &mut self.opportunities[to_index]; + if other_to.chain.get(in_to_index) != Some(¤t) { + continue; + } + let s = other_to.chain.get_mut(in_to_index + 1).unwrap_or(&mut other_to.target); + if *s == succ { + // `other_to` references the `current -> succ` edge, so replace `succ`. + *s = new_succ; + new_involved.push((to_index, in_to_index + 1)); + } + } + + // The TOs that we just updated now reference `new_succ`. Update `involving_tos` + // in case we need to duplicate an edge starting at `new_succ` later. + let _new_succ = self.involving_tos.push(new_involved); + debug_assert_eq!(new_succ, _new_succ); + + current = new_succ; + } + + let current = &mut basic_blocks[current]; + self.update_predecessor_count(current.terminator(), Update::Decr); + current.terminator_mut().kind = TerminatorKind::Goto { target: op_target }; + self.predecessors[op_target] += 1; + } + + fn update_predecessor_count(&mut self, terminator: &Terminator<'_>, incr: Update) { + match incr { + Update::Incr => { + for s in terminator.successors() { + self.predecessors[s] += 1; + } + } + Update::Decr => { + for s in terminator.successors() { + self.predecessors[s] -= 1; + } + } + } + } +} + +fn predecessor_count(body: &Body<'_>) -> IndexVec<BasicBlock, usize> { + let mut predecessors: IndexVec<_, _> = + body.basic_blocks.predecessors().iter().map(|ps| ps.len()).collect(); + predecessors[START_BLOCK] += 1; // Account for the implicit entry edge. + predecessors +} + +enum Update { + Incr, + Decr, +} + +/// Compute the set of loop headers in the given body. We define a loop header as a block which has +/// at least a predecessor which it dominates. This definition is only correct for reducible CFGs. +/// But if the CFG is already irreducible, there is no point in trying much harder. +/// is already irreducible. +fn loop_headers(body: &Body<'_>) -> DenseBitSet<BasicBlock> { + let mut loop_headers = DenseBitSet::new_empty(body.basic_blocks.len()); + let dominators = body.basic_blocks.dominators(); + // Only visit reachable blocks. + for (bb, bbdata) in traversal::preorder(body) { + for succ in bbdata.terminator().successors() { + if dominators.dominates(succ, bb) { + loop_headers.insert(succ); + } + } + } + loop_headers +} diff --git a/compiler/rustc_mir_transform/src/known_panics_lint.rs b/compiler/rustc_mir_transform/src/known_panics_lint.rs new file mode 100644 index 00000000000..481c7941909 --- /dev/null +++ b/compiler/rustc_mir_transform/src/known_panics_lint.rs @@ -0,0 +1,992 @@ +//! A lint that checks for known panics like overflows, division by zero, +//! out-of-bound access etc. Uses const propagation to determine the values of +//! operands during checks. + +use std::fmt::Debug; + +use rustc_abi::{BackendRepr, FieldIdx, HasDataLayout, Size, TargetDataLayout, VariantIdx}; +use rustc_const_eval::const_eval::DummyMachine; +use rustc_const_eval::interpret::{ + ImmTy, InterpCx, InterpResult, Projectable, Scalar, format_interp_error, interp_ok, +}; +use rustc_data_structures::fx::FxHashSet; +use rustc_hir::HirId; +use rustc_hir::def::DefKind; +use rustc_index::IndexVec; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::bug; +use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; +use rustc_middle::ty::{self, ConstInt, ScalarInt, Ty, TyCtxt, TypeVisitableExt}; +use rustc_span::Span; +use tracing::{debug, instrument, trace}; + +use crate::errors::{AssertLint, AssertLintKind}; + +pub(super) struct KnownPanicsLint; + +impl<'tcx> crate::MirLint<'tcx> for KnownPanicsLint { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + if body.tainted_by_errors.is_some() { + return; + } + + let def_id = body.source.def_id().expect_local(); + let def_kind = tcx.def_kind(def_id); + let is_fn_like = def_kind.is_fn_like(); + let is_assoc_const = def_kind == DefKind::AssocConst; + + // Only run const prop on functions, methods, closures and associated constants + if !is_fn_like && !is_assoc_const { + // skip anon_const/statics/consts because they'll be evaluated by miri anyway + trace!("KnownPanicsLint skipped for {:?}", def_id); + return; + } + + // FIXME(welseywiser) const prop doesn't work on coroutines because of query cycles + // computing their layout. + if tcx.is_coroutine(def_id.to_def_id()) { + trace!("KnownPanicsLint skipped for coroutine {:?}", def_id); + return; + } + + trace!("KnownPanicsLint starting for {:?}", def_id); + + let mut linter = ConstPropagator::new(body, tcx); + linter.visit_body(body); + + trace!("KnownPanicsLint done for {:?}", def_id); + } +} + +/// Visits MIR nodes, performs const propagation +/// and runs lint checks as it goes +struct ConstPropagator<'mir, 'tcx> { + ecx: InterpCx<'tcx, DummyMachine>, + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + worklist: Vec<BasicBlock>, + visited_blocks: DenseBitSet<BasicBlock>, + locals: IndexVec<Local, Value<'tcx>>, + body: &'mir Body<'tcx>, + written_only_inside_own_block_locals: FxHashSet<Local>, + can_const_prop: IndexVec<Local, ConstPropMode>, +} + +#[derive(Debug, Clone)] +enum Value<'tcx> { + Immediate(ImmTy<'tcx>), + Aggregate { variant: VariantIdx, fields: IndexVec<FieldIdx, Value<'tcx>> }, + Uninit, +} + +impl<'tcx> From<ImmTy<'tcx>> for Value<'tcx> { + fn from(v: ImmTy<'tcx>) -> Self { + Self::Immediate(v) + } +} + +impl<'tcx> Value<'tcx> { + fn project( + &self, + proj: &[PlaceElem<'tcx>], + prop: &ConstPropagator<'_, 'tcx>, + ) -> Option<&Value<'tcx>> { + let mut this = self; + for proj in proj { + this = match (*proj, this) { + (PlaceElem::Field(idx, _), Value::Aggregate { fields, .. }) => { + fields.get(idx).unwrap_or(&Value::Uninit) + } + (PlaceElem::Index(idx), Value::Aggregate { fields, .. }) => { + let idx = prop.get_const(idx.into())?.immediate()?; + let idx = prop.ecx.read_target_usize(idx).discard_err()?.try_into().ok()?; + if idx <= FieldIdx::MAX_AS_U32 { + fields.get(FieldIdx::from_u32(idx)).unwrap_or(&Value::Uninit) + } else { + return None; + } + } + ( + PlaceElem::ConstantIndex { offset, min_length: _, from_end: false }, + Value::Aggregate { fields, .. }, + ) => fields + .get(FieldIdx::from_u32(offset.try_into().ok()?)) + .unwrap_or(&Value::Uninit), + _ => return None, + }; + } + Some(this) + } + + fn project_mut(&mut self, proj: &[PlaceElem<'_>]) -> Option<&mut Value<'tcx>> { + let mut this = self; + for proj in proj { + this = match (proj, this) { + (PlaceElem::Field(idx, _), Value::Aggregate { fields, .. }) => { + fields.ensure_contains_elem(*idx, || Value::Uninit) + } + (PlaceElem::Field(..), val @ Value::Uninit) => { + *val = + Value::Aggregate { variant: VariantIdx::ZERO, fields: Default::default() }; + val.project_mut(&[*proj])? + } + _ => return None, + }; + } + Some(this) + } + + fn immediate(&self) -> Option<&ImmTy<'tcx>> { + match self { + Value::Immediate(op) => Some(op), + _ => None, + } + } +} + +impl<'tcx> LayoutOfHelpers<'tcx> for ConstPropagator<'_, 'tcx> { + type LayoutOfResult = Result<TyAndLayout<'tcx>, LayoutError<'tcx>>; + + #[inline] + fn handle_layout_err(&self, err: LayoutError<'tcx>, _: Span, _: Ty<'tcx>) -> LayoutError<'tcx> { + err + } +} + +impl HasDataLayout for ConstPropagator<'_, '_> { + #[inline] + fn data_layout(&self) -> &TargetDataLayout { + &self.tcx.data_layout + } +} + +impl<'tcx> ty::layout::HasTyCtxt<'tcx> for ConstPropagator<'_, 'tcx> { + #[inline] + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } +} + +impl<'tcx> ty::layout::HasTypingEnv<'tcx> for ConstPropagator<'_, 'tcx> { + #[inline] + fn typing_env(&self) -> ty::TypingEnv<'tcx> { + self.typing_env + } +} + +impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { + fn new(body: &'mir Body<'tcx>, tcx: TyCtxt<'tcx>) -> ConstPropagator<'mir, 'tcx> { + let def_id = body.source.def_id(); + // FIXME(#132279): This is used during the phase transition from analysis + // to runtime, so we have to manually specify the correct typing mode. + let typing_env = ty::TypingEnv::post_analysis(tcx, body.source.def_id()); + let can_const_prop = CanConstProp::check(tcx, typing_env, body); + let ecx = InterpCx::new(tcx, tcx.def_span(def_id), typing_env, DummyMachine); + + ConstPropagator { + ecx, + tcx, + typing_env, + worklist: vec![START_BLOCK], + visited_blocks: DenseBitSet::new_empty(body.basic_blocks.len()), + locals: IndexVec::from_elem_n(Value::Uninit, body.local_decls.len()), + body, + can_const_prop, + written_only_inside_own_block_locals: Default::default(), + } + } + + fn local_decls(&self) -> &'mir LocalDecls<'tcx> { + &self.body.local_decls + } + + fn get_const(&self, place: Place<'tcx>) -> Option<&Value<'tcx>> { + self.locals[place.local].project(&place.projection, self) + } + + /// Remove `local` from the pool of `Locals`. Allows writing to them, + /// but not reading from them anymore. + fn remove_const(&mut self, local: Local) { + self.locals[local] = Value::Uninit; + self.written_only_inside_own_block_locals.remove(&local); + } + + fn access_mut(&mut self, place: &Place<'_>) -> Option<&mut Value<'tcx>> { + match self.can_const_prop[place.local] { + ConstPropMode::NoPropagation => return None, + ConstPropMode::OnlyInsideOwnBlock => { + self.written_only_inside_own_block_locals.insert(place.local); + } + ConstPropMode::FullConstProp => {} + } + self.locals[place.local].project_mut(place.projection) + } + + fn lint_root(&self, source_info: SourceInfo) -> Option<HirId> { + source_info.scope.lint_root(&self.body.source_scopes) + } + + fn use_ecx<F, T>(&mut self, f: F) -> Option<T> + where + F: FnOnce(&mut Self) -> InterpResult<'tcx, T>, + { + f(self) + .map_err_info(|err| { + trace!("InterpCx operation failed: {:?}", err); + // Some errors shouldn't come up because creating them causes + // an allocation, which we should avoid. When that happens, + // dedicated error variants should be introduced instead. + assert!( + !err.kind().formatted_string(), + "known panics lint encountered formatting error: {}", + format_interp_error(self.ecx.tcx.dcx(), err), + ); + err + }) + .discard_err() + } + + /// Returns the value, if any, of evaluating `c`. + fn eval_constant(&mut self, c: &ConstOperand<'tcx>) -> Option<ImmTy<'tcx>> { + // FIXME we need to revisit this for #67176 + if c.has_param() { + return None; + } + + // Normalization needed b/c known panics lint runs in + // `mir_drops_elaborated_and_const_checked`, which happens before + // optimized MIR. Only after optimizing the MIR can we guarantee + // that the `PostAnalysisNormalize` pass has happened and that the body's consts + // are normalized, so any call to resolve before that needs to be + // manually normalized. + let val = self.tcx.try_normalize_erasing_regions(self.typing_env, c.const_).ok()?; + + self.use_ecx(|this| this.ecx.eval_mir_constant(&val, c.span, None))? + .as_mplace_or_imm() + .right() + } + + /// Returns the value, if any, of evaluating `place`. + #[instrument(level = "trace", skip(self), ret)] + fn eval_place(&mut self, place: Place<'tcx>) -> Option<ImmTy<'tcx>> { + match self.get_const(place)? { + Value::Immediate(imm) => Some(imm.clone()), + Value::Aggregate { .. } => None, + Value::Uninit => None, + } + } + + /// Returns the value, if any, of evaluating `op`. Calls upon `eval_constant` + /// or `eval_place`, depending on the variant of `Operand` used. + fn eval_operand(&mut self, op: &Operand<'tcx>) -> Option<ImmTy<'tcx>> { + match *op { + Operand::Constant(ref c) => self.eval_constant(c), + Operand::Move(place) | Operand::Copy(place) => self.eval_place(place), + } + } + + fn report_assert_as_lint( + &self, + location: Location, + lint_kind: AssertLintKind, + assert_kind: AssertKind<impl Debug>, + ) { + let source_info = self.body.source_info(location); + if let Some(lint_root) = self.lint_root(*source_info) { + let span = source_info.span; + self.tcx.emit_node_span_lint( + lint_kind.lint(), + lint_root, + span, + AssertLint { span, assert_kind, lint_kind }, + ); + } + } + + fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>, location: Location) -> Option<()> { + let arg = self.eval_operand(arg)?; + // The only operator that can overflow is `Neg`. + if op == UnOp::Neg && arg.layout.ty.is_integral() { + // Compute this as `0 - arg` so we can use `SubWithOverflow` to check for overflow. + let (arg, overflow) = self.use_ecx(|this| { + let arg = this.ecx.read_immediate(&arg)?; + let (_res, overflow) = this + .ecx + .binary_op(BinOp::SubWithOverflow, &ImmTy::from_int(0, arg.layout), &arg)? + .to_scalar_pair(); + interp_ok((arg, overflow.to_bool()?)) + })?; + if overflow { + self.report_assert_as_lint( + location, + AssertLintKind::ArithmeticOverflow, + AssertKind::OverflowNeg(arg.to_const_int()), + ); + return None; + } + } + + Some(()) + } + + fn check_binary_op( + &mut self, + op: BinOp, + left: &Operand<'tcx>, + right: &Operand<'tcx>, + location: Location, + ) -> Option<()> { + let r = + self.eval_operand(right).and_then(|r| self.use_ecx(|this| this.ecx.read_immediate(&r))); + let l = + self.eval_operand(left).and_then(|l| self.use_ecx(|this| this.ecx.read_immediate(&l))); + // Check for exceeding shifts *even if* we cannot evaluate the LHS. + if matches!(op, BinOp::Shr | BinOp::Shl) { + let r = r.clone()?; + // We need the type of the LHS. We cannot use `place_layout` as that is the type + // of the result, which for checked binops is not the same! + let left_ty = left.ty(self.local_decls(), self.tcx); + let left_size = self.ecx.layout_of(left_ty).ok()?.size; + let right_size = r.layout.size; + let r_bits = r.to_scalar().to_bits(right_size).discard_err(); + if r_bits.is_some_and(|b| b >= left_size.bits() as u128) { + debug!("check_binary_op: reporting assert for {:?}", location); + let panic = AssertKind::Overflow( + op, + // Invent a dummy value, the diagnostic ignores it anyway + ConstInt::new( + ScalarInt::try_from_uint(1_u8, left_size).unwrap(), + left_ty.is_signed(), + left_ty.is_ptr_sized_integral(), + ), + r.to_const_int(), + ); + self.report_assert_as_lint(location, AssertLintKind::ArithmeticOverflow, panic); + return None; + } + } + + // Div/Rem are handled via the assertions they trigger. + // But for Add/Sub/Mul, those assertions only exist in debug builds, and we want to + // lint in release builds as well, so we check on the operation instead. + // So normalize to the "overflowing" operator, and then ensure that it + // actually is an overflowing operator. + let op = op.wrapping_to_overflowing().unwrap_or(op); + // The remaining operators are handled through `wrapping_to_overflowing`. + if let (Some(l), Some(r)) = (l, r) + && l.layout.ty.is_integral() + && op.is_overflowing() + && self.use_ecx(|this| { + let (_res, overflow) = this.ecx.binary_op(op, &l, &r)?.to_scalar_pair(); + overflow.to_bool() + })? + { + self.report_assert_as_lint( + location, + AssertLintKind::ArithmeticOverflow, + AssertKind::Overflow(op, l.to_const_int(), r.to_const_int()), + ); + return None; + } + + Some(()) + } + + fn check_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) -> Option<()> { + // Perform any special handling for specific Rvalue types. + // Generally, checks here fall into one of two categories: + // 1. Additional checking to provide useful lints to the user + // - In this case, we will do some validation and then fall through to the + // end of the function which evals the assignment. + // 2. Working around bugs in other parts of the compiler + // - In this case, we'll return `None` from this function to stop evaluation. + match rvalue { + // Additional checking: give lints to the user if an overflow would occur. + // We do this here and not in the `Assert` terminator as that terminator is + // only sometimes emitted (overflow checks can be disabled), but we want to always + // lint. + Rvalue::UnaryOp(op, arg) => { + trace!("checking UnaryOp(op = {:?}, arg = {:?})", op, arg); + self.check_unary_op(*op, arg, location)?; + } + Rvalue::BinaryOp(op, box (left, right)) => { + trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right); + self.check_binary_op(*op, left, right, location)?; + } + + // Do not try creating references (#67862) + Rvalue::RawPtr(_, place) | Rvalue::Ref(_, _, place) => { + trace!("skipping RawPtr | Ref for {:?}", place); + + // This may be creating mutable references or immutable references to cells. + // If that happens, the pointed to value could be mutated via that reference. + // Since we aren't tracking references, the const propagator loses track of what + // value the local has right now. + // Thus, all locals that have their reference taken + // must not take part in propagation. + self.remove_const(place.local); + + return None; + } + Rvalue::ThreadLocalRef(def_id) => { + trace!("skipping ThreadLocalRef({:?})", def_id); + + return None; + } + + // There's no other checking to do at this time. + Rvalue::Aggregate(..) + | Rvalue::Use(..) + | Rvalue::CopyForDeref(..) + | Rvalue::Repeat(..) + | Rvalue::Len(..) + | Rvalue::Cast(..) + | Rvalue::ShallowInitBox(..) + | Rvalue::Discriminant(..) + | Rvalue::NullaryOp(..) + | Rvalue::WrapUnsafeBinder(..) => {} + } + + // FIXME we need to revisit this for #67176 + if rvalue.has_param() { + return None; + } + if !rvalue.ty(self.local_decls(), self.tcx).is_sized(self.tcx, self.typing_env) { + // the interpreter doesn't support unsized locals (only unsized arguments), + // but rustc does (in a kinda broken way), so we have to skip them here + return None; + } + + Some(()) + } + + fn check_assertion( + &mut self, + expected: bool, + msg: &AssertKind<Operand<'tcx>>, + cond: &Operand<'tcx>, + location: Location, + ) { + let Some(value) = &self.eval_operand(cond) else { return }; + trace!("assertion on {:?} should be {:?}", value, expected); + + let expected = Scalar::from_bool(expected); + let Some(value_const) = self.use_ecx(|this| this.ecx.read_scalar(value)) else { return }; + + if expected != value_const { + // Poison all places this operand references so that further code + // doesn't use the invalid value + if let Some(place) = cond.place() { + self.remove_const(place.local); + } + + enum DbgVal<T> { + Val(T), + Underscore, + } + impl<T: std::fmt::Debug> std::fmt::Debug for DbgVal<T> { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Val(val) => val.fmt(fmt), + Self::Underscore => fmt.write_str("_"), + } + } + } + let mut eval_to_int = |op| { + // This can be `None` if the lhs wasn't const propagated and we just + // triggered the assert on the value of the rhs. + self.eval_operand(op) + .and_then(|op| self.ecx.read_immediate(&op).discard_err()) + .map_or(DbgVal::Underscore, |op| DbgVal::Val(op.to_const_int())) + }; + let msg = match msg { + AssertKind::DivisionByZero(op) => AssertKind::DivisionByZero(eval_to_int(op)), + AssertKind::RemainderByZero(op) => AssertKind::RemainderByZero(eval_to_int(op)), + AssertKind::Overflow(bin_op @ (BinOp::Div | BinOp::Rem), op1, op2) => { + // Division overflow is *UB* in the MIR, and different than the + // other overflow checks. + AssertKind::Overflow(*bin_op, eval_to_int(op1), eval_to_int(op2)) + } + AssertKind::BoundsCheck { len, index } => { + let len = eval_to_int(len); + let index = eval_to_int(index); + AssertKind::BoundsCheck { len, index } + } + // Remaining overflow errors are already covered by checks on the binary operators. + AssertKind::Overflow(..) | AssertKind::OverflowNeg(_) => return, + // Need proper const propagator for these. + _ => return, + }; + self.report_assert_as_lint(location, AssertLintKind::UnconditionalPanic, msg); + } + } + + fn ensure_not_propagated(&self, local: Local) { + if cfg!(debug_assertions) { + let val = self.get_const(local.into()); + assert!( + matches!(val, Some(Value::Uninit)) + || self + .layout_of(self.local_decls()[local].ty) + .map_or(true, |layout| layout.is_zst()), + "failed to remove values for `{local:?}`, value={val:?}", + ) + } + } + + #[instrument(level = "trace", skip(self), ret)] + fn eval_rvalue(&mut self, rvalue: &Rvalue<'tcx>, dest: &Place<'tcx>) -> Option<()> { + if !dest.projection.is_empty() { + return None; + } + use rustc_middle::mir::Rvalue::*; + let layout = self.ecx.layout_of(dest.ty(self.body, self.tcx).ty).ok()?; + trace!(?layout); + + let val: Value<'_> = match *rvalue { + ThreadLocalRef(_) => return None, + + Use(ref operand) | WrapUnsafeBinder(ref operand, _) => { + self.eval_operand(operand)?.into() + } + + CopyForDeref(place) => self.eval_place(place)?.into(), + + BinaryOp(bin_op, box (ref left, ref right)) => { + let left = self.eval_operand(left)?; + let left = self.use_ecx(|this| this.ecx.read_immediate(&left))?; + + let right = self.eval_operand(right)?; + let right = self.use_ecx(|this| this.ecx.read_immediate(&right))?; + + let val = self.use_ecx(|this| this.ecx.binary_op(bin_op, &left, &right))?; + if matches!(val.layout.backend_repr, BackendRepr::ScalarPair(..)) { + // FIXME `Value` should properly support pairs in `Immediate`... but currently + // it does not. + let (val, overflow) = val.to_pair(&self.ecx); + Value::Aggregate { + variant: VariantIdx::ZERO, + fields: [val.into(), overflow.into()].into_iter().collect(), + } + } else { + val.into() + } + } + + UnaryOp(un_op, ref operand) => { + let operand = self.eval_operand(operand)?; + let val = self.use_ecx(|this| this.ecx.read_immediate(&operand))?; + + let val = self.use_ecx(|this| this.ecx.unary_op(un_op, &val))?; + val.into() + } + + Aggregate(ref kind, ref fields) => Value::Aggregate { + fields: fields + .iter() + .map(|field| self.eval_operand(field).map_or(Value::Uninit, Value::Immediate)) + .collect(), + variant: match **kind { + AggregateKind::Adt(_, variant, _, _, _) => variant, + AggregateKind::Array(_) + | AggregateKind::Tuple + | AggregateKind::RawPtr(_, _) + | AggregateKind::Closure(_, _) + | AggregateKind::Coroutine(_, _) + | AggregateKind::CoroutineClosure(_, _) => VariantIdx::ZERO, + }, + }, + + Repeat(ref op, n) => { + trace!(?op, ?n); + return None; + } + + Len(place) => { + let len = if let ty::Array(_, n) = place.ty(self.local_decls(), self.tcx).ty.kind() + { + n.try_to_target_usize(self.tcx)? + } else { + match self.get_const(place)? { + Value::Immediate(src) => src.len(&self.ecx).discard_err()?, + Value::Aggregate { fields, .. } => fields.len() as u64, + Value::Uninit => return None, + } + }; + ImmTy::from_scalar(Scalar::from_target_usize(len, self), layout).into() + } + + Ref(..) | RawPtr(..) => return None, + + NullaryOp(ref null_op, ty) => { + let op_layout = self.ecx.layout_of(ty).ok()?; + let val = match null_op { + NullOp::SizeOf => op_layout.size.bytes(), + NullOp::AlignOf => op_layout.align.abi.bytes(), + NullOp::OffsetOf(fields) => self + .tcx + .offset_of_subfield(self.typing_env, op_layout, fields.iter()) + .bytes(), + NullOp::UbChecks => return None, + NullOp::ContractChecks => return None, + }; + ImmTy::from_scalar(Scalar::from_target_usize(val, self), layout).into() + } + + ShallowInitBox(..) => return None, + + Cast(ref kind, ref value, to) => match kind { + CastKind::IntToInt | CastKind::IntToFloat => { + let value = self.eval_operand(value)?; + let value = self.ecx.read_immediate(&value).discard_err()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.int_to_int_or_float(&value, to).discard_err()?; + res.into() + } + CastKind::FloatToFloat | CastKind::FloatToInt => { + let value = self.eval_operand(value)?; + let value = self.ecx.read_immediate(&value).discard_err()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.float_to_float_or_int(&value, to).discard_err()?; + res.into() + } + CastKind::Transmute => { + let value = self.eval_operand(value)?; + let to = self.ecx.layout_of(to).ok()?; + // `offset` for immediates only supports scalar/scalar-pair ABIs, + // so bail out if the target is not one. + match (value.layout.backend_repr, to.backend_repr) { + (BackendRepr::Scalar(..), BackendRepr::Scalar(..)) => {} + (BackendRepr::ScalarPair(..), BackendRepr::ScalarPair(..)) => {} + _ => return None, + } + + value.offset(Size::ZERO, to, &self.ecx).discard_err()?.into() + } + _ => return None, + }, + + Discriminant(place) => { + let variant = match self.get_const(place)? { + Value::Immediate(op) => { + let op = op.clone(); + self.use_ecx(|this| this.ecx.read_discriminant(&op))? + } + Value::Aggregate { variant, .. } => *variant, + Value::Uninit => return None, + }; + let imm = self.use_ecx(|this| { + this.ecx.discriminant_for_variant( + place.ty(this.local_decls(), this.tcx).ty, + variant, + ) + })?; + imm.into() + } + }; + trace!(?val); + + *self.access_mut(dest)? = val; + + Some(()) + } +} + +impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { + fn visit_body(&mut self, body: &Body<'tcx>) { + while let Some(bb) = self.worklist.pop() { + if !self.visited_blocks.insert(bb) { + continue; + } + + let data = &body.basic_blocks[bb]; + self.visit_basic_block_data(bb, data); + } + } + + fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { + self.super_operand(operand, location); + } + + fn visit_const_operand(&mut self, constant: &ConstOperand<'tcx>, location: Location) { + trace!("visit_const_operand: {:?}", constant); + self.super_const_operand(constant, location); + self.eval_constant(constant); + } + + fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { + self.super_assign(place, rvalue, location); + + let Some(()) = self.check_rvalue(rvalue, location) else { return }; + + match self.can_const_prop[place.local] { + // Do nothing if the place is indirect. + _ if place.is_indirect() => {} + ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), + ConstPropMode::OnlyInsideOwnBlock | ConstPropMode::FullConstProp => { + if self.eval_rvalue(rvalue, place).is_none() { + // Const prop failed, so erase the destination, ensuring that whatever happens + // from here on, does not know about the previous value. + // This is important in case we have + // ```rust + // let mut x = 42; + // x = SOME_MUTABLE_STATIC; + // // x must now be uninit + // ``` + // FIXME: we overzealously erase the entire local, because that's easier to + // implement. + trace!( + "propagation into {:?} failed. + Nuking the entire site from orbit, it's the only way to be sure", + place, + ); + self.remove_const(place.local); + } + } + } + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + trace!("visit_statement: {:?}", statement); + + // We want to evaluate operands before any change to the assigned-to value, + // so we recurse first. + self.super_statement(statement, location); + + match statement.kind { + StatementKind::SetDiscriminant { ref place, variant_index } => { + match self.can_const_prop[place.local] { + // Do nothing if the place is indirect. + _ if place.is_indirect() => {} + ConstPropMode::NoPropagation => self.ensure_not_propagated(place.local), + ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { + match self.access_mut(place) { + Some(Value::Aggregate { variant, .. }) => *variant = variant_index, + _ => self.remove_const(place.local), + } + } + } + } + StatementKind::StorageLive(local) => { + self.remove_const(local); + } + StatementKind::StorageDead(local) => { + self.remove_const(local); + } + _ => {} + } + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + self.super_terminator(terminator, location); + match &terminator.kind { + TerminatorKind::Assert { expected, msg, cond, .. } => { + self.check_assertion(*expected, msg, cond, location); + } + TerminatorKind::SwitchInt { discr, targets } => { + if let Some(ref value) = self.eval_operand(discr) + && let Some(value_const) = self.use_ecx(|this| this.ecx.read_scalar(value)) + && let Some(constant) = value_const.to_bits(value_const.size()).discard_err() + { + // We managed to evaluate the discriminant, so we know we only need to visit + // one target. + let target = targets.target_for_value(constant); + self.worklist.push(target); + return; + } + // We failed to evaluate the discriminant, fallback to visiting all successors. + } + // None of these have Operands to const-propagate. + TerminatorKind::Goto { .. } + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::TailCall { .. } + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::Yield { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::Call { .. } + | TerminatorKind::InlineAsm { .. } => {} + } + + self.worklist.extend(terminator.successors()); + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { + self.super_basic_block_data(block, data); + + // We remove all Locals which are restricted in propagation to their containing blocks and + // which were modified in the current block. + // Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`. + let mut written_only_inside_own_block_locals = + std::mem::take(&mut self.written_only_inside_own_block_locals); + + // This loop can get very hot for some bodies: it check each local in each bb. + // To avoid this quadratic behaviour, we only clear the locals that were modified inside + // the current block. + // The order in which we remove consts does not matter. + #[allow(rustc::potential_query_instability)] + for local in written_only_inside_own_block_locals.drain() { + debug_assert_eq!(self.can_const_prop[local], ConstPropMode::OnlyInsideOwnBlock); + self.remove_const(local); + } + self.written_only_inside_own_block_locals = written_only_inside_own_block_locals; + + if cfg!(debug_assertions) { + for (local, &mode) in self.can_const_prop.iter_enumerated() { + match mode { + ConstPropMode::FullConstProp => {} + ConstPropMode::NoPropagation | ConstPropMode::OnlyInsideOwnBlock => { + self.ensure_not_propagated(local); + } + } + } + } + } +} + +/// The maximum number of bytes that we'll allocate space for a local or the return value. +/// Needed for #66397, because otherwise we eval into large places and that can cause OOM or just +/// Severely regress performance. +const MAX_ALLOC_LIMIT: u64 = 1024; + +/// The mode that `ConstProp` is allowed to run in for a given `Local`. +#[derive(Clone, Copy, Debug, PartialEq)] +enum ConstPropMode { + /// The `Local` can be propagated into and reads of this `Local` can also be propagated. + FullConstProp, + /// The `Local` can only be propagated into and from its own block. + OnlyInsideOwnBlock, + /// The `Local` cannot be part of propagation at all. Any statement + /// referencing it either for reading or writing will not get propagated. + NoPropagation, +} + +/// A visitor that determines locals in a MIR body +/// that can be const propagated +struct CanConstProp { + can_const_prop: IndexVec<Local, ConstPropMode>, + // False at the beginning. Once set, no more assignments are allowed to that local. + found_assignment: DenseBitSet<Local>, +} + +impl CanConstProp { + /// Returns true if `local` can be propagated + fn check<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + body: &Body<'tcx>, + ) -> IndexVec<Local, ConstPropMode> { + let mut cpv = CanConstProp { + can_const_prop: IndexVec::from_elem(ConstPropMode::FullConstProp, &body.local_decls), + found_assignment: DenseBitSet::new_empty(body.local_decls.len()), + }; + for (local, val) in cpv.can_const_prop.iter_enumerated_mut() { + let ty = body.local_decls[local].ty; + if ty.is_async_drop_in_place_coroutine(tcx) { + // No const propagation for async drop coroutine (AsyncDropGlue). + // Otherwise, tcx.layout_of(typing_env.as_query_input(ty)) will be called + // (early layout request for async drop coroutine) to calculate layout size. + // Layout for `async_drop_in_place<T>::{closure}` may only be known with known T. + *val = ConstPropMode::NoPropagation; + continue; + } else if ty.is_union() { + // Unions are incompatible with the current implementation of + // const prop because Rust has no concept of an active + // variant of a union + *val = ConstPropMode::NoPropagation; + } else { + match tcx.layout_of(typing_env.as_query_input(ty)) { + Ok(layout) if layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) => {} + // Either the layout fails to compute, then we can't use this local anyway + // or the local is too large, then we don't want to. + _ => { + *val = ConstPropMode::NoPropagation; + continue; + } + } + } + } + // Consider that arguments are assigned on entry. + for arg in body.args_iter() { + cpv.found_assignment.insert(arg); + } + cpv.visit_body(body); + cpv.can_const_prop + } +} + +impl<'tcx> Visitor<'tcx> for CanConstProp { + fn visit_place(&mut self, place: &Place<'tcx>, mut context: PlaceContext, loc: Location) { + use rustc_middle::mir::visit::PlaceContext::*; + + // Dereferencing just read the address of `place.local`. + if place.projection.first() == Some(&PlaceElem::Deref) { + context = NonMutatingUse(NonMutatingUseContext::Copy); + } + + self.visit_local(place.local, context, loc); + self.visit_projection(place.as_ref(), context, loc); + } + + fn visit_local(&mut self, local: Local, context: PlaceContext, _: Location) { + use rustc_middle::mir::visit::PlaceContext::*; + match context { + // These are just stores, where the storing is not propagatable, but there may be later + // mutations of the same local via `Store` + | MutatingUse(MutatingUseContext::Call) + | MutatingUse(MutatingUseContext::AsmOutput) + | MutatingUse(MutatingUseContext::Deinit) + // Actual store that can possibly even propagate a value + | MutatingUse(MutatingUseContext::Store) + | MutatingUse(MutatingUseContext::SetDiscriminant) => { + if !self.found_assignment.insert(local) { + match &mut self.can_const_prop[local] { + // If the local can only get propagated in its own block, then we don't have + // to worry about multiple assignments, as we'll nuke the const state at the + // end of the block anyway, and inside the block we overwrite previous + // states as applicable. + ConstPropMode::OnlyInsideOwnBlock => {} + ConstPropMode::NoPropagation => {} + other @ ConstPropMode::FullConstProp => { + trace!( + "local {:?} can't be propagated because of multiple assignments. Previous state: {:?}", + local, other, + ); + *other = ConstPropMode::OnlyInsideOwnBlock; + } + } + } + } + // Reading constants is allowed an arbitrary number of times + NonMutatingUse(NonMutatingUseContext::Copy) + | NonMutatingUse(NonMutatingUseContext::Move) + | NonMutatingUse(NonMutatingUseContext::Inspect) + | NonMutatingUse(NonMutatingUseContext::PlaceMention) + | NonUse(_) => {} + + // These could be propagated with a smarter analysis or just some careful thinking about + // whether they'd be fine right now. + MutatingUse(MutatingUseContext::Yield) + | MutatingUse(MutatingUseContext::Drop) + | MutatingUse(MutatingUseContext::Retag) + // These can't ever be propagated under any scheme, as we can't reason about indirect + // mutation. + | NonMutatingUse(NonMutatingUseContext::SharedBorrow) + | NonMutatingUse(NonMutatingUseContext::FakeBorrow) + | NonMutatingUse(NonMutatingUseContext::RawBorrow) + | MutatingUse(MutatingUseContext::Borrow) + | MutatingUse(MutatingUseContext::RawBorrow) => { + trace!("local {:?} can't be propagated because it's used: {:?}", local, context); + self.can_const_prop[local] = ConstPropMode::NoPropagation; + } + MutatingUse(MutatingUseContext::Projection) + | NonMutatingUse(NonMutatingUseContext::Projection) => bug!("visit_place should not pass {context:?} for {local:?}"), + } + } +} diff --git a/compiler/rustc_mir_transform/src/large_enums.rs b/compiler/rustc_mir_transform/src/large_enums.rs new file mode 100644 index 00000000000..1a91d6bd7da --- /dev/null +++ b/compiler/rustc_mir_transform/src/large_enums.rs @@ -0,0 +1,249 @@ +use rustc_abi::{HasDataLayout, Size, TagEncoding, Variants}; +use rustc_data_structures::fx::FxHashMap; +use rustc_middle::mir::interpret::AllocId; +use rustc_middle::mir::*; +use rustc_middle::ty::util::IntTypeExt; +use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt}; +use rustc_session::Session; + +use crate::patch::MirPatch; + +/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large +/// enough discrepancy between them. +/// +/// i.e. If there are two variants: +/// ``` +/// enum Example { +/// Small, +/// Large([u32; 1024]), +/// } +/// ``` +/// Instead of emitting moves of the large variant, perform a memcpy instead. +/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD). +/// +/// In summary, what this does is at runtime determine which enum variant is active, +/// and instead of copying all the bytes of the largest possible variant, +/// copy only the bytes for the currently active variant. +pub(super) struct EnumSizeOpt { + pub(crate) discrepancy: u64, +} + +impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt { + fn is_enabled(&self, sess: &Session) -> bool { + // There are some differences in behavior on wasm and ARM that are not properly + // understood, so we conservatively treat this optimization as unsound: + // https://github.com/rust-lang/rust/pull/85158#issuecomment-1101836457 + sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // NOTE: This pass may produce different MIR based on the alignment of the target + // platform, but it will still be valid. + + let mut alloc_cache = FxHashMap::default(); + let typing_env = body.typing_env(tcx); + + let mut patch = MirPatch::new(body); + + for (block, data) in body.basic_blocks.as_mut().iter_enumerated_mut() { + for (statement_index, st) in data.statements.iter_mut().enumerate() { + let StatementKind::Assign(box ( + lhs, + Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), + )) = &st.kind + else { + continue; + }; + + let location = Location { block, statement_index }; + + let ty = lhs.ty(&body.local_decls, tcx).ty; + + let Some((adt_def, num_variants, alloc_id)) = + self.candidate(tcx, typing_env, ty, &mut alloc_cache) + else { + continue; + }; + + let span = st.source_info.span; + + let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64); + let size_array_local = patch.new_temp(tmp_ty, span); + + let store_live = StatementKind::StorageLive(size_array_local); + + let place = Place::from(size_array_local); + let constant_vals = ConstOperand { + span, + user_ty: None, + const_: Const::Val( + ConstValue::Indirect { alloc_id, offset: Size::ZERO }, + tmp_ty, + ), + }; + let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals))); + let const_assign = StatementKind::Assign(Box::new((place, rval))); + + let discr_place = + Place::from(patch.new_temp(adt_def.repr().discr_type().to_ty(tcx), span)); + let store_discr = + StatementKind::Assign(Box::new((discr_place, Rvalue::Discriminant(*rhs)))); + + let discr_cast_place = Place::from(patch.new_temp(tcx.types.usize, span)); + let cast_discr = StatementKind::Assign(Box::new(( + discr_cast_place, + Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_place), tcx.types.usize), + ))); + + let size_place = Place::from(patch.new_temp(tcx.types.usize, span)); + let store_size = StatementKind::Assign(Box::new(( + size_place, + Rvalue::Use(Operand::Copy(Place { + local: size_array_local, + projection: tcx.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]), + })), + ))); + + let dst = Place::from(patch.new_temp(Ty::new_mut_ptr(tcx, ty), span)); + let dst_ptr = + StatementKind::Assign(Box::new((dst, Rvalue::RawPtr(RawPtrKind::Mut, *lhs)))); + + let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8); + let dst_cast_place = Place::from(patch.new_temp(dst_cast_ty, span)); + let dst_cast = StatementKind::Assign(Box::new(( + dst_cast_place, + Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty), + ))); + + let src = Place::from(patch.new_temp(Ty::new_imm_ptr(tcx, ty), span)); + let src_ptr = + StatementKind::Assign(Box::new((src, Rvalue::RawPtr(RawPtrKind::Const, *rhs)))); + + let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8); + let src_cast_place = Place::from(patch.new_temp(src_cast_ty, span)); + let src_cast = StatementKind::Assign(Box::new(( + src_cast_place, + Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty), + ))); + + let deinit_old = StatementKind::Deinit(Box::new(dst)); + + let copy_bytes = StatementKind::Intrinsic(Box::new( + NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping { + src: Operand::Copy(src_cast_place), + dst: Operand::Copy(dst_cast_place), + count: Operand::Copy(size_place), + }), + )); + + let store_dead = StatementKind::StorageDead(size_array_local); + + let stmts = [ + store_live, + const_assign, + store_discr, + cast_discr, + store_size, + dst_ptr, + dst_cast, + src_ptr, + src_cast, + deinit_old, + copy_bytes, + store_dead, + ]; + for stmt in stmts { + patch.add_statement(location, stmt); + } + + st.make_nop(); + } + } + + patch.apply(body); + } + + fn is_required(&self) -> bool { + false + } +} + +impl EnumSizeOpt { + fn candidate<'tcx>( + &self, + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + ty: Ty<'tcx>, + alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>, + ) -> Option<(AdtDef<'tcx>, usize, AllocId)> { + let adt_def = match ty.kind() { + ty::Adt(adt_def, _args) if adt_def.is_enum() => adt_def, + _ => return None, + }; + let layout = tcx.layout_of(typing_env.as_query_input(ty)).ok()?; + let variants = match &layout.variants { + Variants::Single { .. } | Variants::Empty => return None, + Variants::Multiple { tag_encoding: TagEncoding::Niche { .. }, .. } => return None, + + Variants::Multiple { variants, .. } if variants.len() <= 1 => return None, + Variants::Multiple { variants, .. } => variants, + }; + let min = variants.iter().map(|v| v.size).min().unwrap(); + let max = variants.iter().map(|v| v.size).max().unwrap(); + if max.bytes() - min.bytes() < self.discrepancy { + return None; + } + + let num_discrs = adt_def.discriminants(tcx).count(); + if variants.iter_enumerated().any(|(var_idx, _)| { + let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val; + (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs) + }) { + return None; + } + if let Some(alloc_id) = alloc_cache.get(&ty) { + return Some((*adt_def, num_discrs, *alloc_id)); + } + + let data_layout = tcx.data_layout(); + let ptr_sized_int = data_layout.ptr_sized_integer(); + let target_bytes = ptr_sized_int.size().bytes() as usize; + let mut data = vec![0; target_bytes * num_discrs]; + + // We use a macro because `$bytes` can be u32 or u64. + macro_rules! encode_store { + ($curr_idx: expr, $endian: expr, $bytes: expr) => { + let bytes = match $endian { + rustc_abi::Endian::Little => $bytes.to_le_bytes(), + rustc_abi::Endian::Big => $bytes.to_be_bytes(), + }; + for (i, b) in bytes.into_iter().enumerate() { + data[$curr_idx + i] = b; + } + }; + } + + for (var_idx, layout) in variants.iter_enumerated() { + let curr_idx = + target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize; + let sz = layout.size; + match ptr_sized_int { + rustc_abi::Integer::I32 => { + encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32); + } + rustc_abi::Integer::I64 => { + encode_store!(curr_idx, data_layout.endian, sz.bytes()); + } + _ => unreachable!(), + }; + } + let alloc = interpret::Allocation::from_bytes( + data, + tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi, + Mutability::Not, + (), + ); + let alloc = tcx.reserve_and_set_memory_alloc(tcx.mk_const_alloc(alloc)); + Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc))) + } +} diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs new file mode 100644 index 00000000000..6b32254b051 --- /dev/null +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -0,0 +1,804 @@ +// tidy-alphabetical-start +#![feature(array_windows)] +#![feature(assert_matches)] +#![feature(box_patterns)] +#![feature(const_type_name)] +#![feature(cow_is_borrowed)] +#![feature(file_buffered)] +#![feature(if_let_guard)] +#![feature(impl_trait_in_assoc_type)] +#![feature(try_blocks)] +#![feature(yeet_expr)] +// tidy-alphabetical-end + +use hir::ConstContext; +use required_consts::RequiredConstsVisitor; +use rustc_const_eval::check_consts::{self, ConstCx}; +use rustc_const_eval::util; +use rustc_data_structures::fx::FxIndexSet; +use rustc_data_structures::steal::Steal; +use rustc_hir as hir; +use rustc_hir::def::{CtorKind, DefKind}; +use rustc_hir::def_id::LocalDefId; +use rustc_index::IndexVec; +use rustc_middle::mir::{ + AnalysisPhase, Body, CallSource, ClearCrossCrate, ConstOperand, ConstQualifs, LocalDecl, + MirPhase, Operand, Place, ProjectionElem, Promoted, RuntimePhase, Rvalue, START_BLOCK, + SourceInfo, Statement, StatementKind, TerminatorKind, +}; +use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt}; +use rustc_middle::util::Providers; +use rustc_middle::{bug, query, span_bug}; +use rustc_mir_build::builder::build_mir; +use rustc_span::source_map::Spanned; +use rustc_span::{DUMMY_SP, sym}; +use tracing::debug; + +#[macro_use] +mod pass_manager; + +use std::sync::LazyLock; + +use pass_manager::{self as pm, Lint, MirLint, MirPass, WithMinOptLevel}; + +mod check_pointers; +mod cost_checker; +mod cross_crate_inline; +mod deduce_param_attrs; +mod elaborate_drop; +mod errors; +mod ffi_unwind_calls; +mod lint; +mod lint_tail_expr_drop_order; +mod patch; +mod shim; +mod ssa; + +/// We import passes via this macro so that we can have a static list of pass names +/// (used to verify CLI arguments). It takes a list of modules, followed by the passes +/// declared within them. +/// ```ignore,macro-test +/// declare_passes! { +/// // Declare a single pass from the module `abort_unwinding_calls` +/// mod abort_unwinding_calls : AbortUnwindingCalls; +/// // When passes are grouped together as an enum, declare the two constituent passes +/// mod add_call_guards : AddCallGuards { +/// AllCallEdges, +/// CriticalCallEdges +/// }; +/// // Declares multiple pass groups, each containing their own constituent passes +/// mod simplify : SimplifyCfg { +/// Initial, +/// /* omitted */ +/// }, SimplifyLocals { +/// BeforeConstProp, +/// /* omitted */ +/// }; +/// } +/// ``` +macro_rules! declare_passes { + ( + $( + $vis:vis mod $mod_name:ident : $($pass_name:ident $( { $($ident:ident),* } )?),+ $(,)?; + )* + ) => { + $( + $vis mod $mod_name; + $( + // Make sure the type name is correct + #[allow(unused_imports)] + use $mod_name::$pass_name as _; + )+ + )* + + static PASS_NAMES: LazyLock<FxIndexSet<&str>> = LazyLock::new(|| [ + // Fake marker pass + "PreCodegen", + $( + $( + stringify!($pass_name), + $( + $( + $mod_name::$pass_name::$ident.name(), + )* + )? + )+ + )* + ].into_iter().collect()); + }; +} + +declare_passes! { + mod abort_unwinding_calls : AbortUnwindingCalls; + mod add_call_guards : AddCallGuards { AllCallEdges, CriticalCallEdges }; + mod add_moves_for_packed_drops : AddMovesForPackedDrops; + mod add_retag : AddRetag; + mod add_subtyping_projections : Subtyper; + mod check_inline : CheckForceInline; + mod check_call_recursion : CheckCallRecursion, CheckDropRecursion; + mod check_alignment : CheckAlignment; + mod check_enums : CheckEnums; + mod check_const_item_mutation : CheckConstItemMutation; + mod check_null : CheckNull; + mod check_packed_ref : CheckPackedRef; + // This pass is public to allow external drivers to perform MIR cleanup + pub mod cleanup_post_borrowck : CleanupPostBorrowck; + + mod copy_prop : CopyProp; + mod coroutine : StateTransform; + mod coverage : InstrumentCoverage; + mod ctfe_limit : CtfeLimit; + mod dataflow_const_prop : DataflowConstProp; + mod dead_store_elimination : DeadStoreElimination { + Initial, + Final + }; + mod deref_separator : Derefer; + mod dest_prop : DestinationPropagation; + pub mod dump_mir : Marker; + mod early_otherwise_branch : EarlyOtherwiseBranch; + mod elaborate_box_derefs : ElaborateBoxDerefs; + mod elaborate_drops : ElaborateDrops; + mod function_item_references : FunctionItemReferences; + mod gvn : GVN; + // Made public so that `mir_drops_elaborated_and_const_checked` can be overridden + // by custom rustc drivers, running all the steps by themselves. See #114628. + pub mod inline : Inline, ForceInline; + mod impossible_predicates : ImpossiblePredicates; + mod instsimplify : InstSimplify { BeforeInline, AfterSimplifyCfg }; + mod jump_threading : JumpThreading; + mod known_panics_lint : KnownPanicsLint; + mod large_enums : EnumSizeOpt; + mod lower_intrinsics : LowerIntrinsics; + mod lower_slice_len : LowerSliceLenCalls; + mod match_branches : MatchBranchSimplification; + mod mentioned_items : MentionedItems; + mod multiple_return_terminators : MultipleReturnTerminators; + mod nrvo : RenameReturnPlace; + mod post_drop_elaboration : CheckLiveDrops; + mod prettify : ReorderBasicBlocks, ReorderLocals; + mod promote_consts : PromoteTemps; + mod ref_prop : ReferencePropagation; + mod remove_noop_landing_pads : RemoveNoopLandingPads; + mod remove_place_mention : RemovePlaceMention; + mod remove_storage_markers : RemoveStorageMarkers; + mod remove_uninit_drops : RemoveUninitDrops; + mod remove_unneeded_drops : RemoveUnneededDrops; + mod remove_zsts : RemoveZsts; + mod required_consts : RequiredConstsVisitor; + mod post_analysis_normalize : PostAnalysisNormalize; + mod sanity_check : SanityCheck; + // This pass is public to allow external drivers to perform MIR cleanup + pub mod simplify : + SimplifyCfg { + Initial, + PromoteConsts, + RemoveFalseEdges, + PostAnalysis, + PreOptimizations, + Final, + MakeShim, + AfterUnreachableEnumBranching + }, + SimplifyLocals { + BeforeConstProp, + AfterGVN, + Final + }; + mod simplify_branches : SimplifyConstCondition { + AfterConstProp, + Final + }; + mod simplify_comparison_integral : SimplifyComparisonIntegral; + mod single_use_consts : SingleUseConsts; + mod sroa : ScalarReplacementOfAggregates; + mod strip_debuginfo : StripDebugInfo; + mod unreachable_enum_branching : UnreachableEnumBranching; + mod unreachable_prop : UnreachablePropagation; + mod validate : Validator; +} + +rustc_fluent_macro::fluent_messages! { "../messages.ftl" } + +pub fn provide(providers: &mut Providers) { + coverage::query::provide(providers); + ffi_unwind_calls::provide(providers); + shim::provide(providers); + cross_crate_inline::provide(providers); + providers.queries = query::Providers { + mir_keys, + mir_built, + mir_const_qualif, + mir_promoted, + mir_drops_elaborated_and_const_checked, + mir_for_ctfe, + mir_coroutine_witnesses: coroutine::mir_coroutine_witnesses, + optimized_mir, + is_mir_available, + is_ctfe_mir_available: is_mir_available, + mir_callgraph_reachable: inline::cycle::mir_callgraph_reachable, + mir_inliner_callees: inline::cycle::mir_inliner_callees, + promoted_mir, + deduced_param_attrs: deduce_param_attrs::deduced_param_attrs, + coroutine_by_move_body_def_id: coroutine::coroutine_by_move_body_def_id, + ..providers.queries + }; +} + +fn remap_mir_for_const_eval_select<'tcx>( + tcx: TyCtxt<'tcx>, + mut body: Body<'tcx>, + context: hir::Constness, +) -> Body<'tcx> { + for bb in body.basic_blocks.as_mut().iter_mut() { + let terminator = bb.terminator.as_mut().expect("invalid terminator"); + match terminator.kind { + TerminatorKind::Call { + func: Operand::Constant(box ConstOperand { ref const_, .. }), + ref mut args, + destination, + target, + unwind, + fn_span, + .. + } if let ty::FnDef(def_id, _) = *const_.ty().kind() + && tcx.is_intrinsic(def_id, sym::const_eval_select) => + { + let Ok([tupled_args, called_in_const, called_at_rt]) = take_array(args) else { + unreachable!() + }; + let ty = tupled_args.node.ty(&body.local_decls, tcx); + let fields = ty.tuple_fields(); + let num_args = fields.len(); + let func = + if context == hir::Constness::Const { called_in_const } else { called_at_rt }; + let (method, place): (fn(Place<'tcx>) -> Operand<'tcx>, Place<'tcx>) = + match tupled_args.node { + Operand::Constant(_) => { + // There is no good way of extracting a tuple arg from a constant + // (const generic stuff) so we just create a temporary and deconstruct + // that. + let local = body.local_decls.push(LocalDecl::new(ty, fn_span)); + bb.statements.push(Statement { + source_info: SourceInfo::outermost(fn_span), + kind: StatementKind::Assign(Box::new(( + local.into(), + Rvalue::Use(tupled_args.node.clone()), + ))), + }); + (Operand::Move, local.into()) + } + Operand::Move(place) => (Operand::Move, place), + Operand::Copy(place) => (Operand::Copy, place), + }; + let place_elems = place.projection; + let arguments = (0..num_args) + .map(|x| { + let mut place_elems = place_elems.to_vec(); + place_elems.push(ProjectionElem::Field(x.into(), fields[x])); + let projection = tcx.mk_place_elems(&place_elems); + let place = Place { local: place.local, projection }; + Spanned { node: method(place), span: DUMMY_SP } + }) + .collect(); + terminator.kind = TerminatorKind::Call { + func: func.node, + args: arguments, + destination, + target, + unwind, + call_source: CallSource::Misc, + fn_span, + }; + } + _ => {} + } + } + body +} + +fn take_array<T, const N: usize>(b: &mut Box<[T]>) -> Result<[T; N], Box<[T]>> { + let b: Box<[T; N]> = std::mem::take(b).try_into()?; + Ok(*b) +} + +fn is_mir_available(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { + tcx.mir_keys(()).contains(&def_id) +} + +/// Finds the full set of `DefId`s within the current crate that have +/// MIR associated with them. +fn mir_keys(tcx: TyCtxt<'_>, (): ()) -> FxIndexSet<LocalDefId> { + // All body-owners have MIR associated with them. + let mut set: FxIndexSet<_> = tcx.hir_body_owners().collect(); + + // Remove the fake bodies for `global_asm!`, since they're not useful + // to be emitted (`--emit=mir`) or encoded (in metadata). + set.retain(|&def_id| !matches!(tcx.def_kind(def_id), DefKind::GlobalAsm)); + + // Coroutine-closures (e.g. async closures) have an additional by-move MIR + // body that isn't in the HIR. + for body_owner in tcx.hir_body_owners() { + if let DefKind::Closure = tcx.def_kind(body_owner) + && tcx.needs_coroutine_by_move_body_def_id(body_owner.to_def_id()) + { + set.insert(tcx.coroutine_by_move_body_def_id(body_owner).expect_local()); + } + } + + // tuple struct/variant constructors have MIR, but they don't have a BodyId, + // so we need to build them separately. + for item in tcx.hir_crate_items(()).free_items() { + if let DefKind::Struct | DefKind::Enum = tcx.def_kind(item.owner_id) { + for variant in tcx.adt_def(item.owner_id).variants() { + if let Some((CtorKind::Fn, ctor_def_id)) = variant.ctor { + set.insert(ctor_def_id.expect_local()); + } + } + } + } + + set +} + +fn mir_const_qualif(tcx: TyCtxt<'_>, def: LocalDefId) -> ConstQualifs { + // N.B., this `borrow()` is guaranteed to be valid (i.e., the value + // cannot yet be stolen), because `mir_promoted()`, which steals + // from `mir_built()`, forces this query to execute before + // performing the steal. + let body = &tcx.mir_built(def).borrow(); + let ccx = check_consts::ConstCx::new(tcx, body); + // No need to const-check a non-const `fn`. + match ccx.const_kind { + Some(ConstContext::Const { .. } | ConstContext::Static(_) | ConstContext::ConstFn) => {} + None => span_bug!( + tcx.def_span(def), + "`mir_const_qualif` should only be called on const fns and const items" + ), + } + + if body.return_ty().references_error() { + // It's possible to reach here without an error being emitted (#121103). + tcx.dcx().span_delayed_bug(body.span, "mir_const_qualif: MIR had errors"); + return Default::default(); + } + + let mut validator = check_consts::check::Checker::new(&ccx); + validator.check_body(); + + // We return the qualifs in the return place for every MIR body, even though it is only used + // when deciding to promote a reference to a `const` for now. + validator.qualifs_in_return_place() +} + +fn mir_built(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> { + let mut body = build_mir(tcx, def); + + pass_manager::dump_mir_for_phase_change(tcx, &body); + + pm::run_passes( + tcx, + &mut body, + &[ + // MIR-level lints. + &Lint(check_inline::CheckForceInline), + &Lint(check_call_recursion::CheckCallRecursion), + &Lint(check_packed_ref::CheckPackedRef), + &Lint(check_const_item_mutation::CheckConstItemMutation), + &Lint(function_item_references::FunctionItemReferences), + // What we need to do constant evaluation. + &simplify::SimplifyCfg::Initial, + &Lint(sanity_check::SanityCheck), + ], + None, + pm::Optimizations::Allowed, + ); + tcx.alloc_steal_mir(body) +} + +/// Compute the main MIR body and the list of MIR bodies of the promoteds. +fn mir_promoted( + tcx: TyCtxt<'_>, + def: LocalDefId, +) -> (&Steal<Body<'_>>, &Steal<IndexVec<Promoted, Body<'_>>>) { + // Ensure that we compute the `mir_const_qualif` for constants at + // this point, before we steal the mir-const result. + // Also this means promotion can rely on all const checks having been done. + + let const_qualifs = match tcx.def_kind(def) { + DefKind::Fn | DefKind::AssocFn | DefKind::Closure + if tcx.constness(def) == hir::Constness::Const + || tcx.is_const_default_method(def.to_def_id()) => + { + tcx.mir_const_qualif(def) + } + DefKind::AssocConst + | DefKind::Const + | DefKind::Static { .. } + | DefKind::InlineConst + | DefKind::AnonConst => tcx.mir_const_qualif(def), + _ => ConstQualifs::default(), + }; + + // the `has_ffi_unwind_calls` query uses the raw mir, so make sure it is run. + tcx.ensure_done().has_ffi_unwind_calls(def); + + // the `by_move_body` query uses the raw mir, so make sure it is run. + if tcx.needs_coroutine_by_move_body_def_id(def.to_def_id()) { + tcx.ensure_done().coroutine_by_move_body_def_id(def); + } + + let mut body = tcx.mir_built(def).steal(); + if let Some(error_reported) = const_qualifs.tainted_by_errors { + body.tainted_by_errors = Some(error_reported); + } + + // Collect `required_consts` *before* promotion, so if there are any consts being promoted + // we still add them to the list in the outer MIR body. + RequiredConstsVisitor::compute_required_consts(&mut body); + + // What we need to run borrowck etc. + let promote_pass = promote_consts::PromoteTemps::default(); + pm::run_passes( + tcx, + &mut body, + &[&promote_pass, &simplify::SimplifyCfg::PromoteConsts, &coverage::InstrumentCoverage], + Some(MirPhase::Analysis(AnalysisPhase::Initial)), + pm::Optimizations::Allowed, + ); + + lint_tail_expr_drop_order::run_lint(tcx, def, &body); + + let promoted = promote_pass.promoted_fragments.into_inner(); + (tcx.alloc_steal_mir(body), tcx.alloc_steal_promoted(promoted)) +} + +/// Compute the MIR that is used during CTFE (and thus has no optimizations run on it) +fn mir_for_ctfe(tcx: TyCtxt<'_>, def_id: LocalDefId) -> &Body<'_> { + tcx.arena.alloc(inner_mir_for_ctfe(tcx, def_id)) +} + +fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> { + // FIXME: don't duplicate this between the optimized_mir/mir_for_ctfe queries + if tcx.is_constructor(def.to_def_id()) { + // There's no reason to run all of the MIR passes on constructors when + // we can just output the MIR we want directly. This also saves const + // qualification and borrow checking the trouble of special casing + // constructors. + return shim::build_adt_ctor(tcx, def.to_def_id()); + } + + let body = tcx.mir_drops_elaborated_and_const_checked(def); + let body = match tcx.hir_body_const_context(def) { + // consts and statics do not have `optimized_mir`, so we can steal the body instead of + // cloning it. + Some(hir::ConstContext::Const { .. } | hir::ConstContext::Static(_)) => body.steal(), + Some(hir::ConstContext::ConstFn) => body.borrow().clone(), + None => bug!("`mir_for_ctfe` called on non-const {def:?}"), + }; + + let mut body = remap_mir_for_const_eval_select(tcx, body, hir::Constness::Const); + pm::run_passes(tcx, &mut body, &[&ctfe_limit::CtfeLimit], None, pm::Optimizations::Allowed); + + body +} + +/// Obtain just the main MIR (no promoteds) and run some cleanups on it. This also runs +/// mir borrowck *before* doing so in order to ensure that borrowck can be run and doesn't +/// end up missing the source MIR due to stealing happening. +fn mir_drops_elaborated_and_const_checked(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> { + if tcx.is_coroutine(def.to_def_id()) { + tcx.ensure_done().mir_coroutine_witnesses(def); + } + + // We only need to borrowck non-synthetic MIR. + let tainted_by_errors = if !tcx.is_synthetic_mir(def) { + tcx.mir_borrowck(tcx.typeck_root_def_id(def.to_def_id()).expect_local()).err() + } else { + None + }; + + let is_fn_like = tcx.def_kind(def).is_fn_like(); + if is_fn_like { + // Do not compute the mir call graph without said call graph actually being used. + if pm::should_run_pass(tcx, &inline::Inline, pm::Optimizations::Allowed) + || inline::ForceInline::should_run_pass_for_callee(tcx, def.to_def_id()) + { + tcx.ensure_done().mir_inliner_callees(ty::InstanceKind::Item(def.to_def_id())); + } + } + + let (body, _) = tcx.mir_promoted(def); + let mut body = body.steal(); + + if let Some(error_reported) = tainted_by_errors { + body.tainted_by_errors = Some(error_reported); + } + + // Also taint the body if it's within a top-level item that is not well formed. + // + // We do this check here and not during `mir_promoted` because that may result + // in borrowck cycles if WF requires looking into an opaque hidden type. + let root = tcx.typeck_root_def_id(def.to_def_id()); + match tcx.def_kind(root) { + DefKind::Fn + | DefKind::AssocFn + | DefKind::Static { .. } + | DefKind::Const + | DefKind::AssocConst => { + if let Err(guar) = tcx.ensure_ok().check_well_formed(root.expect_local()) { + body.tainted_by_errors = Some(guar); + } + } + _ => {} + } + + run_analysis_to_runtime_passes(tcx, &mut body); + + tcx.alloc_steal_mir(body) +} + +// Made public so that `mir_drops_elaborated_and_const_checked` can be overridden +// by custom rustc drivers, running all the steps by themselves. See #114628. +pub fn run_analysis_to_runtime_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + assert!(body.phase == MirPhase::Analysis(AnalysisPhase::Initial)); + let did = body.source.def_id(); + + debug!("analysis_mir_cleanup({:?})", did); + run_analysis_cleanup_passes(tcx, body); + assert!(body.phase == MirPhase::Analysis(AnalysisPhase::PostCleanup)); + + // Do a little drop elaboration before const-checking if `const_precise_live_drops` is enabled. + if check_consts::post_drop_elaboration::checking_enabled(&ConstCx::new(tcx, body)) { + pm::run_passes( + tcx, + body, + &[ + &remove_uninit_drops::RemoveUninitDrops, + &simplify::SimplifyCfg::RemoveFalseEdges, + &Lint(post_drop_elaboration::CheckLiveDrops), + ], + None, + pm::Optimizations::Allowed, + ); + } + + debug!("runtime_mir_lowering({:?})", did); + run_runtime_lowering_passes(tcx, body); + assert!(body.phase == MirPhase::Runtime(RuntimePhase::Initial)); + + debug!("runtime_mir_cleanup({:?})", did); + run_runtime_cleanup_passes(tcx, body); + assert!(body.phase == MirPhase::Runtime(RuntimePhase::PostCleanup)); +} + +// FIXME(JakobDegen): Can we make these lists of passes consts? + +/// After this series of passes, no lifetime analysis based on borrowing can be done. +fn run_analysis_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let passes: &[&dyn MirPass<'tcx>] = &[ + &impossible_predicates::ImpossiblePredicates, + &cleanup_post_borrowck::CleanupPostBorrowck, + &remove_noop_landing_pads::RemoveNoopLandingPads, + &simplify::SimplifyCfg::PostAnalysis, + &deref_separator::Derefer, + ]; + + pm::run_passes( + tcx, + body, + passes, + Some(MirPhase::Analysis(AnalysisPhase::PostCleanup)), + pm::Optimizations::Allowed, + ); +} + +/// Returns the sequence of passes that lowers analysis to runtime MIR. +fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let passes: &[&dyn MirPass<'tcx>] = &[ + // These next passes must be executed together. + &add_call_guards::CriticalCallEdges, + // Must be done before drop elaboration because we need to drop opaque types, too. + &post_analysis_normalize::PostAnalysisNormalize, + // Calling this after `PostAnalysisNormalize` ensures that we don't deal with opaque types. + &add_subtyping_projections::Subtyper, + &elaborate_drops::ElaborateDrops, + // Needs to happen after drop elaboration. + &Lint(check_call_recursion::CheckDropRecursion), + // This will remove extraneous landing pads which are no longer + // necessary as well as forcing any call in a non-unwinding + // function calling a possibly-unwinding function to abort the process. + &abort_unwinding_calls::AbortUnwindingCalls, + // AddMovesForPackedDrops needs to run after drop + // elaboration. + &add_moves_for_packed_drops::AddMovesForPackedDrops, + // `AddRetag` needs to run after `ElaborateDrops` but before `ElaborateBoxDerefs`. + // Otherwise it should run fairly late, but before optimizations begin. + &add_retag::AddRetag, + &elaborate_box_derefs::ElaborateBoxDerefs, + &coroutine::StateTransform, + &Lint(known_panics_lint::KnownPanicsLint), + ]; + pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial))); +} + +/// Returns the sequence of passes that do the initial cleanup of runtime MIR. +fn run_runtime_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let passes: &[&dyn MirPass<'tcx>] = &[ + &lower_intrinsics::LowerIntrinsics, + &remove_place_mention::RemovePlaceMention, + &simplify::SimplifyCfg::PreOptimizations, + ]; + + pm::run_passes( + tcx, + body, + passes, + Some(MirPhase::Runtime(RuntimePhase::PostCleanup)), + pm::Optimizations::Allowed, + ); + + // Clear this by anticipation. Optimizations and runtime MIR have no reason to look + // into this information, which is meant for borrowck diagnostics. + for decl in &mut body.local_decls { + decl.local_info = ClearCrossCrate::Clear; + } +} + +pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + fn o1<T>(x: T) -> WithMinOptLevel<T> { + WithMinOptLevel(1, x) + } + + let def_id = body.source.def_id(); + let optimizations = if tcx.def_kind(def_id).has_codegen_attrs() + && tcx.codegen_fn_attrs(def_id).optimize.do_not_optimize() + { + pm::Optimizations::Suppressed + } else { + pm::Optimizations::Allowed + }; + + // The main optimizations that we do on MIR. + pm::run_passes( + tcx, + body, + &[ + // Add some UB checks before any UB gets optimized away. + &check_alignment::CheckAlignment, + &check_null::CheckNull, + &check_enums::CheckEnums, + // Before inlining: trim down MIR with passes to reduce inlining work. + + // Has to be done before inlining, otherwise actual call will be almost always inlined. + // Also simple, so can just do first. + &lower_slice_len::LowerSliceLenCalls, + // Perform instsimplify before inline to eliminate some trivial calls (like clone + // shims). + &instsimplify::InstSimplify::BeforeInline, + // Perform inlining of `#[rustc_force_inline]`-annotated callees. + &inline::ForceInline, + // Perform inlining, which may add a lot of code. + &inline::Inline, + // Code from other crates may have storage markers, so this needs to happen after + // inlining. + &remove_storage_markers::RemoveStorageMarkers, + // Inlining and instantiation may introduce ZST and useless drops. + &remove_zsts::RemoveZsts, + &remove_unneeded_drops::RemoveUnneededDrops, + // Type instantiation may create uninhabited enums. + // Also eliminates some unreachable branches based on variants of enums. + &unreachable_enum_branching::UnreachableEnumBranching, + &unreachable_prop::UnreachablePropagation, + &o1(simplify::SimplifyCfg::AfterUnreachableEnumBranching), + // Inlining may have introduced a lot of redundant code and a large move pattern. + // Now, we need to shrink the generated MIR. + &ref_prop::ReferencePropagation, + &sroa::ScalarReplacementOfAggregates, + &multiple_return_terminators::MultipleReturnTerminators, + // After simplifycfg, it allows us to discover new opportunities for peephole + // optimizations. + &instsimplify::InstSimplify::AfterSimplifyCfg, + &simplify::SimplifyLocals::BeforeConstProp, + &dead_store_elimination::DeadStoreElimination::Initial, + &gvn::GVN, + &simplify::SimplifyLocals::AfterGVN, + &match_branches::MatchBranchSimplification, + &dataflow_const_prop::DataflowConstProp, + &single_use_consts::SingleUseConsts, + &o1(simplify_branches::SimplifyConstCondition::AfterConstProp), + &jump_threading::JumpThreading, + &early_otherwise_branch::EarlyOtherwiseBranch, + &simplify_comparison_integral::SimplifyComparisonIntegral, + &dest_prop::DestinationPropagation, + &o1(simplify_branches::SimplifyConstCondition::Final), + &o1(remove_noop_landing_pads::RemoveNoopLandingPads), + &o1(simplify::SimplifyCfg::Final), + // After the last SimplifyCfg, because this wants one-block functions. + &strip_debuginfo::StripDebugInfo, + ©_prop::CopyProp, + &dead_store_elimination::DeadStoreElimination::Final, + &nrvo::RenameReturnPlace, + &simplify::SimplifyLocals::Final, + &multiple_return_terminators::MultipleReturnTerminators, + &large_enums::EnumSizeOpt { discrepancy: 128 }, + // Some cleanup necessary at least for LLVM and potentially other codegen backends. + &add_call_guards::CriticalCallEdges, + // Cleanup for human readability, off by default. + &prettify::ReorderBasicBlocks, + &prettify::ReorderLocals, + // Dump the end result for testing and debugging purposes. + &dump_mir::Marker("PreCodegen"), + ], + Some(MirPhase::Runtime(RuntimePhase::Optimized)), + optimizations, + ); +} + +/// Optimize the MIR and prepare it for codegen. +fn optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> &Body<'_> { + tcx.arena.alloc(inner_optimized_mir(tcx, did)) +} + +fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> { + if tcx.is_constructor(did.to_def_id()) { + // There's no reason to run all of the MIR passes on constructors when + // we can just output the MIR we want directly. This also saves const + // qualification and borrow checking the trouble of special casing + // constructors. + return shim::build_adt_ctor(tcx, did.to_def_id()); + } + + match tcx.hir_body_const_context(did) { + // Run the `mir_for_ctfe` query, which depends on `mir_drops_elaborated_and_const_checked` + // which we are going to steal below. Thus we need to run `mir_for_ctfe` first, so it + // computes and caches its result. + Some(hir::ConstContext::ConstFn) => tcx.ensure_done().mir_for_ctfe(did), + None => {} + Some(other) => panic!("do not use `optimized_mir` for constants: {other:?}"), + } + debug!("about to call mir_drops_elaborated..."); + let body = tcx.mir_drops_elaborated_and_const_checked(did).steal(); + let mut body = remap_mir_for_const_eval_select(tcx, body, hir::Constness::NotConst); + + if body.tainted_by_errors.is_some() { + return body; + } + + // Before doing anything, remember which items are being mentioned so that the set of items + // visited does not depend on the optimization level. + // We do not use `run_passes` for this as that might skip the pass if `injection_phase` is set. + mentioned_items::MentionedItems.run_pass(tcx, &mut body); + + // If `mir_drops_elaborated_and_const_checked` found that the current body has unsatisfiable + // predicates, it will shrink the MIR to a single `unreachable` terminator. + // More generally, if MIR is a lone `unreachable`, there is nothing to optimize. + if let TerminatorKind::Unreachable = body.basic_blocks[START_BLOCK].terminator().kind + && body.basic_blocks[START_BLOCK].statements.is_empty() + { + return body; + } + + run_optimization_passes(tcx, &mut body); + + body +} + +/// Fetch all the promoteds of an item and prepare their MIR bodies to be ready for +/// constant evaluation once all generic parameters become known. +fn promoted_mir(tcx: TyCtxt<'_>, def: LocalDefId) -> &IndexVec<Promoted, Body<'_>> { + if tcx.is_constructor(def.to_def_id()) { + return tcx.arena.alloc(IndexVec::new()); + } + + if !tcx.is_synthetic_mir(def) { + tcx.ensure_done().mir_borrowck(tcx.typeck_root_def_id(def.to_def_id()).expect_local()); + } + let mut promoted = tcx.mir_promoted(def).1.steal(); + + for body in &mut promoted { + run_analysis_to_runtime_passes(tcx, body); + } + + tcx.arena.alloc(promoted) +} diff --git a/compiler/rustc_mir_transform/src/lint.rs b/compiler/rustc_mir_transform/src/lint.rs new file mode 100644 index 00000000000..f472c7cb493 --- /dev/null +++ b/compiler/rustc_mir_transform/src/lint.rs @@ -0,0 +1,152 @@ +//! This pass statically detects code which has undefined behaviour or is likely to be erroneous. +//! It can be used to locate problems in MIR building or optimizations. It assumes that all code +//! can be executed, so it has false positives. + +use std::borrow::Cow; + +use rustc_data_structures::fx::FxHashSet; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::mir::visit::{PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::impls::{MaybeStorageDead, MaybeStorageLive, always_storage_live_locals}; +use rustc_mir_dataflow::{Analysis, ResultsCursor}; + +pub(super) fn lint_body<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, when: String) { + let always_live_locals = &always_storage_live_locals(body); + + let maybe_storage_live = MaybeStorageLive::new(Cow::Borrowed(always_live_locals)) + .iterate_to_fixpoint(tcx, body, None) + .into_results_cursor(body); + + let maybe_storage_dead = MaybeStorageDead::new(Cow::Borrowed(always_live_locals)) + .iterate_to_fixpoint(tcx, body, None) + .into_results_cursor(body); + + let mut lint = Lint { + tcx, + when, + body, + is_fn_like: tcx.def_kind(body.source.def_id()).is_fn_like(), + always_live_locals, + maybe_storage_live, + maybe_storage_dead, + places: Default::default(), + }; + for (bb, data) in traversal::reachable(body) { + lint.visit_basic_block_data(bb, data); + } +} + +struct Lint<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + when: String, + body: &'a Body<'tcx>, + is_fn_like: bool, + always_live_locals: &'a DenseBitSet<Local>, + maybe_storage_live: ResultsCursor<'a, 'tcx, MaybeStorageLive<'a>>, + maybe_storage_dead: ResultsCursor<'a, 'tcx, MaybeStorageDead<'a>>, + places: FxHashSet<PlaceRef<'tcx>>, +} + +impl<'a, 'tcx> Lint<'a, 'tcx> { + #[track_caller] + fn fail(&self, location: Location, msg: impl AsRef<str>) { + let span = self.body.source_info(location).span; + self.tcx.sess.dcx().span_delayed_bug( + span, + format!( + "broken MIR in {:?} ({}) at {:?}:\n{}", + self.body.source.instance, + self.when, + location, + msg.as_ref() + ), + ); + } +} + +impl<'a, 'tcx> Visitor<'tcx> for Lint<'a, 'tcx> { + fn visit_local(&mut self, local: Local, context: PlaceContext, location: Location) { + if context.is_use() { + self.maybe_storage_dead.seek_after_primary_effect(location); + if self.maybe_storage_dead.get().contains(local) { + self.fail(location, format!("use of local {local:?}, which has no storage here")); + } + } + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::Assign(box (dest, rvalue)) => { + if let Rvalue::Use(Operand::Copy(src) | Operand::Move(src)) = rvalue { + // The sides of an assignment must not alias. Currently this just checks whether + // the places are identical. + if dest == src { + self.fail( + location, + "encountered `Assign` statement with overlapping memory", + ); + } + } + } + StatementKind::StorageLive(local) => { + self.maybe_storage_live.seek_before_primary_effect(location); + if self.maybe_storage_live.get().contains(*local) { + self.fail( + location, + format!("StorageLive({local:?}) which already has storage here"), + ); + } + } + _ => {} + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + match &terminator.kind { + TerminatorKind::Return => { + if self.is_fn_like { + self.maybe_storage_live.seek_after_primary_effect(location); + for local in self.maybe_storage_live.get().iter() { + if !self.always_live_locals.contains(local) { + self.fail( + location, + format!( + "local {local:?} still has storage when returning from function" + ), + ); + } + } + } + } + TerminatorKind::Call { args, destination, .. } => { + // The call destination place and Operand::Move place used as an argument might be + // passed by a reference to the callee. Consequently they must be non-overlapping. + // Currently this simply checks for duplicate places. + self.places.clear(); + self.places.insert(destination.as_ref()); + let mut has_duplicates = false; + for arg in args { + if let Operand::Move(place) = &arg.node { + has_duplicates |= !self.places.insert(place.as_ref()); + } + } + if has_duplicates { + self.fail( + location, + format!( + "encountered overlapping memory in `Move` arguments to `Call` terminator: {:?}", + terminator.kind, + ), + ); + } + } + _ => {} + } + + self.super_terminator(terminator, location); + } +} diff --git a/compiler/rustc_mir_transform/src/lint_tail_expr_drop_order.rs b/compiler/rustc_mir_transform/src/lint_tail_expr_drop_order.rs new file mode 100644 index 00000000000..1bd770a8526 --- /dev/null +++ b/compiler/rustc_mir_transform/src/lint_tail_expr_drop_order.rs @@ -0,0 +1,544 @@ +use std::cell::RefCell; +use std::collections::hash_map; +use std::rc::Rc; + +use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap}; +use rustc_data_structures::unord::{UnordMap, UnordSet}; +use rustc_errors::Subdiagnostic; +use rustc_hir::CRATE_HIR_ID; +use rustc_hir::def_id::LocalDefId; +use rustc_index::bit_set::MixedBitSet; +use rustc_index::{IndexSlice, IndexVec}; +use rustc_macros::{LintDiagnostic, Subdiagnostic}; +use rustc_middle::bug; +use rustc_middle::mir::{ + self, BasicBlock, Body, ClearCrossCrate, Local, Location, Place, StatementKind, TerminatorKind, + dump_mir, +}; +use rustc_middle::ty::significant_drop_order::{ + extract_component_with_significant_dtor, ty_dtor_span, +}; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_mir_dataflow::impls::MaybeInitializedPlaces; +use rustc_mir_dataflow::move_paths::{LookupResult, MoveData, MovePathIndex}; +use rustc_mir_dataflow::{Analysis, MaybeReachable, ResultsCursor}; +use rustc_session::lint::builtin::TAIL_EXPR_DROP_ORDER; +use rustc_session::lint::{self}; +use rustc_span::{DUMMY_SP, Span, Symbol}; +use tracing::debug; + +fn place_has_common_prefix<'tcx>(left: &Place<'tcx>, right: &Place<'tcx>) -> bool { + left.local == right.local + && left.projection.iter().zip(right.projection).all(|(left, right)| left == right) +} + +/// Cache entry of `drop` at a `BasicBlock` +#[derive(Debug, Clone, Copy)] +enum MovePathIndexAtBlock { + /// We know nothing yet + Unknown, + /// We know that the `drop` here has no effect + None, + /// We know that the `drop` here will invoke a destructor + Some(MovePathIndex), +} + +struct DropsReachable<'a, 'mir, 'tcx> { + body: &'a Body<'tcx>, + place: &'a Place<'tcx>, + drop_span: &'a mut Option<Span>, + move_data: &'a MoveData<'tcx>, + maybe_init: &'a mut ResultsCursor<'mir, 'tcx, MaybeInitializedPlaces<'mir, 'tcx>>, + block_drop_value_info: &'a mut IndexSlice<BasicBlock, MovePathIndexAtBlock>, + collected_drops: &'a mut MixedBitSet<MovePathIndex>, + visited: FxHashMap<BasicBlock, Rc<RefCell<MixedBitSet<MovePathIndex>>>>, +} + +impl<'a, 'mir, 'tcx> DropsReachable<'a, 'mir, 'tcx> { + fn visit(&mut self, block: BasicBlock) { + let move_set_size = self.move_data.move_paths.len(); + let make_new_path_set = || Rc::new(RefCell::new(MixedBitSet::new_empty(move_set_size))); + + let data = &self.body.basic_blocks[block]; + let Some(terminator) = &data.terminator else { return }; + // Given that we observe these dropped locals here at `block` so far, we will try to update + // the successor blocks. An occupied entry at `block` in `self.visited` signals that we + // have visited `block` before. + let dropped_local_here = + Rc::clone(self.visited.entry(block).or_insert_with(make_new_path_set)); + // We could have invoked reverse lookup for a `MovePathIndex` every time, but unfortunately + // it is expensive. Let's cache them in `self.block_drop_value_info`. + match self.block_drop_value_info[block] { + MovePathIndexAtBlock::Some(dropped) => { + dropped_local_here.borrow_mut().insert(dropped); + } + MovePathIndexAtBlock::Unknown => { + if let TerminatorKind::Drop { place, .. } = &terminator.kind + && let LookupResult::Exact(idx) | LookupResult::Parent(Some(idx)) = + self.move_data.rev_lookup.find(place.as_ref()) + { + // Since we are working with MIRs at a very early stage, observing a `drop` + // terminator is not indicative enough that the drop will definitely happen. + // That is decided in the drop elaboration pass instead. Therefore, we need to + // consult with the maybe-initialization information. + self.maybe_init.seek_before_primary_effect(Location { + block, + statement_index: data.statements.len(), + }); + + // Check if the drop of `place` under inspection is really in effect. This is + // true only when `place` may have been initialized along a control flow path + // from a BID to the drop program point today. In other words, this is where + // the drop of `place` will happen in the future instead. + if let MaybeReachable::Reachable(maybe_init) = self.maybe_init.get() + && maybe_init.contains(idx) + { + // We also cache the drop information, so that we do not need to check on + // data-flow cursor again. + self.block_drop_value_info[block] = MovePathIndexAtBlock::Some(idx); + dropped_local_here.borrow_mut().insert(idx); + } else { + self.block_drop_value_info[block] = MovePathIndexAtBlock::None; + } + } + } + MovePathIndexAtBlock::None => {} + } + + for succ in terminator.successors() { + let target = &self.body.basic_blocks[succ]; + if target.is_cleanup { + continue; + } + + // As long as we are passing through a new block, or new dropped places to propagate, + // we will proceed with `succ` + let dropped_local_there = match self.visited.entry(succ) { + hash_map::Entry::Occupied(occupied_entry) => { + if succ == block + || !occupied_entry.get().borrow_mut().union(&*dropped_local_here.borrow()) + { + // `succ` has been visited but no new drops observed so far, + // so we can bail on `succ` until new drop information arrives + continue; + } + Rc::clone(occupied_entry.get()) + } + hash_map::Entry::Vacant(vacant_entry) => Rc::clone( + vacant_entry.insert(Rc::new(RefCell::new(dropped_local_here.borrow().clone()))), + ), + }; + if let Some(terminator) = &target.terminator + && let TerminatorKind::Drop { + place: dropped_place, + target: _, + unwind: _, + replace: _, + drop: _, + async_fut: _, + } = &terminator.kind + && place_has_common_prefix(dropped_place, self.place) + { + // We have now reached the current drop of the `place`. + // Let's check the observed dropped places in. + self.collected_drops.union(&*dropped_local_there.borrow()); + if self.drop_span.is_none() { + // FIXME(@dingxiangfei2009): it turns out that `self.body.source_scopes` are + // still a bit wonky. There is a high chance that this span still points to a + // block rather than a statement semicolon. + *self.drop_span = Some(terminator.source_info.span); + } + // Now we have discovered a simple control flow path from a future drop point + // to the current drop point. + // We will not continue from there. + } else { + self.visit(succ) + } + } + } +} + +/// Check if a moved place at `idx` is a part of a BID. +/// The use of this check is that we will consider drops on these +/// as a drop of the overall BID and, thus, we can exclude it from the diagnosis. +fn place_descendent_of_bids<'tcx>( + mut idx: MovePathIndex, + move_data: &MoveData<'tcx>, + bids: &UnordSet<&Place<'tcx>>, +) -> bool { + loop { + let path = &move_data.move_paths[idx]; + if bids.contains(&path.place) { + return true; + } + if let Some(parent) = path.parent { + idx = parent; + } else { + return false; + } + } +} + +/// The core of the lint `tail-expr-drop-order` +pub(crate) fn run_lint<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId, body: &Body<'tcx>) { + if matches!(tcx.def_kind(def_id), rustc_hir::def::DefKind::SyntheticCoroutineBody) { + // A synthetic coroutine has no HIR body and it is enough to just analyse the original body + return; + } + if body.span.edition().at_least_rust_2024() + || tcx.lints_that_dont_need_to_run(()).contains(&lint::LintId::of(TAIL_EXPR_DROP_ORDER)) + { + return; + } + + // FIXME(typing_env): This should be able to reveal the opaques local to the + // body using the typeck results. + let typing_env = ty::TypingEnv::non_body_analysis(tcx, def_id); + + // ## About BIDs in blocks ## + // Track the set of blocks that contain a backwards-incompatible drop (BID) + // and, for each block, the vector of locations. + // + // We group them per-block because they tend to scheduled in the same drop ladder block. + let mut bid_per_block = FxIndexMap::default(); + let mut bid_places = UnordSet::new(); + + let mut ty_dropped_components = UnordMap::default(); + for (block, data) in body.basic_blocks.iter_enumerated() { + for (statement_index, stmt) in data.statements.iter().enumerate() { + if let StatementKind::BackwardIncompatibleDropHint { place, reason: _ } = &stmt.kind { + let ty = place.ty(body, tcx).ty; + if ty_dropped_components + .entry(ty) + .or_insert_with(|| extract_component_with_significant_dtor(tcx, typing_env, ty)) + .is_empty() + { + continue; + } + bid_per_block + .entry(block) + .or_insert(vec![]) + .push((Location { block, statement_index }, &**place)); + bid_places.insert(&**place); + } + } + } + if bid_per_block.is_empty() { + return; + } + + dump_mir(tcx, false, "lint_tail_expr_drop_order", &0 as _, body, |_, _| Ok(())); + let locals_with_user_names = collect_user_names(body); + let is_closure_like = tcx.is_closure_like(def_id.to_def_id()); + + // Compute the "maybe initialized" information for this body. + // When we encounter a DROP of some place P we only care + // about the drop if `P` may be initialized. + let move_data = MoveData::gather_moves(body, tcx, |_| true); + let mut maybe_init = MaybeInitializedPlaces::new(tcx, body, &move_data) + .iterate_to_fixpoint(tcx, body, None) + .into_results_cursor(body); + let mut block_drop_value_info = + IndexVec::from_elem_n(MovePathIndexAtBlock::Unknown, body.basic_blocks.len()); + for (&block, candidates) in &bid_per_block { + // We will collect drops on locals on paths between BID points to their actual drop locations + // into `all_locals_dropped`. + let mut all_locals_dropped = MixedBitSet::new_empty(move_data.move_paths.len()); + let mut drop_span = None; + for &(_, place) in candidates.iter() { + let mut collected_drops = MixedBitSet::new_empty(move_data.move_paths.len()); + // ## On detecting change in relative drop order ## + // Iterate through each BID-containing block `block`. + // If the place `P` targeted by the BID is "maybe initialized", + // then search forward to find the actual `DROP(P)` point. + // Everything dropped between the BID and the actual drop point + // is something whose relative drop order will change. + DropsReachable { + body, + place, + drop_span: &mut drop_span, + move_data: &move_data, + maybe_init: &mut maybe_init, + block_drop_value_info: &mut block_drop_value_info, + collected_drops: &mut collected_drops, + visited: Default::default(), + } + .visit(block); + // Compute the set `all_locals_dropped` of local variables that are dropped + // after the BID point but before the current drop point. + // + // These are the variables whose drop impls will be reordered with respect + // to `place`. + all_locals_dropped.union(&collected_drops); + } + + // We shall now exclude some local bindings for the following cases. + { + let mut to_exclude = MixedBitSet::new_empty(all_locals_dropped.domain_size()); + // We will now do subtraction from the candidate dropped locals, because of the + // following reasons. + for path_idx in all_locals_dropped.iter() { + let move_path = &move_data.move_paths[path_idx]; + let dropped_local = move_path.place.local; + // a) A return value _0 will eventually be used + // Example: + // fn f() -> Droppy { + // let _x = Droppy; + // Droppy + // } + // _0 holds the literal `Droppy` and rightfully `_x` has to be dropped first + if dropped_local == Local::ZERO { + debug!(?dropped_local, "skip return value"); + to_exclude.insert(path_idx); + continue; + } + // b) If we are analysing a closure, the captures are still dropped last. + // This is part of the closure capture lifetime contract. + // They are similar to the return value _0 with respect to lifetime rules. + if is_closure_like && matches!(dropped_local, ty::CAPTURE_STRUCT_LOCAL) { + debug!(?dropped_local, "skip closure captures"); + to_exclude.insert(path_idx); + continue; + } + // c) Sometimes we collect places that are projections into the BID locals, + // so they are considered dropped now. + // Example: + // struct NotVeryDroppy(Droppy); + // impl Drop for Droppy {..} + // fn f() -> NotVeryDroppy { + // let x = NotVeryDroppy(droppy()); + // { + // let y: Droppy = x.0; + // NotVeryDroppy(y) + // } + // } + // `y` takes `x.0`, which invalidates `x` as a complete `NotVeryDroppy` + // so there is no point in linting against `x` any more. + if place_descendent_of_bids(path_idx, &move_data, &bid_places) { + debug!(?dropped_local, "skip descendent of bids"); + to_exclude.insert(path_idx); + continue; + } + let observer_ty = move_path.place.ty(body, tcx).ty; + // d) The collected local has no custom destructor that passes our ecosystem filter. + if ty_dropped_components + .entry(observer_ty) + .or_insert_with(|| { + extract_component_with_significant_dtor(tcx, typing_env, observer_ty) + }) + .is_empty() + { + debug!(?dropped_local, "skip non-droppy types"); + to_exclude.insert(path_idx); + continue; + } + } + // Suppose that all BIDs point into the same local, + // we can remove the this local from the observed drops, + // so that we can focus our diagnosis more on the others. + if candidates.iter().all(|&(_, place)| candidates[0].1.local == place.local) { + for path_idx in all_locals_dropped.iter() { + if move_data.move_paths[path_idx].place.local == candidates[0].1.local { + to_exclude.insert(path_idx); + } + } + } + all_locals_dropped.subtract(&to_exclude); + } + if all_locals_dropped.is_empty() { + // No drop effect is observable, so let us move on. + continue; + } + + // ## The final work to assemble the diagnosis ## + // First collect or generate fresh names for local variable bindings and temporary values. + let local_names = assign_observables_names( + all_locals_dropped + .iter() + .map(|path_idx| move_data.move_paths[path_idx].place.local) + .chain(candidates.iter().map(|(_, place)| place.local)), + &locals_with_user_names, + ); + + let mut lint_root = None; + let mut local_labels = vec![]; + // We now collect the types with custom destructors. + for &(_, place) in candidates { + let linted_local_decl = &body.local_decls[place.local]; + let Some(&(ref name, is_generated_name)) = local_names.get(&place.local) else { + bug!("a name should have been assigned") + }; + let name = name.as_str(); + + if lint_root.is_none() + && let ClearCrossCrate::Set(data) = + &body.source_scopes[linted_local_decl.source_info.scope].local_data + { + lint_root = Some(data.lint_root); + } + + // Collect spans of the custom destructors. + let mut seen_dyn = false; + let destructors = ty_dropped_components + .get(&linted_local_decl.ty) + .unwrap() + .iter() + .filter_map(|&ty| { + if let Some(span) = ty_dtor_span(tcx, ty) { + Some(DestructorLabel { span, name, dtor_kind: "concrete" }) + } else if matches!(ty.kind(), ty::Dynamic(..)) { + if seen_dyn { + None + } else { + seen_dyn = true; + Some(DestructorLabel { span: DUMMY_SP, name, dtor_kind: "dyn" }) + } + } else { + None + } + }) + .collect(); + local_labels.push(LocalLabel { + span: linted_local_decl.source_info.span, + destructors, + name, + is_generated_name, + is_dropped_first_edition_2024: true, + }); + } + + // Similarly, custom destructors of the observed drops. + for path_idx in all_locals_dropped.iter() { + let place = &move_data.move_paths[path_idx].place; + // We are not using the type of the local because the drop may be partial. + let observer_ty = place.ty(body, tcx).ty; + + let observer_local_decl = &body.local_decls[place.local]; + let Some(&(ref name, is_generated_name)) = local_names.get(&place.local) else { + bug!("a name should have been assigned") + }; + let name = name.as_str(); + + let mut seen_dyn = false; + let destructors = extract_component_with_significant_dtor(tcx, typing_env, observer_ty) + .into_iter() + .filter_map(|ty| { + if let Some(span) = ty_dtor_span(tcx, ty) { + Some(DestructorLabel { span, name, dtor_kind: "concrete" }) + } else if matches!(ty.kind(), ty::Dynamic(..)) { + if seen_dyn { + None + } else { + seen_dyn = true; + Some(DestructorLabel { span: DUMMY_SP, name, dtor_kind: "dyn" }) + } + } else { + None + } + }) + .collect(); + local_labels.push(LocalLabel { + span: observer_local_decl.source_info.span, + destructors, + name, + is_generated_name, + is_dropped_first_edition_2024: false, + }); + } + + let span = local_labels[0].span; + tcx.emit_node_span_lint( + lint::builtin::TAIL_EXPR_DROP_ORDER, + lint_root.unwrap_or(CRATE_HIR_ID), + span, + TailExprDropOrderLint { local_labels, drop_span, _epilogue: () }, + ); + } +} + +/// Extract binding names if available for diagnosis +fn collect_user_names(body: &Body<'_>) -> FxIndexMap<Local, Symbol> { + let mut names = FxIndexMap::default(); + for var_debug_info in &body.var_debug_info { + if let mir::VarDebugInfoContents::Place(place) = &var_debug_info.value + && let Some(local) = place.local_or_deref_local() + { + names.entry(local).or_insert(var_debug_info.name); + } + } + names +} + +/// Assign names for anonymous or temporary values for diagnosis +fn assign_observables_names( + locals: impl IntoIterator<Item = Local>, + user_names: &FxIndexMap<Local, Symbol>, +) -> FxIndexMap<Local, (String, bool)> { + let mut names = FxIndexMap::default(); + let mut assigned_names = FxHashSet::default(); + let mut idx = 0u64; + let mut fresh_name = || { + idx += 1; + (format!("#{idx}"), true) + }; + for local in locals { + let name = if let Some(name) = user_names.get(&local) { + let name = name.as_str(); + if assigned_names.contains(name) { fresh_name() } else { (name.to_owned(), false) } + } else { + fresh_name() + }; + assigned_names.insert(name.0.clone()); + names.insert(local, name); + } + names +} + +#[derive(LintDiagnostic)] +#[diag(mir_transform_tail_expr_drop_order)] +struct TailExprDropOrderLint<'a> { + #[subdiagnostic] + local_labels: Vec<LocalLabel<'a>>, + #[label(mir_transform_drop_location)] + drop_span: Option<Span>, + #[note(mir_transform_note_epilogue)] + _epilogue: (), +} + +struct LocalLabel<'a> { + span: Span, + name: &'a str, + is_generated_name: bool, + is_dropped_first_edition_2024: bool, + destructors: Vec<DestructorLabel<'a>>, +} + +/// A custom `Subdiagnostic` implementation so that the notes are delivered in a specific order +impl Subdiagnostic for LocalLabel<'_> { + fn add_to_diag<G: rustc_errors::EmissionGuarantee>(self, diag: &mut rustc_errors::Diag<'_, G>) { + // Becuase parent uses this field , we need to remove it delay before adding it. + diag.remove_arg("name"); + diag.arg("name", self.name); + diag.remove_arg("is_generated_name"); + diag.arg("is_generated_name", self.is_generated_name); + diag.remove_arg("is_dropped_first_edition_2024"); + diag.arg("is_dropped_first_edition_2024", self.is_dropped_first_edition_2024); + let msg = diag.eagerly_translate(crate::fluent_generated::mir_transform_tail_expr_local); + diag.span_label(self.span, msg); + for dtor in self.destructors { + dtor.add_to_diag(diag); + } + let msg = + diag.eagerly_translate(crate::fluent_generated::mir_transform_label_local_epilogue); + diag.span_label(self.span, msg); + } +} + +#[derive(Subdiagnostic)] +#[note(mir_transform_tail_expr_dtor)] +struct DestructorLabel<'a> { + #[primary_span] + span: Span, + dtor_kind: &'static str, + name: &'a str, +} diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs new file mode 100644 index 00000000000..fa29ab985b7 --- /dev/null +++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs @@ -0,0 +1,389 @@ +//! Lowers intrinsic calls + +use rustc_middle::mir::*; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::{bug, span_bug}; +use rustc_span::sym; + +use crate::take_array; + +pub(super) struct LowerIntrinsics; + +impl<'tcx> crate::MirPass<'tcx> for LowerIntrinsics { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let local_decls = &body.local_decls; + for block in body.basic_blocks.as_mut() { + let terminator = block.terminator.as_mut().unwrap(); + if let TerminatorKind::Call { func, args, destination, target, .. } = + &mut terminator.kind + && let ty::FnDef(def_id, generic_args) = *func.ty(local_decls, tcx).kind() + && let Some(intrinsic) = tcx.intrinsic(def_id) + { + match intrinsic.name { + sym::unreachable => { + terminator.kind = TerminatorKind::Unreachable; + } + sym::ub_checks => { + let target = target.unwrap(); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::NullaryOp(NullOp::UbChecks, tcx.types.bool), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::contract_checks => { + let target = target.unwrap(); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::NullaryOp(NullOp::ContractChecks, tcx.types.bool), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::forget => { + let target = target.unwrap(); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Use(Operand::Constant(Box::new(ConstOperand { + span: terminator.source_info.span, + user_ty: None, + const_: Const::zero_sized(tcx.types.unit), + }))), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::copy_nonoverlapping => { + let target = target.unwrap(); + let Ok([src, dst, count]) = take_array(args) else { + bug!("Wrong arguments for copy_non_overlapping intrinsic"); + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Intrinsic(Box::new( + NonDivergingIntrinsic::CopyNonOverlapping( + rustc_middle::mir::CopyNonOverlapping { + src: src.node, + dst: dst.node, + count: count.node, + }, + ), + )), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::assume => { + let target = target.unwrap(); + let Ok([arg]) = take_array(args) else { + bug!("Wrong arguments for assume intrinsic"); + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Intrinsic(Box::new( + NonDivergingIntrinsic::Assume(arg.node), + )), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::wrapping_add + | sym::wrapping_sub + | sym::wrapping_mul + | sym::three_way_compare + | sym::unchecked_add + | sym::unchecked_sub + | sym::unchecked_mul + | sym::unchecked_div + | sym::unchecked_rem + | sym::unchecked_shl + | sym::unchecked_shr => { + let target = target.unwrap(); + let Ok([lhs, rhs]) = take_array(args) else { + bug!("Wrong arguments for {} intrinsic", intrinsic.name); + }; + let bin_op = match intrinsic.name { + sym::wrapping_add => BinOp::Add, + sym::wrapping_sub => BinOp::Sub, + sym::wrapping_mul => BinOp::Mul, + sym::three_way_compare => BinOp::Cmp, + sym::unchecked_add => BinOp::AddUnchecked, + sym::unchecked_sub => BinOp::SubUnchecked, + sym::unchecked_mul => BinOp::MulUnchecked, + sym::unchecked_div => BinOp::Div, + sym::unchecked_rem => BinOp::Rem, + sym::unchecked_shl => BinOp::ShlUnchecked, + sym::unchecked_shr => BinOp::ShrUnchecked, + _ => bug!("unexpected intrinsic"), + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::BinaryOp(bin_op, Box::new((lhs.node, rhs.node))), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::add_with_overflow | sym::sub_with_overflow | sym::mul_with_overflow => { + let target = target.unwrap(); + let Ok([lhs, rhs]) = take_array(args) else { + bug!("Wrong arguments for {} intrinsic", intrinsic.name); + }; + let bin_op = match intrinsic.name { + sym::add_with_overflow => BinOp::AddWithOverflow, + sym::sub_with_overflow => BinOp::SubWithOverflow, + sym::mul_with_overflow => BinOp::MulWithOverflow, + _ => bug!("unexpected intrinsic"), + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::BinaryOp(bin_op, Box::new((lhs.node, rhs.node))), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::size_of | sym::align_of => { + let target = target.unwrap(); + let tp_ty = generic_args.type_at(0); + let null_op = match intrinsic.name { + sym::size_of => NullOp::SizeOf, + sym::align_of => NullOp::AlignOf, + _ => bug!("unexpected intrinsic"), + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::NullaryOp(null_op, tp_ty), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::read_via_copy => { + let Ok([arg]) = take_array(args) else { + span_bug!(terminator.source_info.span, "Wrong number of arguments"); + }; + let derefed_place = if let Some(place) = arg.node.place() + && let Some(local) = place.as_local() + { + tcx.mk_place_deref(local.into()) + } else { + span_bug!( + terminator.source_info.span, + "Only passing a local is supported" + ); + }; + // Add new statement at the end of the block that does the read, and patch + // up the terminator. + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Use(Operand::Copy(derefed_place)), + ))), + }); + terminator.kind = match *target { + None => { + // No target means this read something uninhabited, + // so it must be unreachable. + TerminatorKind::Unreachable + } + Some(target) => TerminatorKind::Goto { target }, + } + } + sym::write_via_move => { + let target = target.unwrap(); + let Ok([ptr, val]) = take_array(args) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for write_via_move intrinsic", + ); + }; + let derefed_place = if let Some(place) = ptr.node.place() + && let Some(local) = place.as_local() + { + tcx.mk_place_deref(local.into()) + } else { + span_bug!( + terminator.source_info.span, + "Only passing a local is supported" + ); + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + derefed_place, + Rvalue::Use(val.node), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::discriminant_value => { + let target = target.unwrap(); + let Ok([arg]) = take_array(args) else { + span_bug!( + terminator.source_info.span, + "Wrong arguments for discriminant_value intrinsic" + ); + }; + let arg = arg.node.place().unwrap(); + let arg = tcx.mk_place_deref(arg); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Discriminant(arg), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::offset => { + let target = target.unwrap(); + let Ok([ptr, delta]) = take_array(args) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for offset intrinsic", + ); + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::BinaryOp(BinOp::Offset, Box::new((ptr.node, delta.node))), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::slice_get_unchecked => { + let target = target.unwrap(); + let Ok([ptrish, index]) = take_array(args) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for {intrinsic:?}", + ); + }; + + let place = ptrish.node.place().unwrap(); + assert!(!place.is_indirect()); + let updated_place = place.project_deeper( + &[ + ProjectionElem::Deref, + ProjectionElem::Index( + index.node.place().unwrap().as_local().unwrap(), + ), + ], + tcx, + ); + + let ret_ty = generic_args.type_at(0); + let rvalue = match *ret_ty.kind() { + ty::RawPtr(_, Mutability::Not) => { + Rvalue::RawPtr(RawPtrKind::Const, updated_place) + } + ty::RawPtr(_, Mutability::Mut) => { + Rvalue::RawPtr(RawPtrKind::Mut, updated_place) + } + ty::Ref(region, _, Mutability::Not) => { + Rvalue::Ref(region, BorrowKind::Shared, updated_place) + } + ty::Ref(region, _, Mutability::Mut) => Rvalue::Ref( + region, + BorrowKind::Mut { kind: MutBorrowKind::Default }, + updated_place, + ), + _ => bug!("Unknown return type {ret_ty:?}"), + }; + + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new((*destination, rvalue))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::transmute | sym::transmute_unchecked => { + let dst_ty = destination.ty(local_decls, tcx).ty; + let Ok([arg]) = take_array(args) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for transmute intrinsic", + ); + }; + + // Always emit the cast, even if we transmute to an uninhabited type, + // because that lets CTFE and codegen generate better error messages + // when such a transmute actually ends up reachable. + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Cast(CastKind::Transmute, arg.node, dst_ty), + ))), + }); + if let Some(target) = *target { + terminator.kind = TerminatorKind::Goto { target }; + } else { + terminator.kind = TerminatorKind::Unreachable; + } + } + sym::aggregate_raw_ptr => { + let Ok([data, meta]) = take_array(args) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for aggregate_raw_ptr intrinsic", + ); + }; + let target = target.unwrap(); + let pointer_ty = generic_args.type_at(0); + let kind = if let ty::RawPtr(pointee_ty, mutability) = pointer_ty.kind() { + AggregateKind::RawPtr(*pointee_ty, *mutability) + } else { + span_bug!( + terminator.source_info.span, + "Return type of aggregate_raw_ptr intrinsic must be a raw pointer", + ); + }; + let fields = [data.node, meta.node]; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Aggregate(Box::new(kind), fields.into()), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::ptr_metadata => { + let Ok([ptr]) = take_array(args) else { + span_bug!( + terminator.source_info.span, + "Wrong number of arguments for ptr_metadata intrinsic", + ); + }; + let target = target.unwrap(); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::UnaryOp(UnOp::PtrMetadata, ptr.node), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + _ => {} + } + } + } + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/lower_slice_len.rs b/compiler/rustc_mir_transform/src/lower_slice_len.rs new file mode 100644 index 00000000000..aca80e36e33 --- /dev/null +++ b/compiler/rustc_mir_transform/src/lower_slice_len.rs @@ -0,0 +1,68 @@ +//! This pass lowers calls to core::slice::len to just PtrMetadata op. +//! It should run before inlining! + +use rustc_hir::def_id::DefId; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +pub(super) struct LowerSliceLenCalls; + +impl<'tcx> crate::MirPass<'tcx> for LowerSliceLenCalls { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let language_items = tcx.lang_items(); + let Some(slice_len_fn_item_def_id) = language_items.slice_len_fn() else { + // there is no lang item to compare to :) + return; + }; + + // The one successor remains unchanged, so no need to invalidate + let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); + for block in basic_blocks { + // lower `<[_]>::len` calls + lower_slice_len_call(block, slice_len_fn_item_def_id); + } + } + + fn is_required(&self) -> bool { + false + } +} + +fn lower_slice_len_call<'tcx>(block: &mut BasicBlockData<'tcx>, slice_len_fn_item_def_id: DefId) { + let terminator = block.terminator(); + if let TerminatorKind::Call { + func, + args, + destination, + target: Some(bb), + call_source: CallSource::Normal, + .. + } = &terminator.kind + // some heuristics for fast rejection + && let [arg] = &args[..] + && let Some((fn_def_id, _)) = func.const_fn_def() + && fn_def_id == slice_len_fn_item_def_id + { + // perform modifications from something like: + // _5 = core::slice::<impl [u8]>::len(move _6) -> bb1 + // into: + // _5 = PtrMetadata(move _6) + // goto bb1 + + // make new RValue for Len + let r_value = Rvalue::UnaryOp(UnOp::PtrMetadata, arg.node.clone()); + let len_statement_kind = StatementKind::Assign(Box::new((*destination, r_value))); + let add_statement = + Statement { kind: len_statement_kind, source_info: terminator.source_info }; + + // modify terminator into simple Goto + let new_terminator_kind = TerminatorKind::Goto { target: *bb }; + + block.statements.push(add_statement); + block.terminator_mut().kind = new_terminator_kind; + } +} diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs new file mode 100644 index 00000000000..5e511f1a418 --- /dev/null +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -0,0 +1,531 @@ +use std::iter; + +use rustc_abi::Integer; +use rustc_index::IndexSlice; +use rustc_middle::mir::*; +use rustc_middle::ty::layout::{IntegerExt, TyAndLayout}; +use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt}; +use tracing::instrument; + +use super::simplify::simplify_cfg; +use crate::patch::MirPatch; + +pub(super) struct MatchBranchSimplification; + +impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 1 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let typing_env = body.typing_env(tcx); + let mut apply_patch = false; + let mut patch = MirPatch::new(body); + for (bb, bb_data) in body.basic_blocks.iter_enumerated() { + match &bb_data.terminator().kind { + TerminatorKind::SwitchInt { + discr: Operand::Copy(_) | Operand::Move(_), + targets, + .. + // We require that the possible target blocks don't contain this block. + } if !targets.all_targets().contains(&bb) => {} + // Only optimize switch int statements + _ => continue, + }; + + if SimplifyToIf.simplify(tcx, body, &mut patch, bb, typing_env).is_some() { + apply_patch = true; + continue; + } + if SimplifyToExp::default().simplify(tcx, body, &mut patch, bb, typing_env).is_some() { + apply_patch = true; + continue; + } + } + + if apply_patch { + patch.apply(body); + simplify_cfg(tcx, body); + } + } + + fn is_required(&self) -> bool { + false + } +} + +trait SimplifyMatch<'tcx> { + /// Simplifies a match statement, returning `Some` if the simplification succeeds, `None` + /// otherwise. Generic code is written here, and we generally don't need a custom + /// implementation. + fn simplify( + &mut self, + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + patch: &mut MirPatch<'tcx>, + switch_bb_idx: BasicBlock, + typing_env: ty::TypingEnv<'tcx>, + ) -> Option<()> { + let bbs = &body.basic_blocks; + let TerminatorKind::SwitchInt { discr, targets, .. } = + &bbs[switch_bb_idx].terminator().kind + else { + unreachable!(); + }; + + let discr_ty = discr.ty(body.local_decls(), tcx); + self.can_simplify(tcx, targets, typing_env, bbs, discr_ty)?; + + // Take ownership of items now that we know we can optimize. + let discr = discr.clone(); + + // Introduce a temporary for the discriminant value. + let source_info = bbs[switch_bb_idx].terminator().source_info; + let discr_local = patch.new_temp(discr_ty, source_info.span); + + let (_, first) = targets.iter().next().unwrap(); + let statement_index = bbs[switch_bb_idx].statements.len(); + let parent_end = Location { block: switch_bb_idx, statement_index }; + patch.add_statement(parent_end, StatementKind::StorageLive(discr_local)); + patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr)); + self.new_stmts(tcx, targets, typing_env, patch, parent_end, bbs, discr_local, discr_ty); + patch.add_statement(parent_end, StatementKind::StorageDead(discr_local)); + patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone()); + Some(()) + } + + /// Check that the BBs to be simplified satisfies all distinct and + /// that the terminator are the same. + /// There are also conditions for different ways of simplification. + fn can_simplify( + &mut self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + typing_env: ty::TypingEnv<'tcx>, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + discr_ty: Ty<'tcx>, + ) -> Option<()>; + + fn new_stmts( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + typing_env: ty::TypingEnv<'tcx>, + patch: &mut MirPatch<'tcx>, + parent_end: Location, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ); +} + +struct SimplifyToIf; + +/// If a source block is found that switches between two blocks that are exactly +/// the same modulo const bool assignments (e.g., one assigns true another false +/// to the same place), merge a target block statements into the source block, +/// using Eq / Ne comparison with switch value where const bools value differ. +/// +/// For example: +/// +/// ```ignore (MIR) +/// bb0: { +/// switchInt(move _3) -> [42_isize: bb1, otherwise: bb2]; +/// } +/// +/// bb1: { +/// _2 = const true; +/// goto -> bb3; +/// } +/// +/// bb2: { +/// _2 = const false; +/// goto -> bb3; +/// } +/// ``` +/// +/// into: +/// +/// ```ignore (MIR) +/// bb0: { +/// _2 = Eq(move _3, const 42_isize); +/// goto -> bb3; +/// } +/// ``` +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { + #[instrument(level = "debug", skip(self, tcx), ret)] + fn can_simplify( + &mut self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + typing_env: ty::TypingEnv<'tcx>, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + _discr_ty: Ty<'tcx>, + ) -> Option<()> { + let (first, second) = match targets.all_targets() { + &[first, otherwise] => (first, otherwise), + &[first, second, otherwise] if bbs[otherwise].is_empty_unreachable() => (first, second), + _ => { + return None; + } + }; + + // We require that the possible target blocks all be distinct. + if first == second { + return None; + } + // Check that destinations are identical, and if not, then don't optimize this block + if bbs[first].terminator().kind != bbs[second].terminator().kind { + return None; + } + + // Check that blocks are assignments of consts to the same place or same statement, + // and match up 1-1, if not don't optimize this block. + let first_stmts = &bbs[first].statements; + let second_stmts = &bbs[second].statements; + if first_stmts.len() != second_stmts.len() { + return None; + } + for (f, s) in iter::zip(first_stmts, second_stmts) { + match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => {} + + // If two statements are const bool assignments to the same place, we can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.const_.ty().is_bool() + && s_c.const_.ty().is_bool() + && f_c.const_.try_eval_bool(tcx, typing_env).is_some() + && s_c.const_.try_eval_bool(tcx, typing_env).is_some() => {} + + // Otherwise we cannot optimize. Try another block. + _ => return None, + } + } + Some(()) + } + + fn new_stmts( + &self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + typing_env: ty::TypingEnv<'tcx>, + patch: &mut MirPatch<'tcx>, + parent_end: Location, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) { + let ((val, first), second) = match (targets.all_targets(), targets.all_values()) { + (&[first, otherwise], &[val]) => ((val, first), otherwise), + (&[first, second, otherwise], &[val, _]) if bbs[otherwise].is_empty_unreachable() => { + ((val, first), second) + } + _ => unreachable!(), + }; + + // We already checked that first and second are different blocks, + // and bb_idx has a different terminator from both of them. + let first = &bbs[first]; + let second = &bbs[second]; + for (f, s) in iter::zip(&first.statements, &second.statements) { + match (&f.kind, &s.kind) { + (f_s, s_s) if f_s == s_s => { + patch.add_statement(parent_end, f.kind.clone()); + } + + ( + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), + ) => { + // From earlier loop we know that we are dealing with bool constants only: + let f_b = f_c.const_.try_eval_bool(tcx, typing_env).unwrap(); + let s_b = s_c.const_.try_eval_bool(tcx, typing_env).unwrap(); + if f_b == s_b { + // Same value in both blocks. Use statement as is. + patch.add_statement(parent_end, f.kind.clone()); + } else { + // Different value between blocks. Make value conditional on switch + // condition. + let size = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap().size; + let const_cmp = Operand::const_from_scalar( + tcx, + discr_ty, + rustc_const_eval::interpret::Scalar::from_uint(val, size), + rustc_span::DUMMY_SP, + ); + let op = if f_b { BinOp::Eq } else { BinOp::Ne }; + let rhs = Rvalue::BinaryOp( + op, + Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), + ); + patch.add_assign(parent_end, *lhs, rhs); + } + } + + _ => unreachable!(), + } + } + } +} + +/// Check if the cast constant using `IntToInt` is equal to the target constant. +fn can_cast( + tcx: TyCtxt<'_>, + src_val: impl Into<u128>, + src_layout: TyAndLayout<'_>, + cast_ty: Ty<'_>, + target_scalar: ScalarInt, +) -> bool { + let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap(); + let v = match src_layout.ty.kind() { + ty::Uint(_) => from_scalar.to_uint(src_layout.size), + ty::Int(_) => from_scalar.to_int(src_layout.size) as u128, + // We can also transform the values of other integer representations (such as char), + // although this may not be practical in real-world scenarios. + _ => return false, + }; + let size = match *cast_ty.kind() { + ty::Int(t) => Integer::from_int_ty(&tcx, t).size(), + ty::Uint(t) => Integer::from_uint_ty(&tcx, t).size(), + _ => return false, + }; + let v = size.truncate(v); + let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap(); + cast_scalar == target_scalar +} + +#[derive(Default)] +struct SimplifyToExp { + transform_kinds: Vec<TransformKind>, +} + +#[derive(Clone, Copy, Debug)] +enum ExpectedTransformKind<'a, 'tcx> { + /// Identical statements. + Same(&'a StatementKind<'tcx>), + /// Assignment statements have the same value. + SameByEq { place: &'a Place<'tcx>, ty: Ty<'tcx>, scalar: ScalarInt }, + /// Enum variant comparison type. + Cast { place: &'a Place<'tcx>, ty: Ty<'tcx> }, +} + +enum TransformKind { + Same, + Cast, +} + +impl From<ExpectedTransformKind<'_, '_>> for TransformKind { + fn from(compare_type: ExpectedTransformKind<'_, '_>) -> Self { + match compare_type { + ExpectedTransformKind::Same(_) => TransformKind::Same, + ExpectedTransformKind::SameByEq { .. } => TransformKind::Same, + ExpectedTransformKind::Cast { .. } => TransformKind::Cast, + } + } +} + +/// If we find that the value of match is the same as the assignment, +/// merge a target block statements into the source block, +/// using cast to transform different integer types. +/// +/// For example: +/// +/// ```ignore (MIR) +/// bb0: { +/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; +/// } +/// +/// bb1: { +/// unreachable; +/// } +/// +/// bb2: { +/// _0 = const 1_i16; +/// goto -> bb5; +/// } +/// +/// bb3: { +/// _0 = const 2_i16; +/// goto -> bb5; +/// } +/// +/// bb4: { +/// _0 = const 3_i16; +/// goto -> bb5; +/// } +/// ``` +/// +/// into: +/// +/// ```ignore (MIR) +/// bb0: { +/// _0 = _3 as i16 (IntToInt); +/// goto -> bb5; +/// } +/// ``` +impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { + #[instrument(level = "debug", skip(self, tcx), ret)] + fn can_simplify( + &mut self, + tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + typing_env: ty::TypingEnv<'tcx>, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + discr_ty: Ty<'tcx>, + ) -> Option<()> { + if targets.iter().len() < 2 || targets.iter().len() > 64 { + return None; + } + // We require that the possible target blocks all be distinct. + if !targets.is_distinct() { + return None; + } + if !bbs[targets.otherwise()].is_empty_unreachable() { + return None; + } + let mut target_iter = targets.iter(); + let (first_case_val, first_target) = target_iter.next().unwrap(); + let first_terminator_kind = &bbs[first_target].terminator().kind; + // Check that destinations are identical, and if not, then don't optimize this block + if !targets + .iter() + .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind) + { + return None; + } + + let discr_layout = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap(); + let first_stmts = &bbs[first_target].statements; + let (second_case_val, second_target) = target_iter.next().unwrap(); + let second_stmts = &bbs[second_target].statements; + if first_stmts.len() != second_stmts.len() { + return None; + } + + // We first compare the two branches, and then the other branches need to fulfill the same + // conditions. + let mut expected_transform_kinds = Vec::new(); + for (f, s) in iter::zip(first_stmts, second_stmts) { + let compare_type = match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => ExpectedTransformKind::Same(f_s), + + // If two statements are assignments with the match values to the same place, we + // can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.const_.ty() == s_c.const_.ty() + && f_c.const_.ty().is_integral() => + { + match ( + f_c.const_.try_eval_scalar_int(tcx, typing_env), + s_c.const_.try_eval_scalar_int(tcx, typing_env), + ) { + (Some(f), Some(s)) if f == s => ExpectedTransformKind::SameByEq { + place: lhs_f, + ty: f_c.const_.ty(), + scalar: f, + }, + // Enum variants can also be simplified to an assignment statement, + // if we can use `IntToInt` cast to get an equal value. + (Some(f), Some(s)) + if (can_cast( + tcx, + first_case_val, + discr_layout, + f_c.const_.ty(), + f, + ) && can_cast( + tcx, + second_case_val, + discr_layout, + f_c.const_.ty(), + s, + )) => + { + ExpectedTransformKind::Cast { place: lhs_f, ty: f_c.const_.ty() } + } + _ => { + return None; + } + } + } + + // Otherwise we cannot optimize. Try another block. + _ => return None, + }; + expected_transform_kinds.push(compare_type); + } + + // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step. + for (other_val, other_target) in target_iter { + let other_stmts = &bbs[other_target].statements; + if expected_transform_kinds.len() != other_stmts.len() { + return None; + } + for (f, s) in iter::zip(&expected_transform_kinds, other_stmts) { + match (*f, &s.kind) { + (ExpectedTransformKind::Same(f_s), s_s) if f_s == s_s => {} + ( + ExpectedTransformKind::SameByEq { place: lhs_f, ty: f_ty, scalar }, + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && s_c.const_.ty() == f_ty + && s_c.const_.try_eval_scalar_int(tcx, typing_env) == Some(scalar) => {} + ( + ExpectedTransformKind::Cast { place: lhs_f, ty: f_ty }, + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if let Some(f) = s_c.const_.try_eval_scalar_int(tcx, typing_env) + && lhs_f == lhs_s + && s_c.const_.ty() == f_ty + && can_cast(tcx, other_val, discr_layout, f_ty, f) => {} + _ => return None, + } + } + } + self.transform_kinds = expected_transform_kinds.into_iter().map(|c| c.into()).collect(); + Some(()) + } + + fn new_stmts( + &self, + _tcx: TyCtxt<'tcx>, + targets: &SwitchTargets, + _typing_env: ty::TypingEnv<'tcx>, + patch: &mut MirPatch<'tcx>, + parent_end: Location, + bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + discr_local: Local, + discr_ty: Ty<'tcx>, + ) { + let (_, first) = targets.iter().next().unwrap(); + let first = &bbs[first]; + + for (t, s) in iter::zip(&self.transform_kinds, &first.statements) { + match (t, &s.kind) { + (TransformKind::Same, _) => { + patch.add_statement(parent_end, s.kind.clone()); + } + ( + TransformKind::Cast, + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + ) => { + let operand = Operand::Copy(Place::from(discr_local)); + let r_val = if f_c.const_.ty() == discr_ty { + Rvalue::Use(operand) + } else { + Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty()) + }; + patch.add_assign(parent_end, *lhs, r_val); + } + _ => unreachable!(), + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/mentioned_items.rs b/compiler/rustc_mir_transform/src/mentioned_items.rs new file mode 100644 index 00000000000..9fd8d81d64a --- /dev/null +++ b/compiler/rustc_mir_transform/src/mentioned_items.rs @@ -0,0 +1,123 @@ +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::{self, Location, MentionedItem}; +use rustc_middle::ty::adjustment::PointerCoercion; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_session::Session; +use rustc_span::source_map::Spanned; + +pub(super) struct MentionedItems; + +struct MentionedItemsVisitor<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + body: &'a mir::Body<'tcx>, + mentioned_items: Vec<Spanned<MentionedItem<'tcx>>>, +} + +impl<'tcx> crate::MirPass<'tcx> for MentionedItems { + fn is_enabled(&self, _sess: &Session) -> bool { + // If this pass is skipped the collector assume that nothing got mentioned! We could + // potentially skip it in opt-level 0 if we are sure that opt-level will never *remove* uses + // of anything, but that still seems fragile. Furthermore, even debug builds use level 1, so + // special-casing level 0 is just not worth it. + true + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { + let mut visitor = MentionedItemsVisitor { tcx, body, mentioned_items: Vec::new() }; + visitor.visit_body(body); + body.set_mentioned_items(visitor.mentioned_items); + } + + fn is_required(&self) -> bool { + true + } +} + +// This visitor is carefully in sync with the one in `rustc_monomorphize::collector`. We are +// visiting the exact same places but then instead of monomorphizing and creating `MonoItems`, we +// have to remain generic and just recording the relevant information in `mentioned_items`, where it +// will then be monomorphized later during "mentioned items" collection. +impl<'tcx> Visitor<'tcx> for MentionedItemsVisitor<'_, 'tcx> { + fn visit_terminator(&mut self, terminator: &mir::Terminator<'tcx>, location: Location) { + self.super_terminator(terminator, location); + let span = || self.body.source_info(location).span; + match &terminator.kind { + mir::TerminatorKind::Call { func, .. } | mir::TerminatorKind::TailCall { func, .. } => { + let callee_ty = func.ty(self.body, self.tcx); + self.mentioned_items + .push(Spanned { node: MentionedItem::Fn(callee_ty), span: span() }); + } + mir::TerminatorKind::Drop { place, .. } => { + let ty = place.ty(self.body, self.tcx).ty; + self.mentioned_items.push(Spanned { node: MentionedItem::Drop(ty), span: span() }); + } + mir::TerminatorKind::InlineAsm { operands, .. } => { + for op in operands { + match *op { + mir::InlineAsmOperand::SymFn { ref value } => { + self.mentioned_items.push(Spanned { + node: MentionedItem::Fn(value.const_.ty()), + span: span(), + }); + } + _ => {} + } + } + } + _ => {} + } + } + + fn visit_rvalue(&mut self, rvalue: &mir::Rvalue<'tcx>, location: Location) { + self.super_rvalue(rvalue, location); + let span = || self.body.source_info(location).span; + match *rvalue { + // We need to detect unsizing casts that required vtables. + mir::Rvalue::Cast( + mir::CastKind::PointerCoercion(PointerCoercion::Unsize, _) + | mir::CastKind::PointerCoercion(PointerCoercion::DynStar, _), + ref operand, + target_ty, + ) => { + // This isn't monomorphized yet so we can't tell what the actual types are -- just + // add everything that may involve a vtable. + let source_ty = operand.ty(self.body, self.tcx); + let may_involve_vtable = match ( + source_ty.builtin_deref(true).map(|t| t.kind()), + target_ty.builtin_deref(true).map(|t| t.kind()), + ) { + // &str/&[T] unsizing + (Some(ty::Array(..)), Some(ty::Str | ty::Slice(..))) => false, + + _ => true, + }; + if may_involve_vtable { + self.mentioned_items.push(Spanned { + node: MentionedItem::UnsizeCast { source_ty, target_ty }, + span: span(), + }); + } + } + // Similarly, record closures that are turned into function pointers. + mir::Rvalue::Cast( + mir::CastKind::PointerCoercion(PointerCoercion::ClosureFnPointer(_), _), + ref operand, + _, + ) => { + let source_ty = operand.ty(self.body, self.tcx); + self.mentioned_items + .push(Spanned { node: MentionedItem::Closure(source_ty), span: span() }); + } + // And finally, function pointer reification casts. + mir::Rvalue::Cast( + mir::CastKind::PointerCoercion(PointerCoercion::ReifyFnPointer, _), + ref operand, + _, + ) => { + let fn_ty = operand.ty(self.body, self.tcx); + self.mentioned_items.push(Spanned { node: MentionedItem::Fn(fn_ty), span: span() }); + } + _ => {} + } + } +} diff --git a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs new file mode 100644 index 00000000000..f59b849e85c --- /dev/null +++ b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs @@ -0,0 +1,41 @@ +//! This pass removes jumps to basic blocks containing only a return, and replaces them with a +//! return instead. + +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +use crate::simplify; + +pub(super) struct MultipleReturnTerminators; + +impl<'tcx> crate::MirPass<'tcx> for MultipleReturnTerminators { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 4 + } + + fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // find basic blocks with no statement and a return terminator + let mut bbs_simple_returns = DenseBitSet::new_empty(body.basic_blocks.len()); + let bbs = body.basic_blocks_mut(); + for (idx, bb) in bbs.iter_enumerated() { + if bb.statements.is_empty() && bb.terminator().kind == TerminatorKind::Return { + bbs_simple_returns.insert(idx); + } + } + + for bb in bbs { + if let TerminatorKind::Goto { target } = bb.terminator().kind + && bbs_simple_returns.contains(target) + { + bb.terminator_mut().kind = TerminatorKind::Return; + } + } + + simplify::remove_dead_blocks(body) + } + + fn is_required(&self) -> bool { + false + } +} diff --git a/compiler/rustc_mir_transform/src/nrvo.rs b/compiler/rustc_mir_transform/src/nrvo.rs new file mode 100644 index 00000000000..965002aae04 --- /dev/null +++ b/compiler/rustc_mir_transform/src/nrvo.rs @@ -0,0 +1,234 @@ +//! See the docs for [`RenameReturnPlace`]. + +use rustc_hir::Mutability; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::bug; +use rustc_middle::mir::visit::{MutVisitor, NonUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::{self, BasicBlock, Local, Location}; +use rustc_middle::ty::TyCtxt; +use tracing::{debug, trace}; + +/// This pass looks for MIR that always copies the same local into the return place and eliminates +/// the copy by renaming all uses of that local to `_0`. +/// +/// This allows LLVM to perform an optimization similar to the named return value optimization +/// (NRVO) that is guaranteed in C++. This avoids a stack allocation and `memcpy` for the +/// relatively common pattern of allocating a buffer on the stack, mutating it, and returning it by +/// value like so: +/// +/// ```rust +/// fn foo(init: fn(&mut [u8; 1024])) -> [u8; 1024] { +/// let mut buf = [0; 1024]; +/// init(&mut buf); +/// buf +/// } +/// ``` +/// +/// For now, this pass is very simple and only capable of eliminating a single copy. A more general +/// version of copy propagation, such as the one based on non-overlapping live ranges in [#47954] and +/// [#71003], could yield even more benefits. +/// +/// [#47954]: https://github.com/rust-lang/rust/pull/47954 +/// [#71003]: https://github.com/rust-lang/rust/pull/71003 +pub(super) struct RenameReturnPlace; + +impl<'tcx> crate::MirPass<'tcx> for RenameReturnPlace { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + // unsound: #111005 + sess.mir_opt_level() > 0 && sess.opts.unstable_opts.unsound_mir_opts + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { + let def_id = body.source.def_id(); + let Some(returned_local) = local_eligible_for_nrvo(body) else { + debug!("`{:?}` was ineligible for NRVO", def_id); + return; + }; + + debug!( + "`{:?}` was eligible for NRVO, making {:?} the return place", + def_id, returned_local + ); + + RenameToReturnPlace { tcx, to_rename: returned_local }.visit_body_preserves_cfg(body); + + // Clean up the `NOP`s we inserted for statements made useless by our renaming. + for block_data in body.basic_blocks.as_mut_preserves_cfg() { + block_data.statements.retain(|stmt| stmt.kind != mir::StatementKind::Nop); + } + + // Overwrite the debuginfo of `_0` with that of the renamed local. + let (renamed_decl, ret_decl) = + body.local_decls.pick2_mut(returned_local, mir::RETURN_PLACE); + + // Sometimes, the return place is assigned a local of a different but coercible type, for + // example `&mut T` instead of `&T`. Overwriting the `LocalInfo` for the return place means + // its type may no longer match the return type of its function. This doesn't cause a + // problem in codegen because these two types are layout-compatible, but may be unexpected. + debug!("_0: {:?} = {:?}: {:?}", ret_decl.ty, returned_local, renamed_decl.ty); + ret_decl.clone_from(renamed_decl); + + // The return place is always mutable. + ret_decl.mutability = Mutability::Mut; + } + + fn is_required(&self) -> bool { + false + } +} + +/// MIR that is eligible for the NRVO must fulfill two conditions: +/// 1. The return place must not be read prior to the `Return` terminator. +/// 2. A simple assignment of a whole local to the return place (e.g., `_0 = _1`) must be the +/// only definition of the return place reaching the `Return` terminator. +/// +/// If the MIR fulfills both these conditions, this function returns the `Local` that is assigned +/// to the return place along all possible paths through the control-flow graph. +fn local_eligible_for_nrvo(body: &mir::Body<'_>) -> Option<Local> { + if IsReturnPlaceRead::run(body) { + return None; + } + + let mut copied_to_return_place = None; + for block in body.basic_blocks.indices() { + // Look for blocks with a `Return` terminator. + if !matches!(body[block].terminator().kind, mir::TerminatorKind::Return) { + continue; + } + + // Look for an assignment of a single local to the return place prior to the `Return`. + let returned_local = find_local_assigned_to_return_place(block, body)?; + match body.local_kind(returned_local) { + // FIXME: Can we do this for arguments as well? + mir::LocalKind::Arg => return None, + + mir::LocalKind::ReturnPointer => bug!("Return place was assigned to itself?"), + mir::LocalKind::Temp => {} + } + + // If multiple different locals are copied to the return place. We can't pick a + // single one to rename. + if copied_to_return_place.is_some_and(|old| old != returned_local) { + return None; + } + + copied_to_return_place = Some(returned_local); + } + + copied_to_return_place +} + +fn find_local_assigned_to_return_place(start: BasicBlock, body: &mir::Body<'_>) -> Option<Local> { + let mut block = start; + let mut seen = DenseBitSet::new_empty(body.basic_blocks.len()); + + // Iterate as long as `block` has exactly one predecessor that we have not yet visited. + while seen.insert(block) { + trace!("Looking for assignments to `_0` in {:?}", block); + + let local = body[block].statements.iter().rev().find_map(as_local_assigned_to_return_place); + if local.is_some() { + return local; + } + + match body.basic_blocks.predecessors()[block].as_slice() { + &[pred] => block = pred, + _ => return None, + } + } + + None +} + +// If this statement is an assignment of an unprojected local to the return place, +// return that local. +fn as_local_assigned_to_return_place(stmt: &mir::Statement<'_>) -> Option<Local> { + if let mir::StatementKind::Assign(box (lhs, rhs)) = &stmt.kind { + if lhs.as_local() == Some(mir::RETURN_PLACE) { + if let mir::Rvalue::Use(mir::Operand::Copy(rhs) | mir::Operand::Move(rhs)) = rhs { + return rhs.as_local(); + } + } + } + + None +} + +struct RenameToReturnPlace<'tcx> { + to_rename: Local, + tcx: TyCtxt<'tcx>, +} + +/// Replaces all uses of `self.to_rename` with `_0`. +impl<'tcx> MutVisitor<'tcx> for RenameToReturnPlace<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_statement(&mut self, stmt: &mut mir::Statement<'tcx>, loc: Location) { + // Remove assignments of the local being replaced to the return place, since it is now the + // return place: + // _0 = _1 + if as_local_assigned_to_return_place(stmt) == Some(self.to_rename) { + stmt.kind = mir::StatementKind::Nop; + return; + } + + // Remove storage annotations for the local being replaced: + // StorageLive(_1) + if let mir::StatementKind::StorageLive(local) | mir::StatementKind::StorageDead(local) = + stmt.kind + { + if local == self.to_rename { + stmt.kind = mir::StatementKind::Nop; + return; + } + } + + self.super_statement(stmt, loc) + } + + fn visit_terminator(&mut self, terminator: &mut mir::Terminator<'tcx>, loc: Location) { + // Ignore the implicit "use" of the return place in a `Return` statement. + if let mir::TerminatorKind::Return = terminator.kind { + return; + } + + self.super_terminator(terminator, loc); + } + + fn visit_local(&mut self, l: &mut Local, ctxt: PlaceContext, _: Location) { + if *l == mir::RETURN_PLACE { + assert_eq!(ctxt, PlaceContext::NonUse(NonUseContext::VarDebugInfo)); + } else if *l == self.to_rename { + *l = mir::RETURN_PLACE; + } + } +} + +struct IsReturnPlaceRead(bool); + +impl IsReturnPlaceRead { + fn run(body: &mir::Body<'_>) -> bool { + let mut vis = IsReturnPlaceRead(false); + vis.visit_body(body); + vis.0 + } +} + +impl<'tcx> Visitor<'tcx> for IsReturnPlaceRead { + fn visit_local(&mut self, l: Local, ctxt: PlaceContext, _: Location) { + if l == mir::RETURN_PLACE && ctxt.is_use() && !ctxt.is_place_assignment() { + self.0 = true; + } + } + + fn visit_terminator(&mut self, terminator: &mir::Terminator<'tcx>, loc: Location) { + // Ignore the implicit "use" of the return place in a `Return` statement. + if let mir::TerminatorKind::Return = terminator.kind { + return; + } + + self.super_terminator(terminator, loc); + } +} diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs new file mode 100644 index 00000000000..7a8d3ba1ff1 --- /dev/null +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -0,0 +1,351 @@ +use std::cell::RefCell; +use std::collections::hash_map::Entry; + +use rustc_data_structures::fx::{FxHashMap, FxIndexSet}; +use rustc_middle::mir::{self, Body, MirPhase, RuntimePhase}; +use rustc_middle::ty::TyCtxt; +use rustc_session::Session; +use tracing::trace; + +use crate::lint::lint_body; +use crate::{errors, validate}; + +thread_local! { + /// Maps MIR pass names to a snake case form to match profiling naming style + static PASS_TO_PROFILER_NAMES: RefCell<FxHashMap<&'static str, &'static str>> = { + RefCell::new(FxHashMap::default()) + }; +} + +/// Converts a MIR pass name into a snake case form to match the profiling naming style. +fn to_profiler_name(type_name: &'static str) -> &'static str { + PASS_TO_PROFILER_NAMES.with(|names| match names.borrow_mut().entry(type_name) { + Entry::Occupied(e) => *e.get(), + Entry::Vacant(e) => { + let snake_case: String = type_name + .chars() + .flat_map(|c| { + if c.is_ascii_uppercase() { + vec!['_', c.to_ascii_lowercase()] + } else if c == '-' { + vec!['_'] + } else { + vec![c] + } + }) + .collect(); + let result = &*String::leak(format!("mir_pass{}", snake_case)); + e.insert(result); + result + } + }) +} + +// const wrapper for `if let Some((_, tail)) = name.rsplit_once(':') { tail } else { name }` +const fn c_name(name: &'static str) -> &'static str { + // FIXME(const-hack) Simplify the implementation once more `str` methods get const-stable. + // and inline into call site + let bytes = name.as_bytes(); + let mut i = bytes.len(); + while i > 0 && bytes[i - 1] != b':' { + i = i - 1; + } + let (_, bytes) = bytes.split_at(i); + match std::str::from_utf8(bytes) { + Ok(name) => name, + Err(_) => name, + } +} + +/// A streamlined trait that you can implement to create a pass; the +/// pass will be named after the type, and it will consist of a main +/// loop that goes over each available MIR and applies `run_pass`. +pub(super) trait MirPass<'tcx> { + fn name(&self) -> &'static str { + // FIXME(const-hack) Simplify the implementation once more `str` methods get const-stable. + // See copypaste in `MirLint` + const { + let name = std::any::type_name::<Self>(); + c_name(name) + } + } + + fn profiler_name(&self) -> &'static str { + to_profiler_name(self.name()) + } + + /// Returns `true` if this pass is enabled with the current combination of compiler flags. + fn is_enabled(&self, _sess: &Session) -> bool { + true + } + + /// Returns `true` if this pass can be overridden by `-Zenable-mir-passes`. This should be + /// true for basically every pass other than those that are necessary for correctness. + fn can_be_overridden(&self) -> bool { + true + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>); + + fn is_mir_dump_enabled(&self) -> bool { + true + } + + /// Returns `true` if this pass must be run (i.e. it is required for soundness). + /// For passes which are strictly optimizations, this should return `false`. + /// If this is `false`, `#[optimize(none)]` will disable the pass. + fn is_required(&self) -> bool; +} + +/// Just like `MirPass`, except it cannot mutate `Body`, and MIR dumping is +/// disabled (via the `Lint` adapter). +pub(super) trait MirLint<'tcx> { + fn name(&self) -> &'static str { + // FIXME(const-hack) Simplify the implementation once more `str` methods get const-stable. + // See copypaste in `MirPass` + const { + let name = std::any::type_name::<Self>(); + c_name(name) + } + } + + fn is_enabled(&self, _sess: &Session) -> bool { + true + } + + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>); +} + +/// An adapter for `MirLint`s that implements `MirPass`. +#[derive(Debug, Clone)] +pub(super) struct Lint<T>(pub T); + +impl<'tcx, T> MirPass<'tcx> for Lint<T> +where + T: MirLint<'tcx>, +{ + fn name(&self) -> &'static str { + self.0.name() + } + + fn is_enabled(&self, sess: &Session) -> bool { + self.0.is_enabled(sess) + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + self.0.run_lint(tcx, body) + } + + fn is_mir_dump_enabled(&self) -> bool { + false + } + + fn is_required(&self) -> bool { + true + } +} + +pub(super) struct WithMinOptLevel<T>(pub u32, pub T); + +impl<'tcx, T> MirPass<'tcx> for WithMinOptLevel<T> +where + T: MirPass<'tcx>, +{ + fn name(&self) -> &'static str { + self.1.name() + } + + fn is_enabled(&self, sess: &Session) -> bool { + sess.mir_opt_level() >= self.0 as usize + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + self.1.run_pass(tcx, body) + } + + fn is_required(&self) -> bool { + self.1.is_required() + } +} + +/// Whether to allow non-[required] optimizations +/// +/// [required]: MirPass::is_required +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(crate) enum Optimizations { + Suppressed, + Allowed, +} + +/// Run the sequence of passes without validating the MIR after each pass. The MIR is still +/// validated at the end. +pub(super) fn run_passes_no_validate<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + passes: &[&dyn MirPass<'tcx>], + phase_change: Option<MirPhase>, +) { + run_passes_inner(tcx, body, passes, phase_change, false, Optimizations::Allowed); +} + +/// The optional `phase_change` is applied after executing all the passes, if present +pub(super) fn run_passes<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + passes: &[&dyn MirPass<'tcx>], + phase_change: Option<MirPhase>, + optimizations: Optimizations, +) { + run_passes_inner(tcx, body, passes, phase_change, true, optimizations); +} + +pub(super) fn should_run_pass<'tcx, P>( + tcx: TyCtxt<'tcx>, + pass: &P, + optimizations: Optimizations, +) -> bool +where + P: MirPass<'tcx> + ?Sized, +{ + let name = pass.name(); + + if !pass.can_be_overridden() { + return pass.is_enabled(tcx.sess); + } + + let overridden_passes = &tcx.sess.opts.unstable_opts.mir_enable_passes; + let overridden = + overridden_passes.iter().rev().find(|(s, _)| s == &*name).map(|(_name, polarity)| { + trace!( + pass = %name, + "{} as requested by flag", + if *polarity { "Running" } else { "Not running" }, + ); + *polarity + }); + let suppressed = !pass.is_required() && matches!(optimizations, Optimizations::Suppressed); + overridden.unwrap_or_else(|| !suppressed && pass.is_enabled(tcx.sess)) +} + +fn run_passes_inner<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + passes: &[&dyn MirPass<'tcx>], + phase_change: Option<MirPhase>, + validate_each: bool, + optimizations: Optimizations, +) { + let overridden_passes = &tcx.sess.opts.unstable_opts.mir_enable_passes; + trace!(?overridden_passes); + + let named_passes: FxIndexSet<_> = + overridden_passes.iter().map(|(name, _)| name.as_str()).collect(); + + for &name in named_passes.difference(&*crate::PASS_NAMES) { + tcx.dcx().emit_warn(errors::UnknownPassName { name }); + } + + // Verify that no passes are missing from the `declare_passes` invocation + #[cfg(debug_assertions)] + #[allow(rustc::diagnostic_outside_of_impl)] + #[allow(rustc::untranslatable_diagnostic)] + { + let used_passes: FxIndexSet<_> = passes.iter().map(|p| p.name()).collect(); + + let undeclared = used_passes.difference(&*crate::PASS_NAMES).collect::<Vec<_>>(); + if let Some((name, rest)) = undeclared.split_first() { + let mut err = + tcx.dcx().struct_bug(format!("pass `{name}` is not declared in `PASS_NAMES`")); + for name in rest { + err.note(format!("pass `{name}` is also not declared in `PASS_NAMES`")); + } + err.emit(); + } + } + + let prof_arg = tcx.sess.prof.enabled().then(|| format!("{:?}", body.source.def_id())); + + if !body.should_skip() { + let validate = validate_each & tcx.sess.opts.unstable_opts.validate_mir; + let lint = tcx.sess.opts.unstable_opts.lint_mir; + + for pass in passes { + let name = pass.name(); + + if !should_run_pass(tcx, *pass, optimizations) { + continue; + }; + + let dump_enabled = pass.is_mir_dump_enabled(); + + if dump_enabled { + dump_mir_for_pass(tcx, body, name, false); + } + + if let Some(prof_arg) = &prof_arg { + tcx.sess + .prof + .generic_activity_with_arg(pass.profiler_name(), &**prof_arg) + .run(|| pass.run_pass(tcx, body)); + } else { + pass.run_pass(tcx, body); + } + + if dump_enabled { + dump_mir_for_pass(tcx, body, name, true); + } + if validate { + validate_body(tcx, body, format!("after pass {name}")); + } + if lint { + lint_body(tcx, body, format!("after pass {name}")); + } + + body.pass_count += 1; + } + } + + if let Some(new_phase) = phase_change { + if body.phase >= new_phase { + panic!("Invalid MIR phase transition from {:?} to {:?}", body.phase, new_phase); + } + + body.phase = new_phase; + body.pass_count = 0; + + dump_mir_for_phase_change(tcx, body); + + let validate = + (validate_each & tcx.sess.opts.unstable_opts.validate_mir & !body.should_skip()) + || new_phase == MirPhase::Runtime(RuntimePhase::Optimized); + let lint = tcx.sess.opts.unstable_opts.lint_mir & !body.should_skip(); + if validate { + validate_body(tcx, body, format!("after phase change to {}", new_phase.name())); + } + if lint { + lint_body(tcx, body, format!("after phase change to {}", new_phase.name())); + } + + body.pass_count = 1; + } +} + +pub(super) fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: String) { + validate::Validator { when }.run_pass(tcx, body); +} + +fn dump_mir_for_pass<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, pass_name: &str, is_after: bool) { + mir::dump_mir( + tcx, + true, + pass_name, + if is_after { &"after" } else { &"before" }, + body, + |_, _| Ok(()), + ); +} + +pub(super) fn dump_mir_for_phase_change<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + assert_eq!(body.pass_count, 0); + mir::dump_mir(tcx, true, body.phase.name(), &"after", body, |_, _| Ok(())) +} diff --git a/compiler/rustc_mir_transform/src/patch.rs b/compiler/rustc_mir_transform/src/patch.rs new file mode 100644 index 00000000000..a872eae15f1 --- /dev/null +++ b/compiler/rustc_mir_transform/src/patch.rs @@ -0,0 +1,299 @@ +use rustc_index::{Idx, IndexVec}; +use rustc_middle::mir::*; +use rustc_middle::ty::Ty; +use rustc_span::Span; +use tracing::debug; + +/// This struct lets you "patch" a MIR body, i.e. modify it. You can queue up +/// various changes, such as the addition of new statements and basic blocks +/// and replacement of terminators, and then apply the queued changes all at +/// once with `apply`. This is useful for MIR transformation passes. +pub(crate) struct MirPatch<'tcx> { + term_patch_map: IndexVec<BasicBlock, Option<TerminatorKind<'tcx>>>, + new_blocks: Vec<BasicBlockData<'tcx>>, + new_statements: Vec<(Location, StatementKind<'tcx>)>, + new_locals: Vec<LocalDecl<'tcx>>, + resume_block: Option<BasicBlock>, + // Only for unreachable in cleanup path. + unreachable_cleanup_block: Option<BasicBlock>, + // Only for unreachable not in cleanup path. + unreachable_no_cleanup_block: Option<BasicBlock>, + // Cached block for UnwindTerminate (with reason) + terminate_block: Option<(BasicBlock, UnwindTerminateReason)>, + body_span: Span, + next_local: usize, +} + +impl<'tcx> MirPatch<'tcx> { + /// Creates a new, empty patch. + pub(crate) fn new(body: &Body<'tcx>) -> Self { + let mut result = MirPatch { + term_patch_map: IndexVec::from_elem(None, &body.basic_blocks), + new_blocks: vec![], + new_statements: vec![], + new_locals: vec![], + next_local: body.local_decls.len(), + resume_block: None, + unreachable_cleanup_block: None, + unreachable_no_cleanup_block: None, + terminate_block: None, + body_span: body.span, + }; + + for (bb, block) in body.basic_blocks.iter_enumerated() { + // Check if we already have a resume block + if matches!(block.terminator().kind, TerminatorKind::UnwindResume) + && block.statements.is_empty() + { + result.resume_block = Some(bb); + continue; + } + + // Check if we already have an unreachable block + if matches!(block.terminator().kind, TerminatorKind::Unreachable) + && block.statements.is_empty() + { + if block.is_cleanup { + result.unreachable_cleanup_block = Some(bb); + } else { + result.unreachable_no_cleanup_block = Some(bb); + } + continue; + } + + // Check if we already have a terminate block + if let TerminatorKind::UnwindTerminate(reason) = block.terminator().kind + && block.statements.is_empty() + { + result.terminate_block = Some((bb, reason)); + continue; + } + } + + result + } + + pub(crate) fn resume_block(&mut self) -> BasicBlock { + if let Some(bb) = self.resume_block { + return bb; + } + + let bb = self.new_block(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + source_info: SourceInfo::outermost(self.body_span), + kind: TerminatorKind::UnwindResume, + }), + is_cleanup: true, + }); + self.resume_block = Some(bb); + bb + } + + pub(crate) fn unreachable_cleanup_block(&mut self) -> BasicBlock { + if let Some(bb) = self.unreachable_cleanup_block { + return bb; + } + + let bb = self.new_block(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + source_info: SourceInfo::outermost(self.body_span), + kind: TerminatorKind::Unreachable, + }), + is_cleanup: true, + }); + self.unreachable_cleanup_block = Some(bb); + bb + } + + pub(crate) fn unreachable_no_cleanup_block(&mut self) -> BasicBlock { + if let Some(bb) = self.unreachable_no_cleanup_block { + return bb; + } + + let bb = self.new_block(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + source_info: SourceInfo::outermost(self.body_span), + kind: TerminatorKind::Unreachable, + }), + is_cleanup: false, + }); + self.unreachable_no_cleanup_block = Some(bb); + bb + } + + pub(crate) fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock { + if let Some((cached_bb, cached_reason)) = self.terminate_block + && reason == cached_reason + { + return cached_bb; + } + + let bb = self.new_block(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + source_info: SourceInfo::outermost(self.body_span), + kind: TerminatorKind::UnwindTerminate(reason), + }), + is_cleanup: true, + }); + self.terminate_block = Some((bb, reason)); + bb + } + + /// Has a replacement of this block's terminator been queued in this patch? + pub(crate) fn is_term_patched(&self, bb: BasicBlock) -> bool { + self.term_patch_map[bb].is_some() + } + + /// Universal getter for block data, either it is in 'old' blocks or in patched ones + pub(crate) fn block<'a>( + &'a self, + body: &'a Body<'tcx>, + bb: BasicBlock, + ) -> &'a BasicBlockData<'tcx> { + match bb.index().checked_sub(body.basic_blocks.len()) { + Some(new) => &self.new_blocks[new], + None => &body[bb], + } + } + + pub(crate) fn terminator_loc(&self, body: &Body<'tcx>, bb: BasicBlock) -> Location { + let offset = self.block(body, bb).statements.len(); + Location { block: bb, statement_index: offset } + } + + /// Queues the addition of a new temporary with additional local info. + pub(crate) fn new_local_with_info( + &mut self, + ty: Ty<'tcx>, + span: Span, + local_info: LocalInfo<'tcx>, + ) -> Local { + let index = self.next_local; + self.next_local += 1; + let mut new_decl = LocalDecl::new(ty, span); + **new_decl.local_info.as_mut().unwrap_crate_local() = local_info; + self.new_locals.push(new_decl); + Local::new(index) + } + + /// Queues the addition of a new temporary. + pub(crate) fn new_temp(&mut self, ty: Ty<'tcx>, span: Span) -> Local { + let index = self.next_local; + self.next_local += 1; + self.new_locals.push(LocalDecl::new(ty, span)); + Local::new(index) + } + + /// Returns the type of a local that's newly-added in the patch. + pub(crate) fn local_ty(&self, local: Local) -> Ty<'tcx> { + let local = local.as_usize(); + assert!(local < self.next_local); + let new_local_idx = self.new_locals.len() - (self.next_local - local); + self.new_locals[new_local_idx].ty + } + + /// Queues the addition of a new basic block. + pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock { + let block = self.term_patch_map.next_index(); + debug!("MirPatch: new_block: {:?}: {:?}", block, data); + self.new_blocks.push(data); + self.term_patch_map.push(None); + block + } + + /// Queues the replacement of a block's terminator. + pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) { + assert!(self.term_patch_map[block].is_none()); + debug!("MirPatch: patch_terminator({:?}, {:?})", block, new); + self.term_patch_map[block] = Some(new); + } + + /// Queues the insertion of a statement at a given location. The statement + /// currently at that location, and all statements that follow, are shifted + /// down. If multiple statements are queued for addition at the same + /// location, the final statement order after calling `apply` will match + /// the queue insertion order. + /// + /// E.g. if we have `s0` at location `loc` and do these calls: + /// + /// p.add_statement(loc, s1); + /// p.add_statement(loc, s2); + /// p.apply(body); + /// + /// then the final order will be `s1, s2, s0`, with `s1` at `loc`. + pub(crate) fn add_statement(&mut self, loc: Location, stmt: StatementKind<'tcx>) { + debug!("MirPatch: add_statement({:?}, {:?})", loc, stmt); + self.new_statements.push((loc, stmt)); + } + + /// Like `add_statement`, but specialized for assignments. + pub(crate) fn add_assign(&mut self, loc: Location, place: Place<'tcx>, rv: Rvalue<'tcx>) { + self.add_statement(loc, StatementKind::Assign(Box::new((place, rv)))); + } + + /// Applies the queued changes. + pub(crate) fn apply(self, body: &mut Body<'tcx>) { + debug!( + "MirPatch: {:?} new temps, starting from index {}: {:?}", + self.new_locals.len(), + body.local_decls.len(), + self.new_locals + ); + debug!( + "MirPatch: {} new blocks, starting from index {}", + self.new_blocks.len(), + body.basic_blocks.len() + ); + let bbs = if self.term_patch_map.is_empty() && self.new_blocks.is_empty() { + body.basic_blocks.as_mut_preserves_cfg() + } else { + body.basic_blocks.as_mut() + }; + bbs.extend(self.new_blocks); + body.local_decls.extend(self.new_locals); + for (src, patch) in self.term_patch_map.into_iter_enumerated() { + if let Some(patch) = patch { + debug!("MirPatch: patching block {:?}", src); + bbs[src].terminator_mut().kind = patch; + } + } + + let mut new_statements = self.new_statements; + + // This must be a stable sort to provide the ordering described in the + // comment for `add_statement`. + new_statements.sort_by_key(|s| s.0); + + let mut delta = 0; + let mut last_bb = START_BLOCK; + for (mut loc, stmt) in new_statements { + if loc.block != last_bb { + delta = 0; + last_bb = loc.block; + } + debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta); + loc.statement_index += delta; + let source_info = Self::source_info_for_index(&body[loc.block], loc); + body[loc.block] + .statements + .insert(loc.statement_index, Statement { source_info, kind: stmt }); + delta += 1; + } + } + + fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo { + match data.statements.get(loc.statement_index) { + Some(stmt) => stmt.source_info, + None => data.terminator().source_info, + } + } + + pub(crate) fn source_info_for_location(&self, body: &Body<'tcx>, loc: Location) -> SourceInfo { + let data = self.block(body, loc.block); + Self::source_info_for_index(data, loc) + } +} diff --git a/compiler/rustc_mir_transform/src/post_analysis_normalize.rs b/compiler/rustc_mir_transform/src/post_analysis_normalize.rs new file mode 100644 index 00000000000..5599dee4cca --- /dev/null +++ b/compiler/rustc_mir_transform/src/post_analysis_normalize.rs @@ -0,0 +1,81 @@ +//! Normalizes MIR in `TypingMode::PostAnalysis` mode, most notably revealing +//! its opaques. We also only normalize specializable associated items once in +//! `PostAnalysis` mode. + +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; + +pub(super) struct PostAnalysisNormalize; + +impl<'tcx> crate::MirPass<'tcx> for PostAnalysisNormalize { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // FIXME(#132279): This is used during the phase transition from analysis + // to runtime, so we have to manually specify the correct typing mode. + let typing_env = ty::TypingEnv::post_analysis(tcx, body.source.def_id()); + PostAnalysisNormalizeVisitor { tcx, typing_env }.visit_body_preserves_cfg(body); + } + + fn is_required(&self) -> bool { + true + } +} + +struct PostAnalysisNormalizeVisitor<'tcx> { + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for PostAnalysisNormalizeVisitor<'tcx> { + #[inline] + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + #[inline] + fn visit_place( + &mut self, + place: &mut Place<'tcx>, + _context: PlaceContext, + _location: Location, + ) { + if !self.tcx.next_trait_solver_globally() { + // `OpaqueCast` projections are only needed if there are opaque types on which projections + // are performed. After the `PostAnalysisNormalize` pass, all opaque types are replaced with their + // hidden types, so we don't need these projections anymore. + // + // Performance optimization: don't reintern if there is no `OpaqueCast` to remove. + if place.projection.iter().any(|elem| matches!(elem, ProjectionElem::OpaqueCast(_))) { + place.projection = self.tcx.mk_place_elems( + &place + .projection + .into_iter() + .filter(|elem| !matches!(elem, ProjectionElem::OpaqueCast(_))) + .collect::<Vec<_>>(), + ); + }; + } + self.super_place(place, _context, _location); + } + + #[inline] + fn visit_const_operand(&mut self, constant: &mut ConstOperand<'tcx>, location: Location) { + // We have to use `try_normalize_erasing_regions` here, since it's + // possible that we visit impossible-to-satisfy where clauses here, + // see #91745 + if let Ok(c) = self.tcx.try_normalize_erasing_regions(self.typing_env, constant.const_) { + constant.const_ = c; + } + self.super_const_operand(constant, location); + } + + #[inline] + fn visit_ty(&mut self, ty: &mut Ty<'tcx>, _: TyContext) { + // We have to use `try_normalize_erasing_regions` here, since it's + // possible that we visit impossible-to-satisfy where clauses here, + // see #91745 + if let Ok(t) = self.tcx.try_normalize_erasing_regions(self.typing_env, *ty) { + *ty = t; + } + } +} diff --git a/compiler/rustc_mir_transform/src/post_drop_elaboration.rs b/compiler/rustc_mir_transform/src/post_drop_elaboration.rs new file mode 100644 index 00000000000..75721d46076 --- /dev/null +++ b/compiler/rustc_mir_transform/src/post_drop_elaboration.rs @@ -0,0 +1,13 @@ +use rustc_const_eval::check_consts; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +use crate::MirLint; + +pub(super) struct CheckLiveDrops; + +impl<'tcx> MirLint<'tcx> for CheckLiveDrops { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + check_consts::post_drop_elaboration::check_live_drops(tcx, body); + } +} diff --git a/compiler/rustc_mir_transform/src/prettify.rs b/compiler/rustc_mir_transform/src/prettify.rs new file mode 100644 index 00000000000..8217feff24e --- /dev/null +++ b/compiler/rustc_mir_transform/src/prettify.rs @@ -0,0 +1,158 @@ +//! These two passes provide no value to the compiler, so are off at every level. +//! +//! However, they can be enabled on the command line +//! (`-Zmir-enable-passes=+ReorderBasicBlocks,+ReorderLocals`) +//! to make the MIR easier to read for humans. + +use rustc_index::bit_set::DenseBitSet; +use rustc_index::{IndexSlice, IndexVec}; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_session::Session; + +/// Rearranges the basic blocks into a *reverse post-order*. +/// +/// Thus after this pass, all the successors of a block are later than it in the +/// `IndexVec`, unless that successor is a back-edge (such as from a loop). +pub(super) struct ReorderBasicBlocks; + +impl<'tcx> crate::MirPass<'tcx> for ReorderBasicBlocks { + fn is_enabled(&self, _session: &Session) -> bool { + false + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let rpo: IndexVec<BasicBlock, BasicBlock> = + body.basic_blocks.reverse_postorder().iter().copied().collect(); + if rpo.iter().is_sorted() { + return; + } + + let mut updater = BasicBlockUpdater { map: rpo.invert_bijective_mapping(), tcx }; + debug_assert_eq!(updater.map[START_BLOCK], START_BLOCK); + updater.visit_body(body); + + permute(body.basic_blocks.as_mut(), &updater.map); + } + + fn is_required(&self) -> bool { + false + } +} + +/// Rearranges the locals into *use* order. +/// +/// Thus after this pass, a local with a smaller [`Location`] where it was first +/// assigned or referenced will have a smaller number. +/// +/// (Does not reorder arguments nor the [`RETURN_PLACE`].) +pub(super) struct ReorderLocals; + +impl<'tcx> crate::MirPass<'tcx> for ReorderLocals { + fn is_enabled(&self, _session: &Session) -> bool { + false + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut finder = LocalFinder { + map: IndexVec::new(), + seen: DenseBitSet::new_empty(body.local_decls.len()), + }; + + // We can't reorder the return place or the arguments + for local in (0..=body.arg_count).map(Local::from_usize) { + finder.track(local); + } + + for (bb, bbd) in body.basic_blocks.iter_enumerated() { + finder.visit_basic_block_data(bb, bbd); + } + + // Track everything in case there are some locals that we never saw, + // such as in non-block things like debug info or in non-uses. + for local in body.local_decls.indices() { + finder.track(local); + } + + if finder.map.iter().is_sorted() { + return; + } + + let mut updater = LocalUpdater { map: finder.map.invert_bijective_mapping(), tcx }; + + for local in (0..=body.arg_count).map(Local::from_usize) { + debug_assert_eq!(updater.map[local], local); + } + + updater.visit_body_preserves_cfg(body); + + permute(&mut body.local_decls, &updater.map); + } + + fn is_required(&self) -> bool { + false + } +} + +fn permute<I: rustc_index::Idx + Ord, T>(data: &mut IndexVec<I, T>, map: &IndexSlice<I, I>) { + // FIXME: It would be nice to have a less-awkward way to apply permutations, + // but I don't know one that exists. `sort_by_cached_key` has logic for it + // internally, but not in a way that we're allowed to use here. + let mut enumerated: Vec<_> = std::mem::take(data).into_iter_enumerated().collect(); + enumerated.sort_by_key(|p| map[p.0]); + *data = enumerated.into_iter().map(|p| p.1).collect(); +} + +struct BasicBlockUpdater<'tcx> { + map: IndexVec<BasicBlock, BasicBlock>, + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for BasicBlockUpdater<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, _location: Location) { + terminator.successors_mut(|succ| *succ = self.map[*succ]); + } +} + +struct LocalFinder { + map: IndexVec<Local, Local>, + seen: DenseBitSet<Local>, +} + +impl LocalFinder { + fn track(&mut self, l: Local) { + if self.seen.insert(l) { + self.map.push(l); + } + } +} + +impl<'tcx> Visitor<'tcx> for LocalFinder { + fn visit_local(&mut self, l: Local, context: PlaceContext, _location: Location) { + // Exclude non-uses to keep `StorageLive` from controlling where we put + // a `Local`, since it might not actually be assigned until much later. + if context.is_use() { + self.track(l); + } + } +} + +struct LocalUpdater<'tcx> { + map: IndexVec<Local, Local>, + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for LocalUpdater<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, l: &mut Local, _: PlaceContext, _: Location) { + *l = self.map[*l]; + } +} diff --git a/compiler/rustc_mir_transform/src/promote_consts.rs b/compiler/rustc_mir_transform/src/promote_consts.rs new file mode 100644 index 00000000000..47d43830970 --- /dev/null +++ b/compiler/rustc_mir_transform/src/promote_consts.rs @@ -0,0 +1,1080 @@ +//! A pass that promotes borrows of constant rvalues. +//! +//! The rvalues considered constant are trees of temps, each with exactly one +//! initialization, and holding a constant value with no interior mutability. +//! They are placed into a new MIR constant body in `promoted` and the borrow +//! rvalue is replaced with a `Literal::Promoted` using the index into +//! `promoted` of that constant MIR. +//! +//! This pass assumes that every use is dominated by an initialization and can +//! otherwise silence errors, if move analysis runs after promotion on broken +//! MIR. + +use std::assert_matches::assert_matches; +use std::cell::Cell; +use std::{cmp, iter, mem}; + +use either::{Left, Right}; +use rustc_const_eval::check_consts::{ConstCx, qualifs}; +use rustc_data_structures::fx::FxHashSet; +use rustc_hir as hir; +use rustc_index::{IndexSlice, IndexVec}; +use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, GenericArgs, List, Ty, TyCtxt, TypeVisitableExt}; +use rustc_middle::{bug, mir, span_bug}; +use rustc_span::Span; +use rustc_span::source_map::Spanned; +use tracing::{debug, instrument}; + +/// A `MirPass` for promotion. +/// +/// Promotion is the extraction of promotable temps into separate MIR bodies so they can have +/// `'static` lifetime. +/// +/// After this pass is run, `promoted_fragments` will hold the MIR body corresponding to each +/// newly created `Constant`. +#[derive(Default)] +pub(super) struct PromoteTemps<'tcx> { + // Must use `Cell` because `run_pass` takes `&self`, not `&mut self`. + pub promoted_fragments: Cell<IndexVec<Promoted, Body<'tcx>>>, +} + +impl<'tcx> crate::MirPass<'tcx> for PromoteTemps<'tcx> { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // There's not really any point in promoting errorful MIR. + // + // This does not include MIR that failed const-checking, which we still try to promote. + if let Err(_) = body.return_ty().error_reported() { + debug!("PromoteTemps: MIR had errors"); + return; + } + if body.source.promoted.is_some() { + return; + } + + let ccx = ConstCx::new(tcx, body); + let (mut temps, all_candidates) = collect_temps_and_candidates(&ccx); + + let promotable_candidates = validate_candidates(&ccx, &mut temps, all_candidates); + + let promoted = promote_candidates(body, tcx, temps, promotable_candidates); + self.promoted_fragments.set(promoted); + } + + fn is_required(&self) -> bool { + true + } +} + +/// State of a temporary during collection and promotion. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +enum TempState { + /// No references to this temp. + Undefined, + /// One direct assignment and any number of direct uses. + /// A borrow of this temp is promotable if the assigned + /// value is qualified as constant. + Defined { location: Location, uses: usize, valid: Result<(), ()> }, + /// Any other combination of assignments/uses. + Unpromotable, + /// This temp was part of an rvalue which got extracted + /// during promotion and needs cleanup. + PromotedOut, +} + +/// A "root candidate" for promotion, which will become the +/// returned value in a promoted MIR, unless it's a subset +/// of a larger candidate. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +struct Candidate { + location: Location, +} + +struct Collector<'a, 'tcx> { + ccx: &'a ConstCx<'a, 'tcx>, + temps: IndexVec<Local, TempState>, + candidates: Vec<Candidate>, +} + +impl<'tcx> Visitor<'tcx> for Collector<'_, 'tcx> { + #[instrument(level = "debug", skip(self))] + fn visit_local(&mut self, index: Local, context: PlaceContext, location: Location) { + // We're only interested in temporaries and the return place + match self.ccx.body.local_kind(index) { + LocalKind::Arg => return, + LocalKind::Temp if self.ccx.body.local_decls[index].is_user_variable() => return, + LocalKind::ReturnPointer | LocalKind::Temp => {} + } + + // Ignore drops, if the temp gets promoted, + // then it's constant and thus drop is noop. + // Non-uses are also irrelevant. + if context.is_drop() || !context.is_use() { + debug!(is_drop = context.is_drop(), is_use = context.is_use()); + return; + } + + let temp = &mut self.temps[index]; + debug!(?temp); + *temp = match *temp { + TempState::Undefined => match context { + PlaceContext::MutatingUse(MutatingUseContext::Store | MutatingUseContext::Call) => { + TempState::Defined { location, uses: 0, valid: Err(()) } + } + _ => TempState::Unpromotable, + }, + TempState::Defined { ref mut uses, .. } => { + // We always allow borrows, even mutable ones, as we need + // to promote mutable borrows of some ZSTs e.g., `&mut []`. + let allowed_use = match context { + PlaceContext::MutatingUse(MutatingUseContext::Borrow) + | PlaceContext::NonMutatingUse(_) => true, + PlaceContext::MutatingUse(_) | PlaceContext::NonUse(_) => false, + }; + debug!(?allowed_use); + if allowed_use { + *uses += 1; + return; + } + TempState::Unpromotable + } + TempState::Unpromotable | TempState::PromotedOut => TempState::Unpromotable, + }; + debug!(?temp); + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + self.super_rvalue(rvalue, location); + + if let Rvalue::Ref(..) = *rvalue { + self.candidates.push(Candidate { location }); + } + } +} + +fn collect_temps_and_candidates<'tcx>( + ccx: &ConstCx<'_, 'tcx>, +) -> (IndexVec<Local, TempState>, Vec<Candidate>) { + let mut collector = Collector { + temps: IndexVec::from_elem(TempState::Undefined, &ccx.body.local_decls), + candidates: vec![], + ccx, + }; + for (bb, data) in traversal::reverse_postorder(ccx.body) { + collector.visit_basic_block_data(bb, data); + } + (collector.temps, collector.candidates) +} + +/// Checks whether locals that appear in a promotion context (`Candidate`) are actually promotable. +/// +/// This wraps an `Item`, and has access to all fields of that `Item` via `Deref` coercion. +struct Validator<'a, 'tcx> { + ccx: &'a ConstCx<'a, 'tcx>, + temps: &'a mut IndexSlice<Local, TempState>, + /// For backwards compatibility, we are promoting function calls in `const`/`static` + /// initializers. But we want to avoid evaluating code that might panic and that otherwise would + /// not have been evaluated, so we only promote such calls in basic blocks that are guaranteed + /// to execute. In other words, we only promote such calls in basic blocks that are definitely + /// not dead code. Here we cache the result of computing that set of basic blocks. + promotion_safe_blocks: Option<FxHashSet<BasicBlock>>, +} + +impl<'a, 'tcx> std::ops::Deref for Validator<'a, 'tcx> { + type Target = ConstCx<'a, 'tcx>; + + fn deref(&self) -> &Self::Target { + self.ccx + } +} + +struct Unpromotable; + +impl<'tcx> Validator<'_, 'tcx> { + fn validate_candidate(&mut self, candidate: Candidate) -> Result<(), Unpromotable> { + let Left(statement) = self.body.stmt_at(candidate.location) else { bug!() }; + let Some((_, Rvalue::Ref(_, kind, place))) = statement.kind.as_assign() else { bug!() }; + + // We can only promote interior borrows of promotable temps (non-temps + // don't get promoted anyway). + self.validate_local(place.local)?; + + // The reference operation itself must be promotable. + // (Needs to come after `validate_local` to avoid ICEs.) + self.validate_ref(*kind, place)?; + + // We do not check all the projections (they do not get promoted anyway), + // but we do stay away from promoting anything involving a dereference. + if place.projection.contains(&ProjectionElem::Deref) { + return Err(Unpromotable); + } + + Ok(()) + } + + // FIXME(eddyb) maybe cache this? + fn qualif_local<Q: qualifs::Qualif>(&mut self, local: Local) -> bool { + let TempState::Defined { location: loc, .. } = self.temps[local] else { + return false; + }; + + let stmt_or_term = self.body.stmt_at(loc); + match stmt_or_term { + Left(statement) => { + let Some((_, rhs)) = statement.kind.as_assign() else { + span_bug!(statement.source_info.span, "{:?} is not an assignment", statement) + }; + qualifs::in_rvalue::<Q, _>(self.ccx, &mut |l| self.qualif_local::<Q>(l), rhs) + } + Right(terminator) => { + assert_matches!(terminator.kind, TerminatorKind::Call { .. }); + let return_ty = self.body.local_decls[local].ty; + Q::in_any_value_of_ty(self.ccx, return_ty) + } + } + } + + fn validate_local(&mut self, local: Local) -> Result<(), Unpromotable> { + let TempState::Defined { location: loc, uses, valid } = self.temps[local] else { + return Err(Unpromotable); + }; + + // We cannot promote things that need dropping, since the promoted value would not get + // dropped. + if self.qualif_local::<qualifs::NeedsDrop>(local) { + return Err(Unpromotable); + } + + if valid.is_ok() { + return Ok(()); + } + + let ok = { + let stmt_or_term = self.body.stmt_at(loc); + match stmt_or_term { + Left(statement) => { + let Some((_, rhs)) = statement.kind.as_assign() else { + span_bug!( + statement.source_info.span, + "{:?} is not an assignment", + statement + ) + }; + self.validate_rvalue(rhs) + } + Right(terminator) => match &terminator.kind { + TerminatorKind::Call { func, args, .. } => { + self.validate_call(func, args, loc.block) + } + TerminatorKind::Yield { .. } => Err(Unpromotable), + kind => { + span_bug!(terminator.source_info.span, "{:?} not promotable", kind); + } + }, + } + }; + + self.temps[local] = match ok { + Ok(()) => TempState::Defined { location: loc, uses, valid: Ok(()) }, + Err(_) => TempState::Unpromotable, + }; + + ok + } + + fn validate_place(&mut self, place: PlaceRef<'tcx>) -> Result<(), Unpromotable> { + let Some((place_base, elem)) = place.last_projection() else { + return self.validate_local(place.local); + }; + + // Validate topmost projection, then recurse. + match elem { + // Recurse directly. + ProjectionElem::ConstantIndex { .. } + | ProjectionElem::Subtype(_) + | ProjectionElem::Subslice { .. } + | ProjectionElem::UnwrapUnsafeBinder(_) => {} + + // Never recurse. + ProjectionElem::OpaqueCast(..) | ProjectionElem::Downcast(..) => { + return Err(Unpromotable); + } + + ProjectionElem::Deref => { + // When a static is used by-value, that gets desugared to `*STATIC_ADDR`, + // and we need to be able to promote this. So check if this deref matches + // that specific pattern. + + // We need to make sure this is a `Deref` of a local with no further projections. + // Discussion can be found at + // https://github.com/rust-lang/rust/pull/74945#discussion_r463063247 + if let Some(local) = place_base.as_local() + && let TempState::Defined { location, .. } = self.temps[local] + && let Left(def_stmt) = self.body.stmt_at(location) + && let Some((_, Rvalue::Use(Operand::Constant(c)))) = def_stmt.kind.as_assign() + && let Some(did) = c.check_static_ptr(self.tcx) + // Evaluating a promoted may not read statics except if it got + // promoted from a static (this is a CTFE check). So we + // can only promote static accesses inside statics. + && let Some(hir::ConstContext::Static(..)) = self.const_kind + && !self.tcx.is_thread_local_static(did) + { + // Recurse. + } else { + return Err(Unpromotable); + } + } + ProjectionElem::Index(local) => { + // Only accept if we can predict the index and are indexing an array. + if let TempState::Defined { location: loc, .. } = self.temps[local] + && let Left(statement) = self.body.stmt_at(loc) + && let Some((_, Rvalue::Use(Operand::Constant(c)))) = statement.kind.as_assign() + && let Some(idx) = c.const_.try_eval_target_usize(self.tcx, self.typing_env) + // Determine the type of the thing we are indexing. + && let ty::Array(_, len) = place_base.ty(self.body, self.tcx).ty.kind() + // It's an array; determine its length. + && let Some(len) = len.try_to_target_usize(self.tcx) + // If the index is in-bounds, go ahead. + && idx < len + { + self.validate_local(local)?; + // Recurse. + } else { + return Err(Unpromotable); + } + } + + ProjectionElem::Field(..) => { + let base_ty = place_base.ty(self.body, self.tcx).ty; + if base_ty.is_union() { + // No promotion of union field accesses. + return Err(Unpromotable); + } + } + } + + self.validate_place(place_base) + } + + fn validate_operand(&mut self, operand: &Operand<'tcx>) -> Result<(), Unpromotable> { + match operand { + Operand::Copy(place) | Operand::Move(place) => self.validate_place(place.as_ref()), + + // The qualifs for a constant (e.g. `HasMutInterior`) are checked in + // `validate_rvalue` upon access. + Operand::Constant(c) => { + if let Some(def_id) = c.check_static_ptr(self.tcx) { + // Only allow statics (not consts) to refer to other statics. + // FIXME(eddyb) does this matter at all for promotion? + // FIXME(RalfJung) it makes little sense to not promote this in `fn`/`const fn`, + // and in `const` this cannot occur anyway. The only concern is that we might + // promote even `let x = &STATIC` which would be useless, but this applies to + // promotion inside statics as well. + let is_static = matches!(self.const_kind, Some(hir::ConstContext::Static(_))); + if !is_static { + return Err(Unpromotable); + } + + let is_thread_local = self.tcx.is_thread_local_static(def_id); + if is_thread_local { + return Err(Unpromotable); + } + } + + Ok(()) + } + } + } + + fn validate_ref(&mut self, kind: BorrowKind, place: &Place<'tcx>) -> Result<(), Unpromotable> { + match kind { + // Reject these borrow types just to be safe. + // FIXME(RalfJung): could we allow them? Should we? No point in it until we have a + // usecase. + BorrowKind::Fake(_) | BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture } => { + return Err(Unpromotable); + } + + BorrowKind::Shared => { + let has_mut_interior = self.qualif_local::<qualifs::HasMutInterior>(place.local); + if has_mut_interior { + return Err(Unpromotable); + } + } + + // FIXME: consider changing this to only promote &mut [] for default borrows, + // also forbidding two phase borrows + BorrowKind::Mut { kind: MutBorrowKind::Default | MutBorrowKind::TwoPhaseBorrow } => { + let ty = place.ty(self.body, self.tcx).ty; + + // In theory, any zero-sized value could be borrowed + // mutably without consequences. However, only &mut [] + // is allowed right now. + if let ty::Array(_, len) = ty.kind() { + match len.try_to_target_usize(self.tcx) { + Some(0) => {} + _ => return Err(Unpromotable), + } + } else { + return Err(Unpromotable); + } + } + } + + Ok(()) + } + + fn validate_rvalue(&mut self, rvalue: &Rvalue<'tcx>) -> Result<(), Unpromotable> { + match rvalue { + Rvalue::Use(operand) + | Rvalue::Repeat(operand, _) + | Rvalue::WrapUnsafeBinder(operand, _) => { + self.validate_operand(operand)?; + } + Rvalue::CopyForDeref(place) => { + let op = &Operand::Copy(*place); + self.validate_operand(op)? + } + + Rvalue::Discriminant(place) | Rvalue::Len(place) => { + self.validate_place(place.as_ref())? + } + + Rvalue::ThreadLocalRef(_) => return Err(Unpromotable), + + // ptr-to-int casts are not possible in consts and thus not promotable + Rvalue::Cast(CastKind::PointerExposeProvenance, _, _) => return Err(Unpromotable), + + // all other casts including int-to-ptr casts are fine, they just use the integer value + // at pointer type. + Rvalue::Cast(_, operand, _) => { + self.validate_operand(operand)?; + } + + Rvalue::NullaryOp(op, _) => match op { + NullOp::SizeOf => {} + NullOp::AlignOf => {} + NullOp::OffsetOf(_) => {} + NullOp::UbChecks => {} + NullOp::ContractChecks => {} + }, + + Rvalue::ShallowInitBox(_, _) => return Err(Unpromotable), + + Rvalue::UnaryOp(op, operand) => { + match op { + // These operations can never fail. + UnOp::Neg | UnOp::Not | UnOp::PtrMetadata => {} + } + + self.validate_operand(operand)?; + } + + Rvalue::BinaryOp(op, box (lhs, rhs)) => { + let op = *op; + let lhs_ty = lhs.ty(self.body, self.tcx); + + if let ty::RawPtr(_, _) | ty::FnPtr(..) = lhs_ty.kind() { + // Raw and fn pointer operations are not allowed inside consts and thus not + // promotable. + assert_matches!( + op, + BinOp::Eq + | BinOp::Ne + | BinOp::Le + | BinOp::Lt + | BinOp::Ge + | BinOp::Gt + | BinOp::Offset + ); + return Err(Unpromotable); + } + + match op { + BinOp::Div | BinOp::Rem => { + if lhs_ty.is_integral() { + let sz = lhs_ty.primitive_size(self.tcx); + // Integer division: the RHS must be a non-zero const. + let rhs_val = match rhs { + Operand::Constant(c) => { + c.const_.try_eval_scalar_int(self.tcx, self.typing_env) + } + _ => None, + }; + match rhs_val.map(|x| x.to_uint(sz)) { + // for the zero test, int vs uint does not matter + Some(x) if x != 0 => {} // okay + _ => return Err(Unpromotable), // value not known or 0 -- not okay + } + // Furthermore, for signed division, we also have to exclude `int::MIN / + // -1`. + if lhs_ty.is_signed() { + match rhs_val.map(|x| x.to_int(sz)) { + Some(-1) | None => { + // The RHS is -1 or unknown, so we have to be careful. + // But is the LHS int::MIN? + let lhs_val = match lhs { + Operand::Constant(c) => c + .const_ + .try_eval_scalar_int(self.tcx, self.typing_env), + _ => None, + }; + let lhs_min = sz.signed_int_min(); + match lhs_val.map(|x| x.to_int(sz)) { + // okay + Some(x) if x != lhs_min => {} + + // value not known or int::MIN -- not okay + _ => return Err(Unpromotable), + } + } + _ => {} + } + } + } + } + // The remaining operations can never fail. + BinOp::Eq + | BinOp::Ne + | BinOp::Le + | BinOp::Lt + | BinOp::Ge + | BinOp::Gt + | BinOp::Cmp + | BinOp::Offset + | BinOp::Add + | BinOp::AddUnchecked + | BinOp::AddWithOverflow + | BinOp::Sub + | BinOp::SubUnchecked + | BinOp::SubWithOverflow + | BinOp::Mul + | BinOp::MulUnchecked + | BinOp::MulWithOverflow + | BinOp::BitXor + | BinOp::BitAnd + | BinOp::BitOr + | BinOp::Shl + | BinOp::ShlUnchecked + | BinOp::Shr + | BinOp::ShrUnchecked => {} + } + + self.validate_operand(lhs)?; + self.validate_operand(rhs)?; + } + + Rvalue::RawPtr(_, place) => { + // We accept `&raw *`, i.e., raw reborrows -- creating a raw pointer is + // no problem, only using it is. + if let Some((place_base, ProjectionElem::Deref)) = place.as_ref().last_projection() + { + let base_ty = place_base.ty(self.body, self.tcx).ty; + if let ty::Ref(..) = base_ty.kind() { + return self.validate_place(place_base); + } + } + return Err(Unpromotable); + } + + Rvalue::Ref(_, kind, place) => { + // Special-case reborrows to be more like a copy of the reference. + let mut place_simplified = place.as_ref(); + if let Some((place_base, ProjectionElem::Deref)) = + place_simplified.last_projection() + { + let base_ty = place_base.ty(self.body, self.tcx).ty; + if let ty::Ref(..) = base_ty.kind() { + place_simplified = place_base; + } + } + + self.validate_place(place_simplified)?; + + // Check that the reference is fine (using the original place!). + // (Needs to come after `validate_place` to avoid ICEs.) + self.validate_ref(*kind, place)?; + } + + Rvalue::Aggregate(_, operands) => { + for o in operands { + self.validate_operand(o)?; + } + } + } + + Ok(()) + } + + /// Computes the sets of blocks of this MIR that are definitely going to be executed + /// if the function returns successfully. That makes it safe to promote calls in them + /// that might fail. + fn promotion_safe_blocks(body: &mir::Body<'tcx>) -> FxHashSet<BasicBlock> { + let mut safe_blocks = FxHashSet::default(); + let mut safe_block = START_BLOCK; + loop { + safe_blocks.insert(safe_block); + // Let's see if we can find another safe block. + safe_block = match body.basic_blocks[safe_block].terminator().kind { + TerminatorKind::Goto { target } => target, + TerminatorKind::Call { target: Some(target), .. } + | TerminatorKind::Drop { target, .. } => { + // This calls a function or the destructor. `target` does not get executed if + // the callee loops or panics. But in both cases the const already fails to + // evaluate, so we are fine considering `target` a safe block for promotion. + target + } + TerminatorKind::Assert { target, .. } => { + // Similar to above, we only consider successful execution. + target + } + _ => { + // No next safe block. + break; + } + }; + } + safe_blocks + } + + /// Returns whether the block is "safe" for promotion, which means it cannot be dead code. + /// We use this to avoid promoting operations that can fail in dead code. + fn is_promotion_safe_block(&mut self, block: BasicBlock) -> bool { + let body = self.body; + let safe_blocks = + self.promotion_safe_blocks.get_or_insert_with(|| Self::promotion_safe_blocks(body)); + safe_blocks.contains(&block) + } + + fn validate_call( + &mut self, + callee: &Operand<'tcx>, + args: &[Spanned<Operand<'tcx>>], + block: BasicBlock, + ) -> Result<(), Unpromotable> { + // Validate the operands. If they fail, there's no question -- we cannot promote. + self.validate_operand(callee)?; + for arg in args { + self.validate_operand(&arg.node)?; + } + + // Functions marked `#[rustc_promotable]` are explicitly allowed to be promoted, so we can + // accept them at this point. + let fn_ty = callee.ty(self.body, self.tcx); + if let ty::FnDef(def_id, _) = *fn_ty.kind() { + if self.tcx.is_promotable_const_fn(def_id) { + return Ok(()); + } + } + + // Ideally, we'd stop here and reject the rest. + // But for backward compatibility, we have to accept some promotion in const/static + // initializers. Inline consts are explicitly excluded, they are more recent so we have no + // backwards compatibility reason to allow more promotion inside of them. + let promote_all_fn = matches!( + self.const_kind, + Some(hir::ConstContext::Static(_) | hir::ConstContext::Const { inline: false }) + ); + if !promote_all_fn { + return Err(Unpromotable); + } + // Make sure the callee is a `const fn`. + let is_const_fn = match *fn_ty.kind() { + ty::FnDef(def_id, _) => self.tcx.is_const_fn(def_id), + _ => false, + }; + if !is_const_fn { + return Err(Unpromotable); + } + // The problem is, this may promote calls to functions that panic. + // We don't want to introduce compilation errors if there's a panic in a call in dead code. + // So we ensure that this is not dead code. + if !self.is_promotion_safe_block(block) { + return Err(Unpromotable); + } + // This passed all checks, so let's accept. + Ok(()) + } +} + +fn validate_candidates( + ccx: &ConstCx<'_, '_>, + temps: &mut IndexSlice<Local, TempState>, + mut candidates: Vec<Candidate>, +) -> Vec<Candidate> { + let mut validator = Validator { ccx, temps, promotion_safe_blocks: None }; + + candidates.retain(|&candidate| validator.validate_candidate(candidate).is_ok()); + candidates +} + +struct Promoter<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + source: &'a mut Body<'tcx>, + promoted: Body<'tcx>, + temps: &'a mut IndexVec<Local, TempState>, + extra_statements: &'a mut Vec<(Location, Statement<'tcx>)>, + + /// Used to assemble the required_consts list while building the promoted. + required_consts: Vec<ConstOperand<'tcx>>, + + /// If true, all nested temps are also kept in the + /// source MIR, not moved to the promoted MIR. + keep_original: bool, + + /// If true, add the new const (the promoted) to the required_consts of the parent MIR. + /// This is initially false and then set by the visitor when it encounters a `Call` terminator. + add_to_required: bool, +} + +impl<'a, 'tcx> Promoter<'a, 'tcx> { + fn new_block(&mut self) -> BasicBlock { + let span = self.promoted.span; + self.promoted.basic_blocks_mut().push(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + source_info: SourceInfo::outermost(span), + kind: TerminatorKind::Return, + }), + is_cleanup: false, + }) + } + + fn assign(&mut self, dest: Local, rvalue: Rvalue<'tcx>, span: Span) { + let last = self.promoted.basic_blocks.last_index().unwrap(); + let data = &mut self.promoted[last]; + data.statements.push(Statement { + source_info: SourceInfo::outermost(span), + kind: StatementKind::Assign(Box::new((Place::from(dest), rvalue))), + }); + } + + fn is_temp_kind(&self, local: Local) -> bool { + self.source.local_kind(local) == LocalKind::Temp + } + + /// Copies the initialization of this temp to the + /// promoted MIR, recursing through temps. + fn promote_temp(&mut self, temp: Local) -> Local { + let old_keep_original = self.keep_original; + let loc = match self.temps[temp] { + TempState::Defined { location, uses, .. } if uses > 0 => { + if uses > 1 { + self.keep_original = true; + } + location + } + state => { + span_bug!(self.promoted.span, "{:?} not promotable: {:?}", temp, state); + } + }; + if !self.keep_original { + self.temps[temp] = TempState::PromotedOut; + } + + let num_stmts = self.source[loc.block].statements.len(); + let new_temp = self.promoted.local_decls.push(LocalDecl::new( + self.source.local_decls[temp].ty, + self.source.local_decls[temp].source_info.span, + )); + + debug!("promote({:?} @ {:?}/{:?}, {:?})", temp, loc, num_stmts, self.keep_original); + + // First, take the Rvalue or Call out of the source MIR, + // or duplicate it, depending on keep_original. + if loc.statement_index < num_stmts { + let (mut rvalue, source_info) = { + let statement = &mut self.source[loc.block].statements[loc.statement_index]; + let StatementKind::Assign(box (_, rhs)) = &mut statement.kind else { + span_bug!(statement.source_info.span, "{:?} is not an assignment", statement); + }; + + ( + if self.keep_original { + rhs.clone() + } else { + let unit = Rvalue::Use(Operand::Constant(Box::new(ConstOperand { + span: statement.source_info.span, + user_ty: None, + const_: Const::zero_sized(self.tcx.types.unit), + }))); + mem::replace(rhs, unit) + }, + statement.source_info, + ) + }; + + self.visit_rvalue(&mut rvalue, loc); + self.assign(new_temp, rvalue, source_info.span); + } else { + let terminator = if self.keep_original { + self.source[loc.block].terminator().clone() + } else { + let terminator = self.source[loc.block].terminator_mut(); + let target = match &terminator.kind { + TerminatorKind::Call { target: Some(target), .. } => *target, + kind => { + span_bug!(terminator.source_info.span, "{:?} not promotable", kind); + } + }; + Terminator { + source_info: terminator.source_info, + kind: mem::replace(&mut terminator.kind, TerminatorKind::Goto { target }), + } + }; + + match terminator.kind { + TerminatorKind::Call { + mut func, mut args, call_source: desugar, fn_span, .. + } => { + // This promoted involves a function call, so it may fail to evaluate. Let's + // make sure it is added to `required_consts` so that failure cannot get lost. + self.add_to_required = true; + + self.visit_operand(&mut func, loc); + for arg in &mut args { + self.visit_operand(&mut arg.node, loc); + } + + let last = self.promoted.basic_blocks.last_index().unwrap(); + let new_target = self.new_block(); + + *self.promoted[last].terminator_mut() = Terminator { + kind: TerminatorKind::Call { + func, + args, + unwind: UnwindAction::Continue, + destination: Place::from(new_temp), + target: Some(new_target), + call_source: desugar, + fn_span, + }, + source_info: SourceInfo::outermost(terminator.source_info.span), + ..terminator + }; + } + kind => { + span_bug!(terminator.source_info.span, "{:?} not promotable", kind); + } + }; + }; + + self.keep_original = old_keep_original; + new_temp + } + + fn promote_candidate( + mut self, + candidate: Candidate, + next_promoted_index: Promoted, + ) -> Body<'tcx> { + let def = self.source.source.def_id(); + let (mut rvalue, promoted_op) = { + let promoted = &mut self.promoted; + let tcx = self.tcx; + let mut promoted_operand = |ty, span| { + promoted.span = span; + promoted.local_decls[RETURN_PLACE] = LocalDecl::new(ty, span); + let args = tcx.erase_regions(GenericArgs::identity_for_item(tcx, def)); + let uneval = + mir::UnevaluatedConst { def, args, promoted: Some(next_promoted_index) }; + + ConstOperand { span, user_ty: None, const_: Const::Unevaluated(uneval, ty) } + }; + + let blocks = self.source.basic_blocks.as_mut(); + let local_decls = &mut self.source.local_decls; + let loc = candidate.location; + let statement = &mut blocks[loc.block].statements[loc.statement_index]; + let StatementKind::Assign(box (_, Rvalue::Ref(region, borrow_kind, place))) = + &mut statement.kind + else { + bug!() + }; + + // Use the underlying local for this (necessarily interior) borrow. + debug_assert!(region.is_erased()); + let ty = local_decls[place.local].ty; + let span = statement.source_info.span; + + let ref_ty = + Ty::new_ref(tcx, tcx.lifetimes.re_erased, ty, borrow_kind.to_mutbl_lossy()); + + let mut projection = vec![PlaceElem::Deref]; + projection.extend(place.projection); + place.projection = tcx.mk_place_elems(&projection); + + // Create a temp to hold the promoted reference. + // This is because `*r` requires `r` to be a local, + // otherwise we would use the `promoted` directly. + let mut promoted_ref = LocalDecl::new(ref_ty, span); + promoted_ref.source_info = statement.source_info; + let promoted_ref = local_decls.push(promoted_ref); + assert_eq!(self.temps.push(TempState::Unpromotable), promoted_ref); + + let promoted_operand = promoted_operand(ref_ty, span); + let promoted_ref_statement = Statement { + source_info: statement.source_info, + kind: StatementKind::Assign(Box::new(( + Place::from(promoted_ref), + Rvalue::Use(Operand::Constant(Box::new(promoted_operand))), + ))), + }; + self.extra_statements.push((loc, promoted_ref_statement)); + + ( + Rvalue::Ref( + tcx.lifetimes.re_erased, + *borrow_kind, + Place { + local: mem::replace(&mut place.local, promoted_ref), + projection: List::empty(), + }, + ), + promoted_operand, + ) + }; + + assert_eq!(self.new_block(), START_BLOCK); + self.visit_rvalue( + &mut rvalue, + Location { block: START_BLOCK, statement_index: usize::MAX }, + ); + + let span = self.promoted.span; + self.assign(RETURN_PLACE, rvalue, span); + + // Now that we did promotion, we know whether we'll want to add this to `required_consts` of + // the surrounding MIR body. + if self.add_to_required { + self.source.required_consts.as_mut().unwrap().push(promoted_op); + } + + self.promoted.set_required_consts(self.required_consts); + + self.promoted + } +} + +/// Replaces all temporaries with their promoted counterparts. +impl<'a, 'tcx> MutVisitor<'tcx> for Promoter<'a, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + if self.is_temp_kind(*local) { + *local = self.promote_temp(*local); + } + } + + fn visit_const_operand(&mut self, constant: &mut ConstOperand<'tcx>, _location: Location) { + if constant.const_.is_required_const() { + self.required_consts.push(*constant); + } + + // Skipping `super_constant` as the visitor is otherwise only looking for locals. + } +} + +fn promote_candidates<'tcx>( + body: &mut Body<'tcx>, + tcx: TyCtxt<'tcx>, + mut temps: IndexVec<Local, TempState>, + candidates: Vec<Candidate>, +) -> IndexVec<Promoted, Body<'tcx>> { + // Visit candidates in reverse, in case they're nested. + debug!(promote_candidates = ?candidates); + + // eagerly fail fast + if candidates.is_empty() { + return IndexVec::new(); + } + + let mut promotions = IndexVec::new(); + + let mut extra_statements = vec![]; + for candidate in candidates.into_iter().rev() { + let Location { block, statement_index } = candidate.location; + if let StatementKind::Assign(box (place, _)) = &body[block].statements[statement_index].kind + { + if let Some(local) = place.as_local() { + if temps[local] == TempState::PromotedOut { + // Already promoted. + continue; + } + } + } + + // Declare return place local so that `mir::Body::new` doesn't complain. + let initial_locals = iter::once(LocalDecl::new(tcx.types.never, body.span)).collect(); + + let mut scope = body.source_scopes[body.source_info(candidate.location).scope].clone(); + scope.parent_scope = None; + + let mut promoted = Body::new( + body.source, // `promoted` gets filled in below + IndexVec::new(), + IndexVec::from_elem_n(scope, 1), + initial_locals, + IndexVec::new(), + 0, + vec![], + body.span, + None, + body.tainted_by_errors, + ); + promoted.phase = MirPhase::Analysis(AnalysisPhase::Initial); + + let promoter = Promoter { + promoted, + tcx, + source: body, + temps: &mut temps, + extra_statements: &mut extra_statements, + keep_original: false, + add_to_required: false, + required_consts: Vec::new(), + }; + + let mut promoted = promoter.promote_candidate(candidate, promotions.next_index()); + promoted.source.promoted = Some(promotions.next_index()); + promotions.push(promoted); + } + + // Insert each of `extra_statements` before its indicated location, which + // has to be done in reverse location order, to not invalidate the rest. + extra_statements.sort_by_key(|&(loc, _)| cmp::Reverse(loc)); + for (loc, statement) in extra_statements { + body[loc.block].statements.insert(loc.statement_index, statement); + } + + // Eliminate assignments to, and drops of promoted temps. + let promoted = |index: Local| temps[index] == TempState::PromotedOut; + for block in body.basic_blocks_mut() { + block.statements.retain(|statement| match &statement.kind { + StatementKind::Assign(box (place, _)) => { + if let Some(index) = place.as_local() { + !promoted(index) + } else { + true + } + } + StatementKind::StorageLive(index) | StatementKind::StorageDead(index) => { + !promoted(*index) + } + _ => true, + }); + let terminator = block.terminator_mut(); + if let TerminatorKind::Drop { place, target, .. } = &terminator.kind { + if let Some(index) = place.as_local() { + if promoted(index) { + terminator.kind = TerminatorKind::Goto { target: *target }; + } + } + } + } + + promotions +} diff --git a/compiler/rustc_mir_transform/src/ref_prop.rs b/compiler/rustc_mir_transform/src/ref_prop.rs new file mode 100644 index 00000000000..368d5340ac3 --- /dev/null +++ b/compiler/rustc_mir_transform/src/ref_prop.rs @@ -0,0 +1,418 @@ +use std::borrow::Cow; + +use rustc_data_structures::fx::FxHashSet; +use rustc_index::IndexVec; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::bug; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::Analysis; +use rustc_mir_dataflow::impls::{MaybeStorageDead, always_storage_live_locals}; +use tracing::{debug, instrument}; + +use crate::ssa::{SsaLocals, StorageLiveLocals}; + +/// Propagate references using SSA analysis. +/// +/// MIR building may produce a lot of borrow-dereference patterns. +/// +/// This pass aims to transform the following pattern: +/// _1 = &raw? mut? PLACE; +/// _3 = *_1; +/// _4 = &raw? mut? *_1; +/// +/// Into +/// _1 = &raw? mut? PLACE; +/// _3 = PLACE; +/// _4 = &raw? mut? PLACE; +/// +/// where `PLACE` is a direct or an indirect place expression. +/// +/// There are 3 properties that need to be upheld for this transformation to be legal: +/// - place stability: `PLACE` must refer to the same memory wherever it appears; +/// - pointer liveness: we must not introduce dereferences of dangling pointers; +/// - `&mut` borrow uniqueness. +/// +/// # Stability +/// +/// If `PLACE` is an indirect projection, if its of the form `(*LOCAL).PROJECTIONS` where: +/// - `LOCAL` is SSA; +/// - all projections in `PROJECTIONS` have a stable offset (no dereference and no indexing). +/// +/// If `PLACE` is a direct projection of a local, we consider it as constant if: +/// - the local is always live, or it has a single `StorageLive`; +/// - all projections have a stable offset. +/// +/// # Liveness +/// +/// When performing an instantiation, we must take care not to introduce uses of dangling locals. +/// To ensure this, we walk the body with the `MaybeStorageDead` dataflow analysis: +/// - if we want to replace `*x` by reborrow `*y` and `y` may be dead, we allow replacement and +/// mark storage statements on `y` for removal; +/// - if we want to replace `*x` by non-reborrow `y` and `y` must be live, we allow replacement; +/// - if we want to replace `*x` by non-reborrow `y` and `y` may be dead, we do not replace. +/// +/// # Uniqueness +/// +/// For `&mut` borrows, we also need to preserve the uniqueness property: +/// we must avoid creating a state where we interleave uses of `*_1` and `_2`. +/// To do it, we only perform full instantiation of mutable borrows: +/// we replace either all or none of the occurrences of `*_1`. +/// +/// Some care has to be taken when `_1` is copied in other locals. +/// _1 = &raw? mut? _2; +/// _3 = *_1; +/// _4 = _1 +/// _5 = *_4 +/// In such cases, fully instantiating `_1` means fully instantiating all of the copies. +/// +/// For immutable borrows, we do not need to preserve such uniqueness property, +/// so we perform all the possible instantiations without removing the `_1 = &_2` statement. +pub(super) struct ReferencePropagation; + +impl<'tcx> crate::MirPass<'tcx> for ReferencePropagation { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 2 + } + + #[instrument(level = "trace", skip(self, tcx, body))] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); + while propagate_ssa(tcx, body) {} + } + + fn is_required(&self) -> bool { + false + } +} + +fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { + let typing_env = body.typing_env(tcx); + let ssa = SsaLocals::new(tcx, body, typing_env); + + let mut replacer = compute_replacement(tcx, body, &ssa); + debug!(?replacer.targets); + debug!(?replacer.allowed_replacements); + debug!(?replacer.storage_to_remove); + + replacer.visit_body_preserves_cfg(body); + + if replacer.any_replacement { + crate::simplify::remove_unused_definitions(body); + } + + replacer.any_replacement +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum Value<'tcx> { + /// Not a pointer, or we can't know. + Unknown, + /// We know the value to be a pointer to this place. + /// The boolean indicates whether the reference is mutable, subject the uniqueness rule. + Pointer(Place<'tcx>, bool), +} + +/// For each local, save the place corresponding to `*local`. +#[instrument(level = "trace", skip(tcx, body, ssa))] +fn compute_replacement<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + ssa: &SsaLocals, +) -> Replacer<'tcx> { + let always_live_locals = always_storage_live_locals(body); + + // Compute which locals have a single `StorageLive` statement ever. + let storage_live = StorageLiveLocals::new(body, &always_live_locals); + + // Compute `MaybeStorageDead` dataflow to check that we only replace when the pointee is + // definitely live. + let mut maybe_dead = MaybeStorageDead::new(Cow::Owned(always_live_locals)) + .iterate_to_fixpoint(tcx, body, None) + .into_results_cursor(body); + + // Map for each local to the pointee. + let mut targets = IndexVec::from_elem(Value::Unknown, &body.local_decls); + // Set of locals for which we will remove their storage statement. This is useful for + // reborrowed references. + let mut storage_to_remove = DenseBitSet::new_empty(body.local_decls.len()); + + let fully_replacable_locals = fully_replacable_locals(ssa); + + // Returns true iff we can use `place` as a pointee. + // + // Note that we only need to verify that there is a single `StorageLive` statement, and we do + // not need to verify that it dominates all uses of that local. + // + // Consider the three statements: + // SL : StorageLive(a) + // DEF: b = &raw? mut? a + // USE: stuff that uses *b + // + // First, we recall that DEF is checked to dominate USE. Now imagine for the sake of + // contradiction there is a DEF -> SL -> USE path. Consider two cases: + // + // - DEF dominates SL. We always have UB the first time control flow reaches DEF, + // because the storage of `a` is dead. Since DEF dominates USE, that means we cannot + // reach USE and so our optimization is ok. + // + // - DEF does not dominate SL. Then there is a `START_BLOCK -> SL` path not including DEF. + // But we can extend this path to USE, meaning there is also a `START_BLOCK -> USE` path not + // including DEF. This violates the DEF dominates USE condition, and so is impossible. + let is_constant_place = |place: Place<'_>| { + // We only allow `Deref` as the first projection, to avoid surprises. + if place.projection.first() == Some(&PlaceElem::Deref) { + // `place == (*some_local).xxx`, it is constant only if `some_local` is constant. + // We approximate constness using SSAness. + ssa.is_ssa(place.local) && place.projection[1..].iter().all(PlaceElem::is_stable_offset) + } else { + storage_live.has_single_storage(place.local) + && place.projection[..].iter().all(PlaceElem::is_stable_offset) + } + }; + + let mut can_perform_opt = |target: Place<'tcx>, loc: Location| { + if target.projection.first() == Some(&PlaceElem::Deref) { + // We are creating a reborrow. As `place.local` is a reference, removing the storage + // statements should not make it much harder for LLVM to optimize. + storage_to_remove.insert(target.local); + true + } else { + // This is a proper dereference. We can only allow it if `target` is live. + maybe_dead.seek_after_primary_effect(loc); + let maybe_dead = maybe_dead.get().contains(target.local); + !maybe_dead + } + }; + + for (local, rvalue, location) in ssa.assignments(body) { + debug!(?local); + + // Only visit if we have something to do. + let Value::Unknown = targets[local] else { bug!() }; + + let ty = body.local_decls[local].ty; + + // If this is not a reference or pointer, do nothing. + if !ty.is_any_ptr() { + debug!("not a reference or pointer"); + continue; + } + + // Whether the current local is subject to the uniqueness rule. + let needs_unique = ty.is_mutable_ptr(); + + // If this a mutable reference that we cannot fully replace, mark it as unknown. + if needs_unique && !fully_replacable_locals.contains(local) { + debug!("not fully replaceable"); + continue; + } + + debug!(?rvalue); + match rvalue { + // This is a copy, just use the value we have in store for the previous one. + // As we are visiting in `assignment_order`, ie. reverse postorder, `rhs` should + // have been visited before. + Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) + | Rvalue::CopyForDeref(place) => { + if let Some(rhs) = place.as_local() + && ssa.is_ssa(rhs) + { + let target = targets[rhs]; + // Only see through immutable reference and pointers, as we do not know yet if + // mutable references are fully replaced. + if !needs_unique && matches!(target, Value::Pointer(..)) { + targets[local] = target; + } else { + targets[local] = + Value::Pointer(tcx.mk_place_deref(rhs.into()), needs_unique); + } + } + } + Rvalue::Ref(_, _, place) | Rvalue::RawPtr(_, place) => { + let mut place = *place; + // Try to see through `place` in order to collapse reborrow chains. + if place.projection.first() == Some(&PlaceElem::Deref) + && let Value::Pointer(target, inner_needs_unique) = targets[place.local] + // Only see through immutable reference and pointers, as we do not know yet if + // mutable references are fully replaced. + && !inner_needs_unique + // Only collapse chain if the pointee is definitely live. + && can_perform_opt(target, location) + { + place = target.project_deeper(&place.projection[1..], tcx); + } + assert_ne!(place.local, local); + if is_constant_place(place) { + targets[local] = Value::Pointer(place, needs_unique); + } + } + // We do not know what to do, so keep as not-a-pointer. + _ => {} + } + } + + debug!(?targets); + + let mut finder = + ReplacementFinder { targets, can_perform_opt, allowed_replacements: FxHashSet::default() }; + let reachable_blocks = traversal::reachable_as_bitset(body); + for (bb, bbdata) in body.basic_blocks.iter_enumerated() { + // Only visit reachable blocks as we rely on dataflow. + if reachable_blocks.contains(bb) { + finder.visit_basic_block_data(bb, bbdata); + } + } + + let allowed_replacements = finder.allowed_replacements; + return Replacer { + tcx, + targets: finder.targets, + storage_to_remove, + allowed_replacements, + any_replacement: false, + }; + + struct ReplacementFinder<'tcx, F> { + targets: IndexVec<Local, Value<'tcx>>, + can_perform_opt: F, + allowed_replacements: FxHashSet<(Local, Location)>, + } + + impl<'tcx, F> Visitor<'tcx> for ReplacementFinder<'tcx, F> + where + F: FnMut(Place<'tcx>, Location) -> bool, + { + fn visit_place(&mut self, place: &Place<'tcx>, ctxt: PlaceContext, loc: Location) { + if matches!(ctxt, PlaceContext::NonUse(_)) { + // There is no need to check liveness for non-uses. + return; + } + + if place.projection.first() != Some(&PlaceElem::Deref) { + // This is not a dereference, nothing to do. + return; + } + + let mut place = place.as_ref(); + loop { + if let Value::Pointer(target, needs_unique) = self.targets[place.local] { + let perform_opt = (self.can_perform_opt)(target, loc); + debug!(?place, ?target, ?needs_unique, ?perform_opt); + + // This a reborrow chain, recursively allow the replacement. + // + // This also allows to detect cases where `target.local` is not replacable, + // and mark it as such. + if let &[PlaceElem::Deref] = &target.projection[..] { + assert!(perform_opt); + self.allowed_replacements.insert((target.local, loc)); + place.local = target.local; + continue; + } else if perform_opt { + self.allowed_replacements.insert((target.local, loc)); + } else if needs_unique { + // This mutable reference is not fully replacable, so drop it. + self.targets[place.local] = Value::Unknown; + } + } + + break; + } + } + } +} + +/// Compute the set of locals that can be fully replaced. +/// +/// We consider a local to be replacable iff it's only used in a `Deref` projection `*_local` or +/// non-use position (like storage statements and debuginfo). +fn fully_replacable_locals(ssa: &SsaLocals) -> DenseBitSet<Local> { + let mut replacable = DenseBitSet::new_empty(ssa.num_locals()); + + // First pass: for each local, whether its uses can be fully replaced. + for local in ssa.locals() { + if ssa.num_direct_uses(local) == 0 { + replacable.insert(local); + } + } + + // Second pass: a local can only be fully replaced if all its copies can. + ssa.meet_copy_equivalence(&mut replacable); + + replacable +} + +/// Utility to help performing substitution of `*pattern` by `target`. +struct Replacer<'tcx> { + tcx: TyCtxt<'tcx>, + targets: IndexVec<Local, Value<'tcx>>, + storage_to_remove: DenseBitSet<Local>, + allowed_replacements: FxHashSet<(Local, Location)>, + any_replacement: bool, +} + +impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_var_debug_info(&mut self, debuginfo: &mut VarDebugInfo<'tcx>) { + // If the debuginfo is a pointer to another place: + // - if it's a reborrow, see through it; + // - if it's a direct borrow, increase `debuginfo.references`. + while let VarDebugInfoContents::Place(ref mut place) = debuginfo.value + && place.projection.is_empty() + && let Value::Pointer(target, _) = self.targets[place.local] + && target.projection.iter().all(|p| p.can_use_in_debuginfo()) + { + if let Some((&PlaceElem::Deref, rest)) = target.projection.split_last() { + *place = Place::from(target.local).project_deeper(rest, self.tcx); + self.any_replacement = true; + } else { + break; + } + } + + // Simplify eventual projections left inside `debuginfo`. + self.super_var_debug_info(debuginfo); + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, ctxt: PlaceContext, loc: Location) { + loop { + if place.projection.first() != Some(&PlaceElem::Deref) { + return; + } + + let Value::Pointer(target, _) = self.targets[place.local] else { return }; + + let perform_opt = match ctxt { + PlaceContext::NonUse(NonUseContext::VarDebugInfo) => { + target.projection.iter().all(|p| p.can_use_in_debuginfo()) + } + PlaceContext::NonUse(_) => true, + _ => self.allowed_replacements.contains(&(target.local, loc)), + }; + + if !perform_opt { + return; + } + + *place = target.project_deeper(&place.projection[1..], self.tcx); + self.any_replacement = true; + } + } + + fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, loc: Location) { + match stmt.kind { + StatementKind::StorageLive(l) | StatementKind::StorageDead(l) + if self.storage_to_remove.contains(l) => + { + stmt.make_nop(); + } + // Do not remove assignments as they may still be useful for debuginfo. + _ => self.super_statement(stmt, loc), + } + } +} diff --git a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs new file mode 100644 index 00000000000..797056ad52d --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs @@ -0,0 +1,145 @@ +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_target::spec::PanicStrategy; +use tracing::debug; + +use crate::patch::MirPatch; + +/// A pass that removes noop landing pads and replaces jumps to them with +/// `UnwindAction::Continue`. This is important because otherwise LLVM generates +/// terrible code for these. +pub(super) struct RemoveNoopLandingPads; + +impl<'tcx> crate::MirPass<'tcx> for RemoveNoopLandingPads { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.panic_strategy() != PanicStrategy::Abort + } + + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + debug!(?def_id); + + // Skip the pass if there are no blocks with a resume terminator. + let has_resume = body + .basic_blocks + .iter_enumerated() + .any(|(_bb, block)| matches!(block.terminator().kind, TerminatorKind::UnwindResume)); + if !has_resume { + debug!("remove_noop_landing_pads: no resume block in MIR"); + return; + } + + // make sure there's a resume block without any statements + let resume_block = { + let mut patch = MirPatch::new(body); + let resume_block = patch.resume_block(); + patch.apply(body); + resume_block + }; + debug!("remove_noop_landing_pads: resume block is {:?}", resume_block); + + let mut jumps_folded = 0; + let mut landing_pads_removed = 0; + let mut nop_landing_pads = DenseBitSet::new_empty(body.basic_blocks.len()); + + // This is a post-order traversal, so that if A post-dominates B + // then A will be visited before B. + let postorder: Vec<_> = traversal::postorder(body).map(|(bb, _)| bb).collect(); + for bb in postorder { + debug!(" processing {:?}", bb); + if let Some(unwind) = body[bb].terminator_mut().unwind_mut() { + if let UnwindAction::Cleanup(unwind_bb) = *unwind { + if nop_landing_pads.contains(unwind_bb) { + debug!(" removing noop landing pad"); + landing_pads_removed += 1; + *unwind = UnwindAction::Continue; + } + } + } + + body[bb].terminator_mut().successors_mut(|target| { + if *target != resume_block && nop_landing_pads.contains(*target) { + debug!(" folding noop jump to {:?} to resume block", target); + *target = resume_block; + jumps_folded += 1; + } + }); + + let is_nop_landing_pad = self.is_nop_landing_pad(bb, body, &nop_landing_pads); + if is_nop_landing_pad { + nop_landing_pads.insert(bb); + } + debug!(" is_nop_landing_pad({:?}) = {}", bb, is_nop_landing_pad); + } + + debug!("removed {:?} jumps and {:?} landing pads", jumps_folded, landing_pads_removed); + } + + fn is_required(&self) -> bool { + true + } +} + +impl RemoveNoopLandingPads { + fn is_nop_landing_pad( + &self, + bb: BasicBlock, + body: &Body<'_>, + nop_landing_pads: &DenseBitSet<BasicBlock>, + ) -> bool { + for stmt in &body[bb].statements { + match &stmt.kind { + StatementKind::FakeRead(..) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::PlaceMention(..) + | StatementKind::AscribeUserType(..) + | StatementKind::Coverage(..) + | StatementKind::ConstEvalCounter + | StatementKind::BackwardIncompatibleDropHint { .. } + | StatementKind::Nop => { + // These are all noops in a landing pad + } + + StatementKind::Assign(box (place, Rvalue::Use(_) | Rvalue::Discriminant(_))) => { + if place.as_local().is_some() { + // Writing to a local (e.g., a drop flag) does not + // turn a landing pad to a non-nop + } else { + return false; + } + } + + StatementKind::Assign { .. } + | StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(..) + | StatementKind::Intrinsic(..) + | StatementKind::Retag { .. } => { + return false; + } + } + } + + let terminator = body[bb].terminator(); + match terminator.kind { + TerminatorKind::Goto { .. } + | TerminatorKind::UnwindResume + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => { + terminator.successors().all(|succ| nop_landing_pads.contains(succ)) + } + TerminatorKind::CoroutineDrop + | TerminatorKind::Yield { .. } + | TerminatorKind::Return + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Unreachable + | TerminatorKind::Call { .. } + | TerminatorKind::TailCall { .. } + | TerminatorKind::Assert { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::InlineAsm { .. } => false, + } + } +} diff --git a/compiler/rustc_mir_transform/src/remove_place_mention.rs b/compiler/rustc_mir_transform/src/remove_place_mention.rs new file mode 100644 index 00000000000..cb598ceb4df --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_place_mention.rs @@ -0,0 +1,27 @@ +//! This pass removes `PlaceMention` statement, which has no effect at codegen. + +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use tracing::trace; + +pub(super) struct RemovePlaceMention; + +impl<'tcx> crate::MirPass<'tcx> for RemovePlaceMention { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + !sess.opts.unstable_opts.mir_preserve_ub + } + + fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running RemovePlaceMention on {:?}", body.source); + for data in body.basic_blocks.as_mut_preserves_cfg() { + data.statements.retain(|statement| match statement.kind { + StatementKind::PlaceMention(..) | StatementKind::Nop => false, + _ => true, + }) + } + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/remove_storage_markers.rs b/compiler/rustc_mir_transform/src/remove_storage_markers.rs new file mode 100644 index 00000000000..1ae33c00968 --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_storage_markers.rs @@ -0,0 +1,29 @@ +//! This pass removes storage markers if they won't be emitted during codegen. + +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use tracing::trace; + +pub(super) struct RemoveStorageMarkers; + +impl<'tcx> crate::MirPass<'tcx> for RemoveStorageMarkers { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 && !sess.emit_lifetime_markers() + } + + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running RemoveStorageMarkers on {:?}", body.source); + for data in body.basic_blocks.as_mut_preserves_cfg() { + data.statements.retain(|statement| match statement.kind { + StatementKind::StorageLive(..) + | StatementKind::StorageDead(..) + | StatementKind::Nop => false, + _ => true, + }) + } + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/remove_uninit_drops.rs b/compiler/rustc_mir_transform/src/remove_uninit_drops.rs new file mode 100644 index 00000000000..9044a88295c --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_uninit_drops.rs @@ -0,0 +1,154 @@ +use rustc_abi::FieldIdx; +use rustc_index::bit_set::MixedBitSet; +use rustc_middle::mir::{Body, TerminatorKind}; +use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, VariantDef}; +use rustc_mir_dataflow::impls::MaybeInitializedPlaces; +use rustc_mir_dataflow::move_paths::{LookupResult, MoveData, MovePathIndex}; +use rustc_mir_dataflow::{Analysis, MaybeReachable, move_path_children_matching}; + +/// Removes `Drop` terminators whose target is known to be uninitialized at +/// that point. +/// +/// This is redundant with drop elaboration, but we need to do it prior to const-checking, and +/// running const-checking after drop elaboration makes it optimization dependent, causing issues +/// like [#90770]. +/// +/// [#90770]: https://github.com/rust-lang/rust/issues/90770 +pub(super) struct RemoveUninitDrops; + +impl<'tcx> crate::MirPass<'tcx> for RemoveUninitDrops { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let typing_env = body.typing_env(tcx); + let move_data = MoveData::gather_moves(body, tcx, |ty| ty.needs_drop(tcx, typing_env)); + + let mut maybe_inits = MaybeInitializedPlaces::new(tcx, body, &move_data) + .iterate_to_fixpoint(tcx, body, Some("remove_uninit_drops")) + .into_results_cursor(body); + + let mut to_remove = vec![]; + for (bb, block) in body.basic_blocks.iter_enumerated() { + let terminator = block.terminator(); + let TerminatorKind::Drop { place, .. } = &terminator.kind else { continue }; + + maybe_inits.seek_before_primary_effect(body.terminator_loc(bb)); + let MaybeReachable::Reachable(maybe_inits) = maybe_inits.get() else { continue }; + + // If there's no move path for the dropped place, it's probably a `Deref`. Let it alone. + let LookupResult::Exact(mpi) = move_data.rev_lookup.find(place.as_ref()) else { + continue; + }; + + let should_keep = is_needs_drop_and_init( + tcx, + typing_env, + maybe_inits, + &move_data, + place.ty(body, tcx).ty, + mpi, + ); + if !should_keep { + to_remove.push(bb) + } + } + + for bb in to_remove { + let block = &mut body.basic_blocks_mut()[bb]; + + let TerminatorKind::Drop { target, .. } = &block.terminator().kind else { + unreachable!() + }; + + // Replace block terminator with `Goto`. + block.terminator_mut().kind = TerminatorKind::Goto { target: *target }; + } + } + + fn is_required(&self) -> bool { + true + } +} + +fn is_needs_drop_and_init<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + maybe_inits: &MixedBitSet<MovePathIndex>, + move_data: &MoveData<'tcx>, + ty: Ty<'tcx>, + mpi: MovePathIndex, +) -> bool { + // No need to look deeper if the root is definitely uninit or if it has no `Drop` impl. + if !maybe_inits.contains(mpi) || !ty.needs_drop(tcx, typing_env) { + return false; + } + + let field_needs_drop_and_init = |(f, f_ty, mpi)| { + let child = move_path_children_matching(move_data, mpi, |x| x.is_field_to(f)); + let Some(mpi) = child else { + return Ty::needs_drop(f_ty, tcx, typing_env); + }; + + is_needs_drop_and_init(tcx, typing_env, maybe_inits, move_data, f_ty, mpi) + }; + + // This pass is only needed for const-checking, so it doesn't handle as many cases as + // `DropCtxt::open_drop`, since they aren't relevant in a const-context. + match ty.kind() { + ty::Adt(adt, args) => { + let dont_elaborate = adt.is_union() || adt.is_manually_drop() || adt.has_dtor(tcx); + if dont_elaborate { + return true; + } + + // Look at all our fields, or if we are an enum all our variants and their fields. + // + // If a field's projection *is not* present in `MoveData`, it has the same + // initializedness as its parent (maybe init). + // + // If its projection *is* present in `MoveData`, then the field may have been moved + // from separate from its parent. Recurse. + adt.variants().iter_enumerated().any(|(vid, variant)| { + // Enums have multiple variants, which are discriminated with a `Downcast` + // projection. Structs have a single variant, and don't use a `Downcast` + // projection. + let mpi = if adt.is_enum() { + let downcast = + move_path_children_matching(move_data, mpi, |x| x.is_downcast_to(vid)); + let Some(dc_mpi) = downcast else { + return variant_needs_drop(tcx, typing_env, args, variant); + }; + + dc_mpi + } else { + mpi + }; + + variant + .fields + .iter() + .enumerate() + .map(|(f, field)| (FieldIdx::from_usize(f), field.ty(tcx, args), mpi)) + .any(field_needs_drop_and_init) + }) + } + + ty::Tuple(fields) => fields + .iter() + .enumerate() + .map(|(f, f_ty)| (FieldIdx::from_usize(f), f_ty, mpi)) + .any(field_needs_drop_and_init), + + _ => true, + } +} + +fn variant_needs_drop<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + args: GenericArgsRef<'tcx>, + variant: &VariantDef, +) -> bool { + variant.fields.iter().any(|field| { + let f_ty = field.ty(tcx, args); + f_ty.needs_drop(tcx, typing_env) + }) +} diff --git a/compiler/rustc_mir_transform/src/remove_unneeded_drops.rs b/compiler/rustc_mir_transform/src/remove_unneeded_drops.rs new file mode 100644 index 00000000000..43f80508e4a --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_unneeded_drops.rs @@ -0,0 +1,45 @@ +//! This pass replaces a drop of a type that does not need dropping, with a goto. +//! +//! When the MIR is built, we check `needs_drop` before emitting a `Drop` for a place. This pass is +//! useful because (unlike MIR building) it runs after type checking, so it can make use of +//! `TypingMode::PostAnalysis` to provide more precise type information, especially about opaque +//! types. + +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use tracing::{debug, trace}; + +use super::simplify::simplify_cfg; + +pub(super) struct RemoveUnneededDrops; + +impl<'tcx> crate::MirPass<'tcx> for RemoveUnneededDrops { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running RemoveUnneededDrops on {:?}", body.source); + + let typing_env = body.typing_env(tcx); + let mut should_simplify = false; + for block in body.basic_blocks.as_mut() { + let terminator = block.terminator_mut(); + if let TerminatorKind::Drop { place, target, .. } = terminator.kind { + let ty = place.ty(&body.local_decls, tcx); + if ty.ty.needs_drop(tcx, typing_env) { + continue; + } + debug!("SUCCESS: replacing `drop` with goto({:?})", target); + terminator.kind = TerminatorKind::Goto { target }; + should_simplify = true; + } + } + + // if we applied optimizations, we potentially have some cfg to cleanup to + // make it easier for further passes + if should_simplify { + simplify_cfg(tcx, body); + } + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/remove_zsts.rs b/compiler/rustc_mir_transform/src/remove_zsts.rs new file mode 100644 index 00000000000..c4dc8638b26 --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_zsts.rs @@ -0,0 +1,149 @@ +//! Removes operations on ZST places, and convert ZST operands to constants. + +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; + +pub(super) struct RemoveZsts; + +impl<'tcx> crate::MirPass<'tcx> for RemoveZsts { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // Avoid query cycles (coroutines require optimized MIR for layout). + if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() { + return; + } + + let typing_env = body.typing_env(tcx); + let local_decls = &body.local_decls; + let mut replacer = Replacer { tcx, typing_env, local_decls }; + for var_debug_info in &mut body.var_debug_info { + replacer.visit_var_debug_info(var_debug_info); + } + for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { + replacer.visit_basic_block_data(bb, data); + } + } + + fn is_required(&self) -> bool { + true + } +} + +struct Replacer<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + local_decls: &'a LocalDecls<'tcx>, +} + +/// A cheap, approximate check to avoid unnecessary `layout_of` calls. +/// +/// `Some(true)` is definitely ZST; `Some(false)` is definitely *not* ZST. +/// +/// `None` may or may not be, and must check `layout_of` to be sure. +fn trivially_zst<'tcx>(ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> Option<bool> { + match ty.kind() { + // definitely ZST + ty::FnDef(..) | ty::Never => Some(true), + ty::Tuple(fields) if fields.is_empty() => Some(true), + ty::Array(_ty, len) if let Some(0) = len.try_to_target_usize(tcx) => Some(true), + // clearly not ZST + ty::Bool + | ty::Char + | ty::Int(..) + | ty::Uint(..) + | ty::Float(..) + | ty::RawPtr(..) + | ty::Ref(..) + | ty::FnPtr(..) => Some(false), + ty::Coroutine(def_id, _) => { + // For async_drop_in_place::{closure} this is load bearing, not just a perf fix, + // because we don't want to compute the layout before mir analysis is done + if tcx.is_async_drop_in_place_coroutine(*def_id) { Some(false) } else { None } + } + // check `layout_of` to see (including unreachable things we won't actually see) + _ => None, + } +} + +impl<'tcx> Replacer<'_, 'tcx> { + fn known_to_be_zst(&self, ty: Ty<'tcx>) -> bool { + if let Some(is_zst) = trivially_zst(ty, self.tcx) { + is_zst + } else { + self.tcx + .layout_of(self.typing_env.as_query_input(ty)) + .is_ok_and(|layout| layout.is_zst()) + } + } + + fn make_zst(&self, ty: Ty<'tcx>) -> ConstOperand<'tcx> { + debug_assert!(self.known_to_be_zst(ty)); + ConstOperand { + span: rustc_span::DUMMY_SP, + user_ty: None, + const_: Const::Val(ConstValue::ZeroSized, ty), + } + } +} + +impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) { + match var_debug_info.value { + VarDebugInfoContents::Const(_) => {} + VarDebugInfoContents::Place(place) => { + let place_ty = place.ty(self.local_decls, self.tcx).ty; + if self.known_to_be_zst(place_ty) { + var_debug_info.value = VarDebugInfoContents::Const(self.make_zst(place_ty)) + } + } + } + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, _: Location) { + if let Operand::Constant(_) = operand { + return; + } + let op_ty = operand.ty(self.local_decls, self.tcx); + if self.known_to_be_zst(op_ty) { + *operand = Operand::Constant(Box::new(self.make_zst(op_ty))) + } + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, loc: Location) { + let place_for_ty = match statement.kind { + StatementKind::Assign(box (place, ref rvalue)) => { + rvalue.is_safe_to_remove().then_some(place) + } + StatementKind::Deinit(box place) + | StatementKind::SetDiscriminant { box place, variant_index: _ } + | StatementKind::AscribeUserType(box (place, _), _) + | StatementKind::Retag(_, box place) + | StatementKind::PlaceMention(box place) + | StatementKind::FakeRead(box (_, place)) => Some(place), + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + Some(local.into()) + } + StatementKind::Coverage(_) + | StatementKind::Intrinsic(_) + | StatementKind::Nop + | StatementKind::BackwardIncompatibleDropHint { .. } + | StatementKind::ConstEvalCounter => None, + }; + if let Some(place_for_ty) = place_for_ty + && let ty = place_for_ty.ty(self.local_decls, self.tcx).ty + && self.known_to_be_zst(ty) + { + statement.make_nop(); + } else { + self.super_statement(statement, loc); + } + } +} diff --git a/compiler/rustc_mir_transform/src/required_consts.rs b/compiler/rustc_mir_transform/src/required_consts.rs new file mode 100644 index 00000000000..b418ede42f0 --- /dev/null +++ b/compiler/rustc_mir_transform/src/required_consts.rs @@ -0,0 +1,24 @@ +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::{Body, ConstOperand, Location, traversal}; + +pub(super) struct RequiredConstsVisitor<'tcx> { + required_consts: Vec<ConstOperand<'tcx>>, +} + +impl<'tcx> RequiredConstsVisitor<'tcx> { + pub(super) fn compute_required_consts(body: &mut Body<'tcx>) { + let mut visitor = RequiredConstsVisitor { required_consts: Vec::new() }; + for (bb, bb_data) in traversal::reverse_postorder(&body) { + visitor.visit_basic_block_data(bb, bb_data); + } + body.set_required_consts(visitor.required_consts); + } +} + +impl<'tcx> Visitor<'tcx> for RequiredConstsVisitor<'tcx> { + fn visit_const_operand(&mut self, constant: &ConstOperand<'tcx>, _: Location) { + if constant.const_.is_required_const() { + self.required_consts.push(*constant); + } + } +} diff --git a/compiler/rustc_mir_transform/src/sanity_check.rs b/compiler/rustc_mir_transform/src/sanity_check.rs new file mode 100644 index 00000000000..c9445d18162 --- /dev/null +++ b/compiler/rustc_mir_transform/src/sanity_check.rs @@ -0,0 +1,11 @@ +use rustc_middle::mir::Body; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::rustc_peek::sanity_check; + +pub(super) struct SanityCheck; + +impl<'tcx> crate::MirLint<'tcx> for SanityCheck { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + sanity_check(tcx, body); + } +} diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs new file mode 100644 index 00000000000..6d45bbc6e16 --- /dev/null +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -0,0 +1,1261 @@ +use std::assert_matches::assert_matches; +use std::{fmt, iter}; + +use rustc_abi::{ExternAbi, FIRST_VARIANT, FieldIdx, VariantIdx}; +use rustc_hir as hir; +use rustc_hir::def_id::DefId; +use rustc_hir::lang_items::LangItem; +use rustc_index::{Idx, IndexVec}; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext}; +use rustc_middle::mir::*; +use rustc_middle::query::Providers; +use rustc_middle::ty::{ + self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, GenericArgs, Ty, TyCtxt, +}; +use rustc_middle::{bug, span_bug}; +use rustc_span::source_map::{Spanned, dummy_spanned}; +use rustc_span::{DUMMY_SP, Span}; +use tracing::{debug, instrument}; + +use crate::elaborate_drop::{DropElaborator, DropFlagMode, DropStyle, Unwind, elaborate_drop}; +use crate::patch::MirPatch; +use crate::{ + abort_unwinding_calls, add_call_guards, add_moves_for_packed_drops, deref_separator, inline, + instsimplify, mentioned_items, pass_manager as pm, remove_noop_landing_pads, + run_optimization_passes, simplify, +}; + +mod async_destructor_ctor; + +pub(super) fn provide(providers: &mut Providers) { + providers.mir_shims = make_shim; +} + +// Replace Pin<&mut ImplCoroutine> accesses (_1.0) into Pin<&mut ProxyCoroutine> acceses +struct FixProxyFutureDropVisitor<'tcx> { + tcx: TyCtxt<'tcx>, + replace_to: Local, +} + +impl<'tcx> MutVisitor<'tcx> for FixProxyFutureDropVisitor<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place( + &mut self, + place: &mut Place<'tcx>, + _context: PlaceContext, + _location: Location, + ) { + if place.local == Local::from_u32(1) { + if place.projection.len() == 1 { + assert!(matches!( + place.projection.first(), + Some(ProjectionElem::Field(FieldIdx::ZERO, _)) + )); + *place = Place::from(self.replace_to); + } else if place.projection.len() == 2 { + assert!(matches!(place.projection[0], ProjectionElem::Field(FieldIdx::ZERO, _))); + assert!(matches!(place.projection[1], ProjectionElem::Deref)); + *place = + Place::from(self.replace_to).project_deeper(&[ProjectionElem::Deref], self.tcx); + } + } + } +} + +fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceKind<'tcx>) -> Body<'tcx> { + debug!("make_shim({:?})", instance); + + let mut result = match instance { + ty::InstanceKind::Item(..) => bug!("item {:?} passed to make_shim", instance), + ty::InstanceKind::VTableShim(def_id) => { + let adjustment = Adjustment::Deref { source: DerefSource::MutPtr }; + build_call_shim(tcx, instance, Some(adjustment), CallKind::Direct(def_id)) + } + ty::InstanceKind::FnPtrShim(def_id, ty) => { + let trait_ = tcx.trait_of_item(def_id).unwrap(); + // Supports `Fn` or `async Fn` traits. + let adjustment = match tcx + .fn_trait_kind_from_def_id(trait_) + .or_else(|| tcx.async_fn_trait_kind_from_def_id(trait_)) + { + Some(ty::ClosureKind::FnOnce) => Adjustment::Identity, + Some(ty::ClosureKind::Fn) => Adjustment::Deref { source: DerefSource::ImmRef }, + Some(ty::ClosureKind::FnMut) => Adjustment::Deref { source: DerefSource::MutRef }, + None => bug!("fn pointer {:?} is not an fn", ty), + }; + + build_call_shim(tcx, instance, Some(adjustment), CallKind::Indirect(ty)) + } + // We are generating a call back to our def-id, which the + // codegen backend knows to turn to an actual call, be it + // a virtual call, or a direct call to a function for which + // indirect calls must be codegen'd differently than direct ones + // (such as `#[track_caller]`). + ty::InstanceKind::ReifyShim(def_id, _) => { + build_call_shim(tcx, instance, None, CallKind::Direct(def_id)) + } + ty::InstanceKind::ClosureOnceShim { call_once: _, track_caller: _ } => { + let fn_mut = tcx.require_lang_item(LangItem::FnMut, DUMMY_SP); + let call_mut = tcx + .associated_items(fn_mut) + .in_definition_order() + .find(|it| it.is_fn()) + .unwrap() + .def_id; + + build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut)) + } + + ty::InstanceKind::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + receiver_by_ref, + } => build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id, receiver_by_ref), + + ty::InstanceKind::DropGlue(def_id, ty) => { + // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end + // of this function. Is this intentional? + if let Some(&ty::Coroutine(coroutine_def_id, args)) = ty.map(Ty::kind) { + let coroutine_body = tcx.optimized_mir(coroutine_def_id); + + let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind() + else { + bug!() + }; + + // If this is a regular coroutine, grab its drop shim. If this is a coroutine + // that comes from a coroutine-closure, and the kind ty differs from the "maximum" + // kind that it supports, then grab the appropriate drop shim. This ensures that + // the future returned by `<[coroutine-closure] as AsyncFnOnce>::call_once` will + // drop the coroutine-closure's upvars. + let body = if id_args.as_coroutine().kind_ty() == args.as_coroutine().kind_ty() { + coroutine_body.coroutine_drop().unwrap() + } else { + assert_eq!( + args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(), + ty::ClosureKind::FnOnce + ); + tcx.optimized_mir(tcx.coroutine_by_move_body_def_id(coroutine_def_id)) + .coroutine_drop() + .unwrap() + }; + + let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args); + debug!("make_shim({:?}) = {:?}", instance, body); + + pm::run_passes( + tcx, + &mut body, + &[ + &mentioned_items::MentionedItems, + &abort_unwinding_calls::AbortUnwindingCalls, + &add_call_guards::CriticalCallEdges, + ], + Some(MirPhase::Runtime(RuntimePhase::Optimized)), + pm::Optimizations::Allowed, + ); + + return body; + } + + build_drop_shim(tcx, def_id, ty) + } + ty::InstanceKind::ThreadLocalShim(..) => build_thread_local_shim(tcx, instance), + ty::InstanceKind::CloneShim(def_id, ty) => build_clone_shim(tcx, def_id, ty), + ty::InstanceKind::FnPtrAddrShim(def_id, ty) => build_fn_ptr_addr_shim(tcx, def_id, ty), + ty::InstanceKind::FutureDropPollShim(def_id, proxy_ty, impl_ty) => { + let mut body = + async_destructor_ctor::build_future_drop_poll_shim(tcx, def_id, proxy_ty, impl_ty); + + pm::run_passes( + tcx, + &mut body, + &[ + &mentioned_items::MentionedItems, + &abort_unwinding_calls::AbortUnwindingCalls, + &add_call_guards::CriticalCallEdges, + ], + Some(MirPhase::Runtime(RuntimePhase::PostCleanup)), + pm::Optimizations::Allowed, + ); + run_optimization_passes(tcx, &mut body); + debug!("make_shim({:?}) = {:?}", instance, body); + return body; + } + ty::InstanceKind::AsyncDropGlue(def_id, ty) => { + let mut body = async_destructor_ctor::build_async_drop_shim(tcx, def_id, ty); + + // Main pass required here is StateTransform to convert sync drop ladder + // into coroutine. + // Others are minimal passes as for sync drop glue shim + pm::run_passes( + tcx, + &mut body, + &[ + &mentioned_items::MentionedItems, + &abort_unwinding_calls::AbortUnwindingCalls, + &add_call_guards::CriticalCallEdges, + &simplify::SimplifyCfg::MakeShim, + &crate::coroutine::StateTransform, + ], + Some(MirPhase::Runtime(RuntimePhase::PostCleanup)), + pm::Optimizations::Allowed, + ); + run_optimization_passes(tcx, &mut body); + debug!("make_shim({:?}) = {:?}", instance, body); + return body; + } + + ty::InstanceKind::AsyncDropGlueCtorShim(def_id, ty) => { + let body = async_destructor_ctor::build_async_destructor_ctor_shim(tcx, def_id, ty); + debug!("make_shim({:?}) = {:?}", instance, body); + return body; + } + ty::InstanceKind::Virtual(..) => { + bug!("InstanceKind::Virtual ({:?}) is for direct calls only", instance) + } + ty::InstanceKind::Intrinsic(_) => { + bug!("creating shims from intrinsics ({:?}) is unsupported", instance) + } + }; + debug!("make_shim({:?}) = untransformed {:?}", instance, result); + + // We don't validate MIR here because the shims may generate code that's + // only valid in a `PostAnalysis` param-env. However, since we do initial + // validation with the MirBuilt phase, which uses a user-facing param-env. + // This causes validation errors when TAITs are involved. + pm::run_passes_no_validate( + tcx, + &mut result, + &[ + &mentioned_items::MentionedItems, + &add_moves_for_packed_drops::AddMovesForPackedDrops, + &deref_separator::Derefer, + &remove_noop_landing_pads::RemoveNoopLandingPads, + &simplify::SimplifyCfg::MakeShim, + &instsimplify::InstSimplify::BeforeInline, + // Perform inlining of `#[rustc_force_inline]`-annotated callees. + &inline::ForceInline, + &abort_unwinding_calls::AbortUnwindingCalls, + &add_call_guards::CriticalCallEdges, + ], + Some(MirPhase::Runtime(RuntimePhase::Optimized)), + ); + + debug!("make_shim({:?}) = {:?}", instance, result); + + result +} + +#[derive(Copy, Clone, Debug, PartialEq)] +enum DerefSource { + /// `fn shim(&self) { inner(*self )}`. + ImmRef, + /// `fn shim(&mut self) { inner(*self )}`. + MutRef, + /// `fn shim(*mut self) { inner(*self )}`. + MutPtr, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +enum Adjustment { + /// Pass the receiver as-is. + Identity, + + /// We get passed a reference or a raw pointer to `self` and call the target with `*self`. + /// + /// This either copies `self` (if `Self: Copy`, eg. for function items), or moves out of it + /// (for `VTableShim`, which effectively is passed `&own Self`). + Deref { source: DerefSource }, + + /// We get passed `self: Self` and call the target with `&mut self`. + /// + /// In this case we need to ensure that the `Self` is dropped after the call, as the callee + /// won't do it for us. + RefMut, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +enum CallKind<'tcx> { + /// Call the `FnPtr` that was passed as the receiver. + Indirect(Ty<'tcx>), + + /// Call a known `FnDef`. + Direct(DefId), +} + +fn local_decls_for_sig<'tcx>( + sig: &ty::FnSig<'tcx>, + span: Span, +) -> IndexVec<Local, LocalDecl<'tcx>> { + iter::once(LocalDecl::new(sig.output(), span)) + .chain(sig.inputs().iter().map(|ity| LocalDecl::new(*ity, span).immutable())) + .collect() +} + +fn dropee_emit_retag<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + mut dropee_ptr: Place<'tcx>, + span: Span, +) -> Place<'tcx> { + if tcx.sess.opts.unstable_opts.mir_emit_retag { + let source_info = SourceInfo::outermost(span); + // We want to treat the function argument as if it was passed by `&mut`. As such, we + // generate + // ``` + // temp = &mut *arg; + // Retag(temp, FnEntry) + // ``` + // It's important that we do this first, before anything that depends on `dropee_ptr` + // has been put into the body. + let reborrow = Rvalue::Ref( + tcx.lifetimes.re_erased, + BorrowKind::Mut { kind: MutBorrowKind::Default }, + tcx.mk_place_deref(dropee_ptr), + ); + let ref_ty = reborrow.ty(body.local_decls(), tcx); + dropee_ptr = body.local_decls.push(LocalDecl::new(ref_ty, span)).into(); + let new_statements = [ + StatementKind::Assign(Box::new((dropee_ptr, reborrow))), + StatementKind::Retag(RetagKind::FnEntry, Box::new(dropee_ptr)), + ]; + for s in new_statements { + body.basic_blocks_mut()[START_BLOCK] + .statements + .push(Statement { source_info, kind: s }); + } + } + dropee_ptr +} + +fn build_drop_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, ty: Option<Ty<'tcx>>) -> Body<'tcx> { + debug!("build_drop_shim(def_id={:?}, ty={:?})", def_id, ty); + + assert!(!matches!(ty, Some(ty) if ty.is_coroutine())); + + let args = if let Some(ty) = ty { + tcx.mk_args(&[ty.into()]) + } else { + GenericArgs::identity_for_item(tcx, def_id) + }; + let sig = tcx.fn_sig(def_id).instantiate(tcx, args); + let sig = tcx.instantiate_bound_regions_with_erased(sig); + let span = tcx.def_span(def_id); + + let source_info = SourceInfo::outermost(span); + + let return_block = BasicBlock::new(1); + let mut blocks = IndexVec::with_capacity(2); + let block = |blocks: &mut IndexVec<_, _>, kind| { + blocks.push(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { source_info, kind }), + is_cleanup: false, + }) + }; + block(&mut blocks, TerminatorKind::Goto { target: return_block }); + block(&mut blocks, TerminatorKind::Return); + + let source = MirSource::from_instance(ty::InstanceKind::DropGlue(def_id, ty)); + let mut body = + new_body(source, blocks, local_decls_for_sig(&sig, span), sig.inputs().len(), span); + + // The first argument (index 0), but add 1 for the return value. + let dropee_ptr = Place::from(Local::new(1 + 0)); + let dropee_ptr = dropee_emit_retag(tcx, &mut body, dropee_ptr, span); + + if ty.is_some() { + let patch = { + let typing_env = ty::TypingEnv::post_analysis(tcx, def_id); + let mut elaborator = DropShimElaborator { + body: &body, + patch: MirPatch::new(&body), + tcx, + typing_env, + produce_async_drops: false, + }; + let dropee = tcx.mk_place_deref(dropee_ptr); + let resume_block = elaborator.patch.resume_block(); + elaborate_drop( + &mut elaborator, + source_info, + dropee, + (), + return_block, + Unwind::To(resume_block), + START_BLOCK, + None, + ); + elaborator.patch + }; + patch.apply(&mut body); + } + + body +} + +fn new_body<'tcx>( + source: MirSource<'tcx>, + basic_blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>>, + local_decls: IndexVec<Local, LocalDecl<'tcx>>, + arg_count: usize, + span: Span, +) -> Body<'tcx> { + let mut body = Body::new( + source, + basic_blocks, + IndexVec::from_elem_n( + SourceScopeData { + span, + parent_scope: None, + inlined: None, + inlined_parent_scope: None, + local_data: ClearCrossCrate::Clear, + }, + 1, + ), + local_decls, + IndexVec::new(), + arg_count, + vec![], + span, + None, + // FIXME(compiler-errors): is this correct? + None, + ); + // Shims do not directly mention any consts. + body.set_required_consts(Vec::new()); + body +} + +pub(super) struct DropShimElaborator<'a, 'tcx> { + pub body: &'a Body<'tcx>, + pub patch: MirPatch<'tcx>, + pub tcx: TyCtxt<'tcx>, + pub typing_env: ty::TypingEnv<'tcx>, + pub produce_async_drops: bool, +} + +impl fmt::Debug for DropShimElaborator<'_, '_> { + fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + Ok(()) + } +} + +impl<'a, 'tcx> DropElaborator<'a, 'tcx> for DropShimElaborator<'a, 'tcx> { + type Path = (); + + fn patch_ref(&self) -> &MirPatch<'tcx> { + &self.patch + } + fn patch(&mut self) -> &mut MirPatch<'tcx> { + &mut self.patch + } + fn body(&self) -> &'a Body<'tcx> { + self.body + } + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + fn typing_env(&self) -> ty::TypingEnv<'tcx> { + self.typing_env + } + + fn terminator_loc(&self, bb: BasicBlock) -> Location { + self.patch.terminator_loc(self.body, bb) + } + fn allow_async_drops(&self) -> bool { + self.produce_async_drops + } + + fn drop_style(&self, _path: Self::Path, mode: DropFlagMode) -> DropStyle { + match mode { + DropFlagMode::Shallow => { + // Drops for the contained fields are "shallow" and "static" - they will simply call + // the field's own drop glue. + DropStyle::Static + } + DropFlagMode::Deep => { + // The top-level drop is "deep" and "open" - it will be elaborated to a drop ladder + // dropping each field contained in the value. + DropStyle::Open + } + } + } + + fn get_drop_flag(&mut self, _path: Self::Path) -> Option<Operand<'tcx>> { + None + } + + fn clear_drop_flag(&mut self, _location: Location, _path: Self::Path, _mode: DropFlagMode) {} + + fn field_subpath(&self, _path: Self::Path, _field: FieldIdx) -> Option<Self::Path> { + None + } + fn deref_subpath(&self, _path: Self::Path) -> Option<Self::Path> { + None + } + fn downcast_subpath(&self, _path: Self::Path, _variant: VariantIdx) -> Option<Self::Path> { + Some(()) + } + fn array_subpath(&self, _path: Self::Path, _index: u64, _size: u64) -> Option<Self::Path> { + None + } +} + +fn build_thread_local_shim<'tcx>( + tcx: TyCtxt<'tcx>, + instance: ty::InstanceKind<'tcx>, +) -> Body<'tcx> { + let def_id = instance.def_id(); + + let span = tcx.def_span(def_id); + let source_info = SourceInfo::outermost(span); + + let blocks = IndexVec::from_raw(vec![BasicBlockData { + statements: vec![Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::ThreadLocalRef(def_id), + ))), + }], + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }]); + + new_body( + MirSource::from_instance(instance), + blocks, + IndexVec::from_raw(vec![LocalDecl::new(tcx.thread_local_ptr_ty(def_id), span)]), + 0, + span, + ) +} + +/// Builds a `Clone::clone` shim for `self_ty`. Here, `def_id` is `Clone::clone`. +fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Body<'tcx> { + debug!("build_clone_shim(def_id={:?})", def_id); + + let mut builder = CloneShimBuilder::new(tcx, def_id, self_ty); + + let dest = Place::return_place(); + let src = tcx.mk_place_deref(Place::from(Local::new(1 + 0))); + + match self_ty.kind() { + ty::FnDef(..) | ty::FnPtr(..) => builder.copy_shim(), + ty::Closure(_, args) => builder.tuple_like_shim(dest, src, args.as_closure().upvar_tys()), + ty::CoroutineClosure(_, args) => { + builder.tuple_like_shim(dest, src, args.as_coroutine_closure().upvar_tys()) + } + ty::Tuple(..) => builder.tuple_like_shim(dest, src, self_ty.tuple_fields()), + ty::Coroutine(coroutine_def_id, args) => { + assert_eq!(tcx.coroutine_movability(*coroutine_def_id), hir::Movability::Movable); + builder.coroutine_shim(dest, src, *coroutine_def_id, args.as_coroutine()) + } + _ => bug!("clone shim for `{:?}` which is not `Copy` and is not an aggregate", self_ty), + }; + + builder.into_mir() +} + +struct CloneShimBuilder<'tcx> { + tcx: TyCtxt<'tcx>, + def_id: DefId, + local_decls: IndexVec<Local, LocalDecl<'tcx>>, + blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>>, + span: Span, + sig: ty::FnSig<'tcx>, +} + +impl<'tcx> CloneShimBuilder<'tcx> { + fn new(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Self { + // we must instantiate the self_ty because it's + // otherwise going to be TySelf and we can't index + // or access fields of a Place of type TySelf. + let sig = tcx.fn_sig(def_id).instantiate(tcx, &[self_ty.into()]); + let sig = tcx.instantiate_bound_regions_with_erased(sig); + let span = tcx.def_span(def_id); + + CloneShimBuilder { + tcx, + def_id, + local_decls: local_decls_for_sig(&sig, span), + blocks: IndexVec::new(), + span, + sig, + } + } + + fn into_mir(self) -> Body<'tcx> { + let source = MirSource::from_instance(ty::InstanceKind::CloneShim( + self.def_id, + self.sig.inputs_and_output[0], + )); + new_body(source, self.blocks, self.local_decls, self.sig.inputs().len(), self.span) + } + + fn source_info(&self) -> SourceInfo { + SourceInfo::outermost(self.span) + } + + fn block( + &mut self, + statements: Vec<Statement<'tcx>>, + kind: TerminatorKind<'tcx>, + is_cleanup: bool, + ) -> BasicBlock { + let source_info = self.source_info(); + self.blocks.push(BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind }), + is_cleanup, + }) + } + + /// Gives the index of an upcoming BasicBlock, with an offset. + /// offset=0 will give you the index of the next BasicBlock, + /// offset=1 will give the index of the next-to-next block, + /// offset=-1 will give you the index of the last-created block + fn block_index_offset(&self, offset: usize) -> BasicBlock { + BasicBlock::new(self.blocks.len() + offset) + } + + fn make_statement(&self, kind: StatementKind<'tcx>) -> Statement<'tcx> { + Statement { source_info: self.source_info(), kind } + } + + fn copy_shim(&mut self) { + let rcvr = self.tcx.mk_place_deref(Place::from(Local::new(1 + 0))); + let ret_statement = self.make_statement(StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Use(Operand::Copy(rcvr)), + )))); + self.block(vec![ret_statement], TerminatorKind::Return, false); + } + + fn make_place(&mut self, mutability: Mutability, ty: Ty<'tcx>) -> Place<'tcx> { + let span = self.span; + let mut local = LocalDecl::new(ty, span); + if mutability.is_not() { + local = local.immutable(); + } + Place::from(self.local_decls.push(local)) + } + + fn make_clone_call( + &mut self, + dest: Place<'tcx>, + src: Place<'tcx>, + ty: Ty<'tcx>, + next: BasicBlock, + cleanup: BasicBlock, + ) { + let tcx = self.tcx; + + // `func == Clone::clone(&ty) -> ty` + let func_ty = Ty::new_fn_def(tcx, self.def_id, [ty]); + let func = Operand::Constant(Box::new(ConstOperand { + span: self.span, + user_ty: None, + const_: Const::zero_sized(func_ty), + })); + + let ref_loc = + self.make_place(Mutability::Not, Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, ty)); + + // `let ref_loc: &ty = &src;` + let statement = self.make_statement(StatementKind::Assign(Box::new(( + ref_loc, + Rvalue::Ref(tcx.lifetimes.re_erased, BorrowKind::Shared, src), + )))); + + // `let loc = Clone::clone(ref_loc);` + self.block( + vec![statement], + TerminatorKind::Call { + func, + args: [Spanned { node: Operand::Move(ref_loc), span: DUMMY_SP }].into(), + destination: dest, + target: Some(next), + unwind: UnwindAction::Cleanup(cleanup), + call_source: CallSource::Normal, + fn_span: self.span, + }, + false, + ); + } + + fn clone_fields<I>( + &mut self, + dest: Place<'tcx>, + src: Place<'tcx>, + target: BasicBlock, + mut unwind: BasicBlock, + tys: I, + ) -> BasicBlock + where + I: IntoIterator<Item = Ty<'tcx>>, + { + // For an iterator of length n, create 2*n + 1 blocks. + for (i, ity) in tys.into_iter().enumerate() { + // Each iteration creates two blocks, referred to here as block 2*i and block 2*i + 1. + // + // Block 2*i attempts to clone the field. If successful it branches to 2*i + 2 (the + // next clone block). If unsuccessful it branches to the previous unwind block, which + // is initially the `unwind` argument passed to this function. + // + // Block 2*i + 1 is the unwind block for this iteration. It drops the cloned value + // created by block 2*i. We store this block in `unwind` so that the next clone block + // will unwind to it if cloning fails. + + let field = FieldIdx::new(i); + let src_field = self.tcx.mk_place_field(src, field, ity); + + let dest_field = self.tcx.mk_place_field(dest, field, ity); + + let next_unwind = self.block_index_offset(1); + let next_block = self.block_index_offset(2); + self.make_clone_call(dest_field, src_field, ity, next_block, unwind); + self.block( + vec![], + TerminatorKind::Drop { + place: dest_field, + target: unwind, + unwind: UnwindAction::Terminate(UnwindTerminateReason::InCleanup), + replace: false, + drop: None, + async_fut: None, + }, + /* is_cleanup */ true, + ); + unwind = next_unwind; + } + // If all clones succeed then we end up here. + self.block(vec![], TerminatorKind::Goto { target }, false); + unwind + } + + fn tuple_like_shim<I>(&mut self, dest: Place<'tcx>, src: Place<'tcx>, tys: I) + where + I: IntoIterator<Item = Ty<'tcx>>, + { + self.block(vec![], TerminatorKind::Goto { target: self.block_index_offset(3) }, false); + let unwind = self.block(vec![], TerminatorKind::UnwindResume, true); + let target = self.block(vec![], TerminatorKind::Return, false); + + let _final_cleanup_block = self.clone_fields(dest, src, target, unwind, tys); + } + + fn coroutine_shim( + &mut self, + dest: Place<'tcx>, + src: Place<'tcx>, + coroutine_def_id: DefId, + args: CoroutineArgs<TyCtxt<'tcx>>, + ) { + self.block(vec![], TerminatorKind::Goto { target: self.block_index_offset(3) }, false); + let unwind = self.block(vec![], TerminatorKind::UnwindResume, true); + // This will get overwritten with a switch once we know the target blocks + let switch = self.block(vec![], TerminatorKind::Unreachable, false); + let unwind = self.clone_fields(dest, src, switch, unwind, args.upvar_tys()); + let target = self.block(vec![], TerminatorKind::Return, false); + let unreachable = self.block(vec![], TerminatorKind::Unreachable, false); + let mut cases = Vec::with_capacity(args.state_tys(coroutine_def_id, self.tcx).count()); + for (index, state_tys) in args.state_tys(coroutine_def_id, self.tcx).enumerate() { + let variant_index = VariantIdx::new(index); + let dest = self.tcx.mk_place_downcast_unnamed(dest, variant_index); + let src = self.tcx.mk_place_downcast_unnamed(src, variant_index); + let clone_block = self.block_index_offset(1); + let start_block = self.block( + vec![self.make_statement(StatementKind::SetDiscriminant { + place: Box::new(Place::return_place()), + variant_index, + })], + TerminatorKind::Goto { target: clone_block }, + false, + ); + cases.push((index as u128, start_block)); + let _final_cleanup_block = self.clone_fields(dest, src, target, unwind, state_tys); + } + let discr_ty = args.discr_ty(self.tcx); + let temp = self.make_place(Mutability::Mut, discr_ty); + let rvalue = Rvalue::Discriminant(src); + let statement = self.make_statement(StatementKind::Assign(Box::new((temp, rvalue)))); + match &mut self.blocks[switch] { + BasicBlockData { statements, terminator: Some(Terminator { kind, .. }), .. } => { + statements.push(statement); + *kind = TerminatorKind::SwitchInt { + discr: Operand::Move(temp), + targets: SwitchTargets::new(cases.into_iter(), unreachable), + }; + } + BasicBlockData { terminator: None, .. } => unreachable!(), + } + } +} + +/// Builds a "call" shim for `instance`. The shim calls the function specified by `call_kind`, +/// first adjusting its first argument according to `rcvr_adjustment`. +#[instrument(level = "debug", skip(tcx), ret)] +fn build_call_shim<'tcx>( + tcx: TyCtxt<'tcx>, + instance: ty::InstanceKind<'tcx>, + rcvr_adjustment: Option<Adjustment>, + call_kind: CallKind<'tcx>, +) -> Body<'tcx> { + // `FnPtrShim` contains the fn pointer type that a call shim is being built for - this is used + // to instantiate into the signature of the shim. It is not necessary for users of this + // MIR body to perform further instantiations (see `InstanceKind::has_polymorphic_mir_body`). + let (sig_args, untuple_args) = if let ty::InstanceKind::FnPtrShim(_, ty) = instance { + let sig = tcx.instantiate_bound_regions_with_erased(ty.fn_sig(tcx)); + + let untuple_args = sig.inputs(); + + // Create substitutions for the `Self` and `Args` generic parameters of the shim body. + let arg_tup = Ty::new_tup(tcx, untuple_args); + + (Some([ty.into(), arg_tup.into()]), Some(untuple_args)) + } else { + (None, None) + }; + + let def_id = instance.def_id(); + + let sig = tcx.fn_sig(def_id); + let sig = sig.map_bound(|sig| tcx.instantiate_bound_regions_with_erased(sig)); + + assert_eq!(sig_args.is_some(), !instance.has_polymorphic_mir_body()); + let mut sig = if let Some(sig_args) = sig_args { + sig.instantiate(tcx, &sig_args) + } else { + sig.instantiate_identity() + }; + + if let CallKind::Indirect(fnty) = call_kind { + // `sig` determines our local decls, and thus the callee type in the `Call` terminator. This + // can only be an `FnDef` or `FnPtr`, but currently will be `Self` since the types come from + // the implemented `FnX` trait. + + // Apply the opposite adjustment to the MIR input. + let mut inputs_and_output = sig.inputs_and_output.to_vec(); + + // Initial signature is `fn(&? Self, Args) -> Self::Output` where `Args` is a tuple of the + // fn arguments. `Self` may be passed via (im)mutable reference or by-value. + assert_eq!(inputs_and_output.len(), 3); + + // `Self` is always the original fn type `ty`. The MIR call terminator is only defined for + // `FnDef` and `FnPtr` callees, not the `Self` type param. + let self_arg = &mut inputs_and_output[0]; + *self_arg = match rcvr_adjustment.unwrap() { + Adjustment::Identity => fnty, + Adjustment::Deref { source } => match source { + DerefSource::ImmRef => Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, fnty), + DerefSource::MutRef => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, fnty), + DerefSource::MutPtr => Ty::new_mut_ptr(tcx, fnty), + }, + Adjustment::RefMut => bug!("`RefMut` is never used with indirect calls: {instance:?}"), + }; + sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output); + } + + // FIXME: Avoid having to adjust the signature both here and in + // `fn_sig_for_fn_abi`. + if let ty::InstanceKind::VTableShim(..) = instance { + // Modify fn(self, ...) to fn(self: *mut Self, ...) + let mut inputs_and_output = sig.inputs_and_output.to_vec(); + let self_arg = &mut inputs_and_output[0]; + debug_assert!(tcx.generics_of(def_id).has_self && *self_arg == tcx.types.self_param); + *self_arg = Ty::new_mut_ptr(tcx, *self_arg); + sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output); + } + + let span = tcx.def_span(def_id); + + debug!(?sig); + + let mut local_decls = local_decls_for_sig(&sig, span); + let source_info = SourceInfo::outermost(span); + + let destination = Place::return_place(); + + let rcvr_place = || { + assert!(rcvr_adjustment.is_some()); + Place::from(Local::new(1)) + }; + let mut statements = vec![]; + + let rcvr = rcvr_adjustment.map(|rcvr_adjustment| match rcvr_adjustment { + Adjustment::Identity => Operand::Move(rcvr_place()), + Adjustment::Deref { source: _ } => Operand::Move(tcx.mk_place_deref(rcvr_place())), + Adjustment::RefMut => { + // let rcvr = &mut rcvr; + let ref_rcvr = local_decls.push( + LocalDecl::new( + Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, sig.inputs()[0]), + span, + ) + .immutable(), + ); + let borrow_kind = BorrowKind::Mut { kind: MutBorrowKind::Default }; + statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::from(ref_rcvr), + Rvalue::Ref(tcx.lifetimes.re_erased, borrow_kind, rcvr_place()), + ))), + }); + Operand::Move(Place::from(ref_rcvr)) + } + }); + + let (callee, mut args) = match call_kind { + // `FnPtr` call has no receiver. Args are untupled below. + CallKind::Indirect(_) => (rcvr.unwrap(), vec![]), + + // `FnDef` call with optional receiver. + CallKind::Direct(def_id) => { + let ty = tcx.type_of(def_id).instantiate_identity(); + ( + Operand::Constant(Box::new(ConstOperand { + span, + user_ty: None, + const_: Const::zero_sized(ty), + })), + rcvr.into_iter().collect::<Vec<_>>(), + ) + } + }; + + let mut arg_range = 0..sig.inputs().len(); + + // Take the `self` ("receiver") argument out of the range (it's adjusted above). + if rcvr_adjustment.is_some() { + arg_range.start += 1; + } + + // Take the last argument, if we need to untuple it (handled below). + if untuple_args.is_some() { + arg_range.end -= 1; + } + + // Pass all of the non-special arguments directly. + args.extend(arg_range.map(|i| Operand::Move(Place::from(Local::new(1 + i))))); + + // Untuple the last argument, if we have to. + if let Some(untuple_args) = untuple_args { + let tuple_arg = Local::new(1 + (sig.inputs().len() - 1)); + args.extend(untuple_args.iter().enumerate().map(|(i, ity)| { + Operand::Move(tcx.mk_place_field(Place::from(tuple_arg), FieldIdx::new(i), *ity)) + })); + } + + let n_blocks = if let Some(Adjustment::RefMut) = rcvr_adjustment { 5 } else { 2 }; + let mut blocks = IndexVec::with_capacity(n_blocks); + let block = |blocks: &mut IndexVec<_, _>, statements, kind, is_cleanup| { + blocks.push(BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind }), + is_cleanup, + }) + }; + + // BB #0 + let args = args.into_iter().map(|a| Spanned { node: a, span: DUMMY_SP }).collect(); + block( + &mut blocks, + statements, + TerminatorKind::Call { + func: callee, + args, + destination, + target: Some(BasicBlock::new(1)), + unwind: if let Some(Adjustment::RefMut) = rcvr_adjustment { + UnwindAction::Cleanup(BasicBlock::new(3)) + } else { + UnwindAction::Continue + }, + call_source: CallSource::Misc, + fn_span: span, + }, + false, + ); + + if let Some(Adjustment::RefMut) = rcvr_adjustment { + // BB #1 - drop for Self + block( + &mut blocks, + vec![], + TerminatorKind::Drop { + place: rcvr_place(), + target: BasicBlock::new(2), + unwind: UnwindAction::Continue, + replace: false, + drop: None, + async_fut: None, + }, + false, + ); + } + // BB #1/#2 - return + let stmts = vec![]; + block(&mut blocks, stmts, TerminatorKind::Return, false); + if let Some(Adjustment::RefMut) = rcvr_adjustment { + // BB #3 - drop if closure panics + block( + &mut blocks, + vec![], + TerminatorKind::Drop { + place: rcvr_place(), + target: BasicBlock::new(4), + unwind: UnwindAction::Terminate(UnwindTerminateReason::InCleanup), + replace: false, + drop: None, + async_fut: None, + }, + /* is_cleanup */ true, + ); + + // BB #4 - resume + block(&mut blocks, vec![], TerminatorKind::UnwindResume, true); + } + + let mut body = + new_body(MirSource::from_instance(instance), blocks, local_decls, sig.inputs().len(), span); + + if let ExternAbi::RustCall = sig.abi { + body.spread_arg = Some(Local::new(sig.inputs().len())); + } + + body +} + +pub(super) fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> { + debug_assert!(tcx.is_constructor(ctor_id)); + + let typing_env = ty::TypingEnv::post_analysis(tcx, ctor_id); + + // Normalize the sig. + let sig = tcx + .fn_sig(ctor_id) + .instantiate_identity() + .no_bound_vars() + .expect("LBR in ADT constructor signature"); + let sig = tcx.normalize_erasing_regions(typing_env, sig); + + let ty::Adt(adt_def, args) = sig.output().kind() else { + bug!("unexpected type for ADT ctor {:?}", sig.output()); + }; + + debug!("build_ctor: ctor_id={:?} sig={:?}", ctor_id, sig); + + let span = tcx.def_span(ctor_id); + + let local_decls = local_decls_for_sig(&sig, span); + + let source_info = SourceInfo::outermost(span); + + let variant_index = + if adt_def.is_enum() { adt_def.variant_index_with_ctor_id(ctor_id) } else { FIRST_VARIANT }; + + // Generate the following MIR: + // + // (return as Variant).field0 = arg0; + // (return as Variant).field1 = arg1; + // + // return; + debug!("build_ctor: variant_index={:?}", variant_index); + + let kind = AggregateKind::Adt(adt_def.did(), variant_index, args, None, None); + let variant = adt_def.variant(variant_index); + let statement = Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate( + Box::new(kind), + (0..variant.fields.len()) + .map(|idx| Operand::Move(Place::from(Local::new(idx + 1)))) + .collect(), + ), + ))), + source_info, + }; + + let start_block = BasicBlockData { + statements: vec![statement], + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }; + + let source = MirSource::item(ctor_id); + let mut body = new_body( + source, + IndexVec::from_elem_n(start_block, 1), + local_decls, + sig.inputs().len(), + span, + ); + // A constructor doesn't mention any other items (and we don't run the usual optimization passes + // so this would otherwise not get filled). + body.set_mentioned_items(Vec::new()); + + crate::pass_manager::dump_mir_for_phase_change(tcx, &body); + + body +} + +/// ```ignore (pseudo-impl) +/// impl FnPtr for fn(u32) { +/// fn addr(self) -> usize { +/// self as usize +/// } +/// } +/// ``` +fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Body<'tcx> { + assert_matches!(self_ty.kind(), ty::FnPtr(..), "expected fn ptr, found {self_ty}"); + let span = tcx.def_span(def_id); + let Some(sig) = tcx.fn_sig(def_id).instantiate(tcx, &[self_ty.into()]).no_bound_vars() else { + span_bug!(span, "FnPtr::addr with bound vars for `{self_ty}`"); + }; + let locals = local_decls_for_sig(&sig, span); + + let source_info = SourceInfo::outermost(span); + // FIXME: use `expose_provenance` once we figure out whether function pointers have meaningful + // provenance. + let rvalue = Rvalue::Cast( + CastKind::FnPtrToPtr, + Operand::Move(Place::from(Local::new(1))), + Ty::new_imm_ptr(tcx, tcx.types.unit), + ); + let stmt = Statement { + source_info, + kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))), + }; + let statements = vec![stmt]; + let start_block = BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }; + let source = MirSource::from_instance(ty::InstanceKind::FnPtrAddrShim(def_id, self_ty)); + new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span) +} + +fn build_construct_coroutine_by_move_shim<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_closure_def_id: DefId, + receiver_by_ref: bool, +) -> Body<'tcx> { + let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let mut self_local: Place<'tcx> = Local::from_usize(1).into(); + let ty::CoroutineClosure(_, args) = *self_ty.kind() else { + bug!(); + }; + + // We use `&Self` here because we only need to emit an ABI-compatible shim body, + // rather than match the signature exactly (which might take `&mut self` instead). + // + // We adjust the `self_local` to be a deref since we want to copy fields out of + // a reference to the closure. + if receiver_by_ref { + self_local = tcx.mk_place_deref(self_local); + self_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, self_ty); + } + + let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { + tcx.mk_fn_sig( + [self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()), + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.as_coroutine_closure().parent_args(), + tcx.coroutine_for_closure(coroutine_closure_def_id), + ty::ClosureKind::FnOnce, + tcx.lifetimes.re_erased, + args.as_coroutine_closure().tupled_upvars_ty(), + args.as_coroutine_closure().coroutine_captures_by_ref_ty(), + ), + sig.c_variadic, + sig.safety, + sig.abi, + ) + }); + let sig = tcx.liberate_late_bound_regions(coroutine_closure_def_id, poly_sig); + let ty::Coroutine(coroutine_def_id, coroutine_args) = *sig.output().kind() else { + bug!(); + }; + + let span = tcx.def_span(coroutine_closure_def_id); + let locals = local_decls_for_sig(&sig, span); + + let mut fields = vec![]; + + // Move all of the closure args. + for idx in 1..sig.inputs().len() { + fields.push(Operand::Move(Local::from_usize(idx + 1).into())); + } + + for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() { + if receiver_by_ref { + // The only situation where it's possible is when we capture immuatable references, + // since those don't need to be reborrowed with the closure's env lifetime. Since + // references are always `Copy`, just emit a copy. + if !matches!(ty.kind(), ty::Ref(_, _, hir::Mutability::Not)) { + // This copy is only sound if it's a `&T`. This may be + // reachable e.g. when eagerly computing the `Fn` instance + // of an async closure that doesn't borrowck. + tcx.dcx().delayed_bug(format!( + "field should be captured by immutable ref if we have \ + an `Fn` instance, but it was: {ty}" + )); + } + fields.push(Operand::Copy(tcx.mk_place_field( + self_local, + FieldIdx::from_usize(idx), + ty, + ))); + } else { + fields.push(Operand::Move(tcx.mk_place_field( + self_local, + FieldIdx::from_usize(idx), + ty, + ))); + } + } + + let source_info = SourceInfo::outermost(span); + let rvalue = Rvalue::Aggregate( + Box::new(AggregateKind::Coroutine(coroutine_def_id, coroutine_args)), + IndexVec::from_raw(fields), + ); + let stmt = Statement { + source_info, + kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))), + }; + let statements = vec![stmt]; + let start_block = BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }; + + let source = MirSource::from_instance(ty::InstanceKind::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + receiver_by_ref, + }); + + let body = + new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span); + dump_mir( + tcx, + false, + if receiver_by_ref { "coroutine_closure_by_ref" } else { "coroutine_closure_by_move" }, + &0, + &body, + |_, _| Ok(()), + ); + + body +} diff --git a/compiler/rustc_mir_transform/src/shim/async_destructor_ctor.rs b/compiler/rustc_mir_transform/src/shim/async_destructor_ctor.rs new file mode 100644 index 00000000000..fd7b7362cd9 --- /dev/null +++ b/compiler/rustc_mir_transform/src/shim/async_destructor_ctor.rs @@ -0,0 +1,431 @@ +use rustc_hir::def_id::DefId; +use rustc_hir::lang_items::LangItem; +use rustc_hir::{CoroutineDesugaring, CoroutineKind, CoroutineSource, Safety}; +use rustc_index::{Idx, IndexVec}; +use rustc_middle::mir::{ + BasicBlock, BasicBlockData, Body, Local, LocalDecl, MirSource, Operand, Place, Rvalue, + SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, +}; +use rustc_middle::ty::{self, EarlyBinder, Ty, TyCtxt, TypeVisitableExt}; + +use super::*; +use crate::patch::MirPatch; + +pub(super) fn build_async_destructor_ctor_shim<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: DefId, + ty: Ty<'tcx>, +) -> Body<'tcx> { + debug!("build_async_destructor_ctor_shim(def_id={:?}, ty={:?})", def_id, ty); + debug_assert_eq!(Some(def_id), tcx.lang_items().async_drop_in_place_fn()); + let generic_body = tcx.optimized_mir(def_id); + let args = tcx.mk_args(&[ty.into()]); + let mut body = EarlyBinder::bind(generic_body.clone()).instantiate(tcx, args); + + // Minimal shim passes except MentionedItems, + // it causes error "mentioned_items for DefId(...async_drop_in_place...) have already been set + pm::run_passes( + tcx, + &mut body, + &[ + &simplify::SimplifyCfg::MakeShim, + &abort_unwinding_calls::AbortUnwindingCalls, + &add_call_guards::CriticalCallEdges, + ], + None, + pm::Optimizations::Allowed, + ); + body +} + +// build_drop_shim analog for async drop glue (for generated coroutine poll function) +pub(super) fn build_async_drop_shim<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: DefId, + ty: Ty<'tcx>, +) -> Body<'tcx> { + debug!("build_async_drop_shim(def_id={:?}, ty={:?})", def_id, ty); + let ty::Coroutine(_, parent_args) = ty.kind() else { + bug!(); + }; + let typing_env = ty::TypingEnv::fully_monomorphized(); + + let drop_ty = parent_args.first().unwrap().expect_ty(); + let drop_ptr_ty = Ty::new_mut_ptr(tcx, drop_ty); + + assert!(tcx.is_coroutine(def_id)); + let coroutine_kind = tcx.coroutine_kind(def_id).unwrap(); + + assert!(matches!( + coroutine_kind, + CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Fn) + )); + + let needs_async_drop = drop_ty.needs_async_drop(tcx, typing_env); + let needs_sync_drop = !needs_async_drop && drop_ty.needs_drop(tcx, typing_env); + + let resume_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, DUMMY_SP)); + let resume_ty = Ty::new_adt(tcx, resume_adt, ty::List::empty()); + + let fn_sig = ty::Binder::dummy(tcx.mk_fn_sig( + [ty, resume_ty], + tcx.types.unit, + false, + Safety::Safe, + ExternAbi::Rust, + )); + let sig = tcx.instantiate_bound_regions_with_erased(fn_sig); + + assert!(!drop_ty.is_coroutine()); + let span = tcx.def_span(def_id); + let source_info = SourceInfo::outermost(span); + + // The first argument (index 0), but add 1 for the return value. + let coroutine_layout = Place::from(Local::new(1 + 0)); + let coroutine_layout_dropee = + tcx.mk_place_field(coroutine_layout, FieldIdx::new(0), drop_ptr_ty); + + let return_block = BasicBlock::new(1); + let mut blocks = IndexVec::with_capacity(2); + let block = |blocks: &mut IndexVec<_, _>, kind| { + blocks.push(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { source_info, kind }), + is_cleanup: false, + }) + }; + block( + &mut blocks, + if needs_sync_drop { + TerminatorKind::Drop { + place: tcx.mk_place_deref(coroutine_layout_dropee), + target: return_block, + unwind: UnwindAction::Continue, + replace: false, + drop: None, + async_fut: None, + } + } else { + TerminatorKind::Goto { target: return_block } + }, + ); + block(&mut blocks, TerminatorKind::Return); + + let source = MirSource::from_instance(ty::InstanceKind::AsyncDropGlue(def_id, ty)); + let mut body = + new_body(source, blocks, local_decls_for_sig(&sig, span), sig.inputs().len(), span); + + body.coroutine = Some(Box::new(CoroutineInfo::initial( + coroutine_kind, + parent_args.as_coroutine().yield_ty(), + parent_args.as_coroutine().resume_ty(), + ))); + body.phase = MirPhase::Runtime(RuntimePhase::Initial); + if !needs_async_drop || drop_ty.references_error() { + // Returning noop body for types without `need async drop` + // (or sync Drop in case of !`need async drop` && `need drop`). + // And also for error types. + return body; + } + + let mut dropee_ptr = Place::from(body.local_decls.push(LocalDecl::new(drop_ptr_ty, span))); + let st_kind = StatementKind::Assign(Box::new(( + dropee_ptr, + Rvalue::Use(Operand::Move(coroutine_layout_dropee)), + ))); + body.basic_blocks_mut()[START_BLOCK].statements.push(Statement { source_info, kind: st_kind }); + dropee_ptr = dropee_emit_retag(tcx, &mut body, dropee_ptr, span); + + let dropline = body.basic_blocks.last_index(); + + let patch = { + let mut elaborator = DropShimElaborator { + body: &body, + patch: MirPatch::new(&body), + tcx, + typing_env, + produce_async_drops: true, + }; + let dropee = tcx.mk_place_deref(dropee_ptr); + let resume_block = elaborator.patch.resume_block(); + elaborate_drop( + &mut elaborator, + source_info, + dropee, + (), + return_block, + Unwind::To(resume_block), + START_BLOCK, + dropline, + ); + elaborator.patch + }; + patch.apply(&mut body); + + body +} + +// * For async drop a "normal" coroutine: +// `async_drop_in_place<T>::{closure}.poll()` is converted into `T.future_drop_poll()`. +// Every coroutine has its `poll` (calculate yourself a little further) +// and its `future_drop_poll` (drop yourself a little further). +// +// * For async drop of "async drop coroutine" (`async_drop_in_place<T>::{closure}`): +// Correct drop of such coroutine means normal execution of nested async drop. +// async_drop(async_drop(T))::future_drop_poll() => async_drop(T)::poll(). +pub(super) fn build_future_drop_poll_shim<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: DefId, + proxy_ty: Ty<'tcx>, + impl_ty: Ty<'tcx>, +) -> Body<'tcx> { + let instance = ty::InstanceKind::FutureDropPollShim(def_id, proxy_ty, impl_ty); + let ty::Coroutine(coroutine_def_id, _) = impl_ty.kind() else { + bug!("build_future_drop_poll_shim not for coroutine impl type: ({:?})", instance); + }; + + let span = tcx.def_span(def_id); + + if tcx.is_async_drop_in_place_coroutine(*coroutine_def_id) { + build_adrop_for_adrop_shim(tcx, proxy_ty, impl_ty, span, instance) + } else { + build_adrop_for_coroutine_shim(tcx, proxy_ty, impl_ty, span, instance) + } +} + +// For async drop a "normal" coroutine: +// `async_drop_in_place<T>::{closure}.poll()` is converted into `T.future_drop_poll()`. +// Every coroutine has its `poll` (calculate yourself a little further) +// and its `future_drop_poll` (drop yourself a little further). +fn build_adrop_for_coroutine_shim<'tcx>( + tcx: TyCtxt<'tcx>, + proxy_ty: Ty<'tcx>, + impl_ty: Ty<'tcx>, + span: Span, + instance: ty::InstanceKind<'tcx>, +) -> Body<'tcx> { + let ty::Coroutine(coroutine_def_id, impl_args) = impl_ty.kind() else { + bug!("build_adrop_for_coroutine_shim not for coroutine impl type: ({:?})", instance); + }; + let proxy_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, proxy_ty); + // taking _1.0 (impl from Pin) + let pin_proxy_layout_local = Local::new(1); + let source_info = SourceInfo::outermost(span); + // converting `(_1: Pin<&mut CorLayout>, _2: &mut Context<'_>) -> Poll<()>` + // into `(_1: Pin<&mut ProxyLayout>, _2: &mut Context<'_>) -> Poll<()>` + // let mut _x: &mut CorLayout = &*_1.0.0; + // Replace old _1.0 accesses into _x accesses; + let body = tcx.optimized_mir(*coroutine_def_id).future_drop_poll().unwrap(); + let mut body: Body<'tcx> = EarlyBinder::bind(body.clone()).instantiate(tcx, impl_args); + body.source.instance = instance; + body.phase = MirPhase::Runtime(RuntimePhase::Initial); + body.var_debug_info.clear(); + let pin_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Pin, span)); + let args = tcx.mk_args(&[proxy_ref.into()]); + let pin_proxy_ref = Ty::new_adt(tcx, pin_adt_ref, args); + + let cor_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, impl_ty); + + let proxy_ref_local = body.local_decls.push(LocalDecl::new(proxy_ref, span)); + let cor_ref_local = body.local_decls.push(LocalDecl::new(cor_ref, span)); + + FixProxyFutureDropVisitor { tcx, replace_to: cor_ref_local }.visit_body(&mut body); + // Now changing first arg from Pin<&mut ImplCoroutine> to Pin<&mut ProxyCoroutine> + body.local_decls[pin_proxy_layout_local] = LocalDecl::new(pin_proxy_ref, span); + + { + let mut idx: usize = 0; + // _proxy = _1.0 : Pin<&ProxyLayout> ==> &ProxyLayout + let proxy_ref_place = Place::from(pin_proxy_layout_local) + .project_deeper(&[PlaceElem::Field(FieldIdx::ZERO, proxy_ref)], tcx); + body.basic_blocks_mut()[START_BLOCK].statements.insert( + idx, + Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::from(proxy_ref_local), + Rvalue::CopyForDeref(proxy_ref_place), + ))), + }, + ); + idx += 1; + let mut cor_ptr_local = proxy_ref_local; + proxy_ty.find_async_drop_impl_coroutine(tcx, |ty| { + if ty != proxy_ty { + let ty_ptr = Ty::new_mut_ptr(tcx, ty); + let impl_ptr_place = Place::from(cor_ptr_local).project_deeper( + &[PlaceElem::Deref, PlaceElem::Field(FieldIdx::ZERO, ty_ptr)], + tcx, + ); + cor_ptr_local = body.local_decls.push(LocalDecl::new(ty_ptr, span)); + // _cor_ptr = _proxy.0.0 (... .0) + body.basic_blocks_mut()[START_BLOCK].statements.insert( + idx, + Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::from(cor_ptr_local), + Rvalue::CopyForDeref(impl_ptr_place), + ))), + }, + ); + idx += 1; + } + }); + + // _cor_ref = &*cor_ptr + let reborrow = Rvalue::Ref( + tcx.lifetimes.re_erased, + BorrowKind::Mut { kind: MutBorrowKind::Default }, + tcx.mk_place_deref(Place::from(cor_ptr_local)), + ); + body.basic_blocks_mut()[START_BLOCK].statements.insert( + idx, + Statement { + source_info, + kind: StatementKind::Assign(Box::new((Place::from(cor_ref_local), reborrow))), + }, + ); + } + body +} + +// When dropping async drop coroutine, we continue its execution. +// async_drop(async_drop(T))::future_drop_poll() => async_drop(T)::poll() +fn build_adrop_for_adrop_shim<'tcx>( + tcx: TyCtxt<'tcx>, + proxy_ty: Ty<'tcx>, + impl_ty: Ty<'tcx>, + span: Span, + instance: ty::InstanceKind<'tcx>, +) -> Body<'tcx> { + let source_info = SourceInfo::outermost(span); + let proxy_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, proxy_ty); + // taking _1.0 (impl from Pin) + let pin_proxy_layout_local = Local::new(1); + let proxy_ref_place = Place::from(pin_proxy_layout_local) + .project_deeper(&[PlaceElem::Field(FieldIdx::ZERO, proxy_ref)], tcx); + let cor_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, impl_ty); + + // ret_ty = `Poll<()>` + let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, span)); + let ret_ty = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()])); + // env_ty = `Pin<&mut proxy_ty>` + let pin_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Pin, span)); + let env_ty = Ty::new_adt(tcx, pin_adt_ref, tcx.mk_args(&[proxy_ref.into()])); + // sig = `fn (Pin<&mut proxy_ty>, &mut Context) -> Poll<()>` + let sig = tcx.mk_fn_sig( + [env_ty, Ty::new_task_context(tcx)], + ret_ty, + false, + hir::Safety::Safe, + ExternAbi::Rust, + ); + // This function will be called with pinned proxy coroutine layout. + // We need to extract `Arg0.0` to get proxy layout, and then get `.0` + // further to receive impl coroutine (may be needed) + let mut locals = local_decls_for_sig(&sig, span); + let mut blocks = IndexVec::with_capacity(3); + + let proxy_ref_local = locals.push(LocalDecl::new(proxy_ref, span)); + + let call_bb = BasicBlock::new(1); + let return_bb = BasicBlock::new(2); + + let mut statements = Vec::new(); + + statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::from(proxy_ref_local), + Rvalue::CopyForDeref(proxy_ref_place), + ))), + }); + + let mut cor_ptr_local = proxy_ref_local; + proxy_ty.find_async_drop_impl_coroutine(tcx, |ty| { + if ty != proxy_ty { + let ty_ptr = Ty::new_mut_ptr(tcx, ty); + let impl_ptr_place = Place::from(cor_ptr_local) + .project_deeper(&[PlaceElem::Deref, PlaceElem::Field(FieldIdx::ZERO, ty_ptr)], tcx); + cor_ptr_local = locals.push(LocalDecl::new(ty_ptr, span)); + // _cor_ptr = _proxy.0.0 (... .0) + statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::from(cor_ptr_local), + Rvalue::CopyForDeref(impl_ptr_place), + ))), + }); + } + }); + + // convert impl coroutine ptr into ref + let reborrow = Rvalue::Ref( + tcx.lifetimes.re_erased, + BorrowKind::Mut { kind: MutBorrowKind::Default }, + tcx.mk_place_deref(Place::from(cor_ptr_local)), + ); + let cor_ref_place = Place::from(locals.push(LocalDecl::new(cor_ref, span))); + statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((cor_ref_place, reborrow))), + }); + + // cor_pin_ty = `Pin<&mut cor_ref>` + let cor_pin_ty = Ty::new_adt(tcx, pin_adt_ref, tcx.mk_args(&[cor_ref.into()])); + let cor_pin_place = Place::from(locals.push(LocalDecl::new(cor_pin_ty, span))); + + let pin_fn = tcx.require_lang_item(LangItem::PinNewUnchecked, span); + // call Pin<FutTy>::new_unchecked(&mut impl_cor) + blocks.push(BasicBlockData { + statements, + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::Call { + func: Operand::function_handle(tcx, pin_fn, [cor_ref.into()], span), + args: [dummy_spanned(Operand::Move(cor_ref_place))].into(), + destination: cor_pin_place, + target: Some(call_bb), + unwind: UnwindAction::Continue, + call_source: CallSource::Misc, + fn_span: span, + }, + }), + is_cleanup: false, + }); + // When dropping async drop coroutine, we continue its execution: + // we call impl::poll (impl_layout, ctx) + let poll_fn = tcx.require_lang_item(LangItem::FuturePoll, span); + let resume_ctx = Place::from(Local::new(2)); + blocks.push(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::Call { + func: Operand::function_handle(tcx, poll_fn, [impl_ty.into()], span), + args: [ + dummy_spanned(Operand::Move(cor_pin_place)), + dummy_spanned(Operand::Move(resume_ctx)), + ] + .into(), + destination: Place::return_place(), + target: Some(return_bb), + unwind: UnwindAction::Continue, + call_source: CallSource::Misc, + fn_span: span, + }, + }), + is_cleanup: false, + }); + blocks.push(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }); + + let source = MirSource::from_instance(instance); + let mut body = new_body(source, blocks, locals, sig.inputs().len(), span); + body.phase = MirPhase::Runtime(RuntimePhase::Initial); + return body; +} diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs new file mode 100644 index 00000000000..db933da6413 --- /dev/null +++ b/compiler/rustc_mir_transform/src/simplify.rs @@ -0,0 +1,630 @@ +//! A number of passes which remove various redundancies in the CFG. +//! +//! The `SimplifyCfg` pass gets rid of unnecessary blocks in the CFG, whereas the `SimplifyLocals` +//! gets rid of all the unnecessary local variable declarations. +//! +//! The `SimplifyLocals` pass is kinda expensive and therefore not very suitable to be run often. +//! Most of the passes should not care or be impacted in meaningful ways due to extra locals +//! either, so running the pass once, right before codegen, should suffice. +//! +//! On the other side of the spectrum, the `SimplifyCfg` pass is considerably cheap to run, thus +//! one should run it after every pass which may modify CFG in significant ways. This pass must +//! also be run before any analysis passes because it removes dead blocks, and some of these can be +//! ill-typed. +//! +//! The cause of this typing issue is typeck allowing most blocks whose end is not reachable have +//! an arbitrary return type, rather than having the usual () return type (as a note, typeck's +//! notion of reachability is in fact slightly weaker than MIR CFG reachability - see #31617). A +//! standard example of the situation is: +//! +//! ```rust +//! fn example() { +//! let _a: char = { return; }; +//! } +//! ``` +//! +//! Here the block (`{ return; }`) has the return type `char`, rather than `()`, but the MIR we +//! naively generate still contains the `_a = ()` write in the unreachable block "after" the +//! return. +//! +//! **WARNING**: This is one of the few optimizations that runs on built and analysis MIR, and +//! so its effects may affect the type-checking, borrow-checking, and other analysis of MIR. +//! We must be extremely careful to only apply optimizations that preserve UB and all +//! non-determinism, since changes here can affect which programs compile in an insta-stable way. +//! The normal logic that a program with UB can be changed to do anything does not apply to +//! pre-"runtime" MIR! + +use rustc_index::{Idx, IndexSlice, IndexVec}; +use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_span::DUMMY_SP; +use smallvec::SmallVec; +use tracing::{debug, trace}; + +pub(super) enum SimplifyCfg { + Initial, + PromoteConsts, + RemoveFalseEdges, + /// Runs at the beginning of "analysis to runtime" lowering, *before* drop elaboration. + PostAnalysis, + /// Runs at the end of "analysis to runtime" lowering, *after* drop elaboration. + /// This is before the main optimization passes on runtime MIR kick in. + PreOptimizations, + Final, + MakeShim, + AfterUnreachableEnumBranching, +} + +impl SimplifyCfg { + fn name(&self) -> &'static str { + match self { + SimplifyCfg::Initial => "SimplifyCfg-initial", + SimplifyCfg::PromoteConsts => "SimplifyCfg-promote-consts", + SimplifyCfg::RemoveFalseEdges => "SimplifyCfg-remove-false-edges", + SimplifyCfg::PostAnalysis => "SimplifyCfg-post-analysis", + SimplifyCfg::PreOptimizations => "SimplifyCfg-pre-optimizations", + SimplifyCfg::Final => "SimplifyCfg-final", + SimplifyCfg::MakeShim => "SimplifyCfg-make_shim", + SimplifyCfg::AfterUnreachableEnumBranching => { + "SimplifyCfg-after-unreachable-enum-branching" + } + } + } +} + +pub(super) fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + if CfgSimplifier::new(tcx, body).simplify() { + // `simplify` returns that it changed something. We must invalidate the CFG caches as they + // are not consistent with the modified CFG any more. + body.basic_blocks.invalidate_cfg_cache(); + } + remove_dead_blocks(body); + + // FIXME: Should probably be moved into some kind of pass manager + body.basic_blocks.as_mut_preserves_cfg().shrink_to_fit(); +} + +impl<'tcx> crate::MirPass<'tcx> for SimplifyCfg { + fn name(&self) -> &'static str { + self.name() + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("SimplifyCfg({:?}) - simplifying {:?}", self.name(), body.source); + simplify_cfg(tcx, body); + } + + fn is_required(&self) -> bool { + false + } +} + +struct CfgSimplifier<'a, 'tcx> { + preserve_switch_reads: bool, + basic_blocks: &'a mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + pred_count: IndexVec<BasicBlock, u32>, +} + +impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> { + fn new(tcx: TyCtxt<'tcx>, body: &'a mut Body<'tcx>) -> Self { + let mut pred_count = IndexVec::from_elem(0u32, &body.basic_blocks); + + // we can't use mir.predecessors() here because that counts + // dead blocks, which we don't want to. + pred_count[START_BLOCK] = 1; + + for (_, data) in traversal::preorder(body) { + if let Some(ref term) = data.terminator { + for tgt in term.successors() { + pred_count[tgt] += 1; + } + } + } + + // Preserve `SwitchInt` reads on built and analysis MIR, or if `-Zmir-preserve-ub`. + let preserve_switch_reads = matches!(body.phase, MirPhase::Built | MirPhase::Analysis(_)) + || tcx.sess.opts.unstable_opts.mir_preserve_ub; + // Do not clear caches yet. The caller to `simplify` will do it if anything changed. + let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); + + CfgSimplifier { preserve_switch_reads, basic_blocks, pred_count } + } + + /// Returns whether we actually simplified anything. In that case, the caller *must* invalidate + /// the CFG caches of the MIR body. + #[must_use] + fn simplify(mut self) -> bool { + self.strip_nops(); + + // Vec of the blocks that should be merged. We store the indices here, instead of the + // statements itself to avoid moving the (relatively) large statements twice. + // We do not push the statements directly into the target block (`bb`) as that is slower + // due to additional reallocations + let mut merged_blocks = Vec::new(); + let mut outer_changed = false; + loop { + let mut changed = false; + + for bb in self.basic_blocks.indices() { + if self.pred_count[bb] == 0 { + continue; + } + + debug!("simplifying {:?}", bb); + + let mut terminator = + self.basic_blocks[bb].terminator.take().expect("invalid terminator state"); + + terminator + .successors_mut(|successor| self.collapse_goto_chain(successor, &mut changed)); + + let mut inner_changed = true; + merged_blocks.clear(); + while inner_changed { + inner_changed = false; + inner_changed |= self.simplify_branch(&mut terminator); + inner_changed |= self.merge_successor(&mut merged_blocks, &mut terminator); + changed |= inner_changed; + } + + let statements_to_merge = + merged_blocks.iter().map(|&i| self.basic_blocks[i].statements.len()).sum(); + + if statements_to_merge > 0 { + let mut statements = std::mem::take(&mut self.basic_blocks[bb].statements); + statements.reserve(statements_to_merge); + for &from in &merged_blocks { + statements.append(&mut self.basic_blocks[from].statements); + } + self.basic_blocks[bb].statements = statements; + } + + self.basic_blocks[bb].terminator = Some(terminator); + } + + if !changed { + break; + } + + outer_changed = true; + } + + outer_changed + } + + /// This function will return `None` if + /// * the block has statements + /// * the block has a terminator other than `goto` + /// * the block has no terminator (meaning some other part of the current optimization stole it) + fn take_terminator_if_simple_goto(&mut self, bb: BasicBlock) -> Option<Terminator<'tcx>> { + match self.basic_blocks[bb] { + BasicBlockData { + ref statements, + terminator: + ref mut terminator @ Some(Terminator { kind: TerminatorKind::Goto { .. }, .. }), + .. + } if statements.is_empty() => terminator.take(), + // if `terminator` is None, this means we are in a loop. In that + // case, let all the loop collapse to its entry. + _ => None, + } + } + + /// Collapse a goto chain starting from `start` + fn collapse_goto_chain(&mut self, start: &mut BasicBlock, changed: &mut bool) { + // Using `SmallVec` here, because in some logs on libcore oli-obk saw many single-element + // goto chains. We should probably benchmark different sizes. + let mut terminators: SmallVec<[_; 1]> = Default::default(); + let mut current = *start; + while let Some(terminator) = self.take_terminator_if_simple_goto(current) { + let Terminator { kind: TerminatorKind::Goto { target }, .. } = terminator else { + unreachable!(); + }; + terminators.push((current, terminator)); + current = target; + } + let last = current; + *start = last; + while let Some((current, mut terminator)) = terminators.pop() { + let Terminator { kind: TerminatorKind::Goto { ref mut target }, .. } = terminator + else { + unreachable!(); + }; + *changed |= *target != last; + *target = last; + debug!("collapsing goto chain from {:?} to {:?}", current, target); + + if self.pred_count[current] == 1 { + // This is the last reference to current, so the pred-count to + // to target is moved into the current block. + self.pred_count[current] = 0; + } else { + self.pred_count[*target] += 1; + self.pred_count[current] -= 1; + } + self.basic_blocks[current].terminator = Some(terminator); + } + } + + // merge a block with 1 `goto` predecessor to its parent + fn merge_successor( + &mut self, + merged_blocks: &mut Vec<BasicBlock>, + terminator: &mut Terminator<'tcx>, + ) -> bool { + let target = match terminator.kind { + TerminatorKind::Goto { target } if self.pred_count[target] == 1 => target, + _ => return false, + }; + + debug!("merging block {:?} into {:?}", target, terminator); + *terminator = match self.basic_blocks[target].terminator.take() { + Some(terminator) => terminator, + None => { + // unreachable loop - this should not be possible, as we + // don't strand blocks, but handle it correctly. + return false; + } + }; + + merged_blocks.push(target); + self.pred_count[target] = 0; + + true + } + + // turn a branch with all successors identical to a goto + fn simplify_branch(&mut self, terminator: &mut Terminator<'tcx>) -> bool { + // Removing a `SwitchInt` terminator may remove reads that result in UB, + // so we must not apply this optimization before borrowck or when + // `-Zmir-preserve-ub` is set. + if self.preserve_switch_reads { + return false; + } + + let TerminatorKind::SwitchInt { .. } = terminator.kind else { + return false; + }; + + let first_succ = { + if let Some(first_succ) = terminator.successors().next() { + if terminator.successors().all(|s| s == first_succ) { + let count = terminator.successors().count(); + self.pred_count[first_succ] -= (count - 1) as u32; + first_succ + } else { + return false; + } + } else { + return false; + } + }; + + debug!("simplifying branch {:?}", terminator); + terminator.kind = TerminatorKind::Goto { target: first_succ }; + true + } + + fn strip_nops(&mut self) { + for blk in self.basic_blocks.iter_mut() { + blk.statements.retain(|stmt| !matches!(stmt.kind, StatementKind::Nop)) + } + } +} + +pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>) { + if let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind { + let otherwise = targets.otherwise(); + if targets.iter().any(|t| t.1 == otherwise) { + *targets = SwitchTargets::new( + targets.iter().filter(|t| t.1 != otherwise), + targets.otherwise(), + ); + } + } +} + +pub(super) fn remove_dead_blocks(body: &mut Body<'_>) { + let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| { + // CfgSimplifier::simplify leaves behind some unreachable basic blocks without a + // terminator. Those blocks will be deleted by remove_dead_blocks, but we run just + // before then so we need to handle missing terminators. + // We also need to prevent confusing cleanup and non-cleanup blocks. In practice we + // don't emit empty unreachable cleanup blocks, so this simple check suffices. + bbdata.terminator.is_some() && bbdata.is_empty_unreachable() && !bbdata.is_cleanup + }; + + let reachable = traversal::reachable_as_bitset(body); + let empty_unreachable_blocks = body + .basic_blocks + .iter_enumerated() + .filter(|(bb, bbdata)| should_deduplicate_unreachable(bbdata) && reachable.contains(*bb)) + .count(); + + let num_blocks = body.basic_blocks.len(); + if num_blocks == reachable.count() && empty_unreachable_blocks <= 1 { + return; + } + + let basic_blocks = body.basic_blocks.as_mut(); + + let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect(); + let mut orig_index = 0; + let mut used_index = 0; + let mut kept_unreachable = None; + let mut deduplicated_unreachable = false; + basic_blocks.raw.retain(|bbdata| { + let orig_bb = BasicBlock::new(orig_index); + if !reachable.contains(orig_bb) { + orig_index += 1; + return false; + } + + let used_bb = BasicBlock::new(used_index); + if should_deduplicate_unreachable(bbdata) { + let kept_unreachable = *kept_unreachable.get_or_insert(used_bb); + if kept_unreachable != used_bb { + replacements[orig_index] = kept_unreachable; + deduplicated_unreachable = true; + orig_index += 1; + return false; + } + } + + replacements[orig_index] = used_bb; + used_index += 1; + orig_index += 1; + true + }); + + // If we deduplicated unreachable blocks we erase their source_info as we + // can no longer attribute their code to a particular location in the + // source. + if deduplicated_unreachable { + basic_blocks[kept_unreachable.unwrap()].terminator_mut().source_info = + SourceInfo { span: DUMMY_SP, scope: OUTERMOST_SOURCE_SCOPE }; + } + + for block in basic_blocks { + block.terminator_mut().successors_mut(|target| *target = replacements[target.index()]); + } +} + +pub(super) enum SimplifyLocals { + BeforeConstProp, + AfterGVN, + Final, +} + +impl<'tcx> crate::MirPass<'tcx> for SimplifyLocals { + fn name(&self) -> &'static str { + match &self { + SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop", + SimplifyLocals::AfterGVN => "SimplifyLocals-after-value-numbering", + SimplifyLocals::Final => "SimplifyLocals-final", + } + } + + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("running SimplifyLocals on {:?}", body.source); + + // First, we're going to get a count of *actual* uses for every `Local`. + let mut used_locals = UsedLocals::new(body); + + // Next, we're going to remove any `Local` with zero actual uses. When we remove those + // `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals` + // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from + // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a + // fixedpoint where there are no more unused locals. + remove_unused_definitions_helper(&mut used_locals, body); + + // Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the + // `Local`s. + let map = make_local_map(&mut body.local_decls, &used_locals); + + // Only bother running the `LocalUpdater` if we actually found locals to remove. + if map.iter().any(Option::is_none) { + // Update references to all vars and tmps now + let mut updater = LocalUpdater { map, tcx }; + updater.visit_body_preserves_cfg(body); + + body.local_decls.shrink_to_fit(); + } + } + + fn is_required(&self) -> bool { + false + } +} + +pub(super) fn remove_unused_definitions<'tcx>(body: &mut Body<'tcx>) { + // First, we're going to get a count of *actual* uses for every `Local`. + let mut used_locals = UsedLocals::new(body); + + // Next, we're going to remove any `Local` with zero actual uses. When we remove those + // `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals` + // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from + // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a + // fixedpoint where there are no more unused locals. + remove_unused_definitions_helper(&mut used_locals, body); +} + +/// Construct the mapping while swapping out unused stuff out from the `vec`. +fn make_local_map<V>( + local_decls: &mut IndexVec<Local, V>, + used_locals: &UsedLocals, +) -> IndexVec<Local, Option<Local>> { + let mut map: IndexVec<Local, Option<Local>> = IndexVec::from_elem(None, local_decls); + let mut used = Local::ZERO; + + for alive_index in local_decls.indices() { + // `is_used` treats the `RETURN_PLACE` and arguments as used. + if !used_locals.is_used(alive_index) { + continue; + } + + map[alive_index] = Some(used); + if alive_index != used { + local_decls.swap(alive_index, used); + } + used.increment_by(1); + } + local_decls.truncate(used.index()); + map +} + +/// Keeps track of used & unused locals. +struct UsedLocals { + increment: bool, + arg_count: u32, + use_count: IndexVec<Local, u32>, +} + +impl UsedLocals { + /// Determines which locals are used & unused in the given body. + fn new(body: &Body<'_>) -> Self { + let mut this = Self { + increment: true, + arg_count: body.arg_count.try_into().unwrap(), + use_count: IndexVec::from_elem(0, &body.local_decls), + }; + this.visit_body(body); + this + } + + /// Checks if local is used. + /// + /// Return place and arguments are always considered used. + fn is_used(&self, local: Local) -> bool { + trace!("is_used({:?}): use_count: {:?}", local, self.use_count[local]); + local.as_u32() <= self.arg_count || self.use_count[local] != 0 + } + + /// Updates the use counts to reflect the removal of given statement. + fn statement_removed(&mut self, statement: &Statement<'_>) { + self.increment = false; + + // The location of the statement is irrelevant. + let location = Location::START; + self.visit_statement(statement, location); + } + + /// Visits a left-hand side of an assignment. + fn visit_lhs(&mut self, place: &Place<'_>, location: Location) { + if place.is_indirect() { + // A use, not a definition. + self.visit_place(place, PlaceContext::MutatingUse(MutatingUseContext::Store), location); + } else { + // A definition. The base local itself is not visited, so this occurrence is not counted + // toward its use count. There might be other locals still, used in an indexing + // projection. + self.super_projection( + place.as_ref(), + PlaceContext::MutatingUse(MutatingUseContext::Projection), + location, + ); + } + } +} + +impl<'tcx> Visitor<'tcx> for UsedLocals { + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match statement.kind { + StatementKind::Intrinsic(..) + | StatementKind::Retag(..) + | StatementKind::Coverage(..) + | StatementKind::FakeRead(..) + | StatementKind::PlaceMention(..) + | StatementKind::AscribeUserType(..) => { + self.super_statement(statement, location); + } + + StatementKind::ConstEvalCounter | StatementKind::Nop => {} + + StatementKind::StorageLive(_local) | StatementKind::StorageDead(_local) => {} + + StatementKind::Assign(box (ref place, ref rvalue)) => { + if rvalue.is_safe_to_remove() { + self.visit_lhs(place, location); + self.visit_rvalue(rvalue, location); + } else { + self.super_statement(statement, location); + } + } + + StatementKind::SetDiscriminant { ref place, variant_index: _ } + | StatementKind::Deinit(ref place) + | StatementKind::BackwardIncompatibleDropHint { ref place, reason: _ } => { + self.visit_lhs(place, location); + } + } + } + + fn visit_local(&mut self, local: Local, _ctx: PlaceContext, _location: Location) { + if self.increment { + self.use_count[local] += 1; + } else { + assert_ne!(self.use_count[local], 0); + self.use_count[local] -= 1; + } + } +} + +/// Removes unused definitions. Updates the used locals to reflect the changes made. +fn remove_unused_definitions_helper(used_locals: &mut UsedLocals, body: &mut Body<'_>) { + // The use counts are updated as we remove the statements. A local might become unused + // during the retain operation, leading to a temporary inconsistency (storage statements or + // definitions referencing the local might remain). For correctness it is crucial that this + // computation reaches a fixed point. + + let mut modified = true; + while modified { + modified = false; + + for data in body.basic_blocks.as_mut_preserves_cfg() { + // Remove unnecessary StorageLive and StorageDead annotations. + data.statements.retain(|statement| { + let keep = match &statement.kind { + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + used_locals.is_used(*local) + } + StatementKind::Assign(box (place, _)) => used_locals.is_used(place.local), + + StatementKind::SetDiscriminant { place, .. } + | StatementKind::BackwardIncompatibleDropHint { place, reason: _ } + | StatementKind::Deinit(place) => used_locals.is_used(place.local), + StatementKind::Nop => false, + _ => true, + }; + + if !keep { + trace!("removing statement {:?}", statement); + modified = true; + used_locals.statement_removed(statement); + } + + keep + }); + } + } +} + +struct LocalUpdater<'tcx> { + map: IndexVec<Local, Option<Local>>, + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for LocalUpdater<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, l: &mut Local, _: PlaceContext, _: Location) { + *l = self.map[*l].unwrap(); + } +} diff --git a/compiler/rustc_mir_transform/src/simplify_branches.rs b/compiler/rustc_mir_transform/src/simplify_branches.rs new file mode 100644 index 00000000000..886f4d6e509 --- /dev/null +++ b/compiler/rustc_mir_transform/src/simplify_branches.rs @@ -0,0 +1,67 @@ +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use tracing::trace; + +pub(super) enum SimplifyConstCondition { + AfterConstProp, + Final, +} + +/// A pass that replaces a branch with a goto when its condition is known. +impl<'tcx> crate::MirPass<'tcx> for SimplifyConstCondition { + fn name(&self) -> &'static str { + match self { + SimplifyConstCondition::AfterConstProp => "SimplifyConstCondition-after-const-prop", + SimplifyConstCondition::Final => "SimplifyConstCondition-final", + } + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running SimplifyConstCondition on {:?}", body.source); + let typing_env = body.typing_env(tcx); + 'blocks: for block in body.basic_blocks_mut() { + for stmt in block.statements.iter_mut() { + // Simplify `assume` of a known value: either a NOP or unreachable. + if let StatementKind::Intrinsic(box ref intrinsic) = stmt.kind + && let NonDivergingIntrinsic::Assume(discr) = intrinsic + && let Operand::Constant(c) = discr + && let Some(constant) = c.const_.try_eval_bool(tcx, typing_env) + { + if constant { + stmt.make_nop(); + } else { + block.statements.clear(); + block.terminator_mut().kind = TerminatorKind::Unreachable; + continue 'blocks; + } + } + } + + let terminator = block.terminator_mut(); + terminator.kind = match terminator.kind { + TerminatorKind::SwitchInt { + discr: Operand::Constant(ref c), ref targets, .. + } => { + let constant = c.const_.try_eval_bits(tcx, typing_env); + if let Some(constant) = constant { + let target = targets.target_for_value(constant); + TerminatorKind::Goto { target } + } else { + continue; + } + } + TerminatorKind::Assert { + target, cond: Operand::Constant(ref c), expected, .. + } => match c.const_.try_eval_bool(tcx, typing_env) { + Some(v) if v == expected => TerminatorKind::Goto { target }, + _ => continue, + }, + _ => continue, + }; + } + } + + fn is_required(&self) -> bool { + false + } +} diff --git a/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs new file mode 100644 index 00000000000..bd008230731 --- /dev/null +++ b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs @@ -0,0 +1,249 @@ +use std::iter; + +use rustc_middle::bug; +use rustc_middle::mir::interpret::Scalar; +use rustc_middle::mir::{ + BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement, StatementKind, SwitchTargets, + TerminatorKind, +}; +use rustc_middle::ty::{Ty, TyCtxt}; +use tracing::trace; + +/// Pass to convert `if` conditions on integrals into switches on the integral. +/// For an example, it turns something like +/// +/// ```ignore (MIR) +/// _3 = Eq(move _4, const 43i32); +/// StorageDead(_4); +/// switchInt(_3) -> [false: bb2, otherwise: bb3]; +/// ``` +/// +/// into: +/// +/// ```ignore (MIR) +/// switchInt(_4) -> [43i32: bb3, otherwise: bb2]; +/// ``` +pub(super) struct SimplifyComparisonIntegral; + +impl<'tcx> crate::MirPass<'tcx> for SimplifyComparisonIntegral { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running SimplifyComparisonIntegral on {:?}", body.source); + + let helper = OptimizationFinder { body }; + let opts = helper.find_optimizations(); + let mut storage_deads_to_insert = vec![]; + let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![]; + let typing_env = body.typing_env(tcx); + for opt in opts { + trace!("SUCCESS: Applying {:?}", opt); + // replace terminator with a switchInt that switches on the integer directly + let bbs = &mut body.basic_blocks_mut(); + let bb = &mut bbs[opt.bb_idx]; + let new_value = match opt.branch_value_scalar { + Scalar::Int(int) => { + let layout = tcx + .layout_of(typing_env.as_query_input(opt.branch_value_ty)) + .expect("if we have an evaluated constant we must know the layout"); + int.to_bits(layout.size) + } + Scalar::Ptr(..) => continue, + }; + const FALSE: u128 = 0; + + let mut new_targets = opt.targets; + let first_value = new_targets.iter().next().unwrap().0; + let first_is_false_target = first_value == FALSE; + match opt.op { + BinOp::Eq => { + // if the assignment was Eq we want the true case to be first + if first_is_false_target { + new_targets.all_targets_mut().swap(0, 1); + } + } + BinOp::Ne => { + // if the assignment was Ne we want the false case to be first + if !first_is_false_target { + new_targets.all_targets_mut().swap(0, 1); + } + } + _ => unreachable!(), + } + + // delete comparison statement if it the value being switched on was moved, which means + // it can not be user later on + if opt.can_remove_bin_op_stmt { + bb.statements[opt.bin_op_stmt_idx].make_nop(); + } else { + // if the integer being compared to a const integral is being moved into the + // comparison, e.g `_2 = Eq(move _3, const 'x');` + // we want to avoid making a double move later on in the switchInt on _3. + // So to avoid `switchInt(move _3) -> ['x': bb2, otherwise: bb1];`, + // we convert the move in the comparison statement to a copy. + + // unwrap is safe as we know this statement is an assign + let (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap(); + + use Operand::*; + match rhs { + Rvalue::BinaryOp(_, box (left @ Move(_), Constant(_))) => { + *left = Copy(opt.to_switch_on); + } + Rvalue::BinaryOp(_, box (Constant(_), right @ Move(_))) => { + *right = Copy(opt.to_switch_on); + } + _ => (), + } + } + + let terminator = bb.terminator(); + + // remove StorageDead (if it exists) being used in the assign of the comparison + for (stmt_idx, stmt) in bb.statements.iter().enumerate() { + if !matches!( + stmt.kind, + StatementKind::StorageDead(local) if local == opt.to_switch_on.local + ) { + continue; + } + storage_deads_to_remove.push((stmt_idx, opt.bb_idx)); + // if we have StorageDeads to remove then make sure to insert them at the top of + // each target + for bb_idx in new_targets.all_targets() { + storage_deads_to_insert.push(( + *bb_idx, + Statement { + source_info: terminator.source_info, + kind: StatementKind::StorageDead(opt.to_switch_on.local), + }, + )); + } + } + + let [bb_cond, bb_otherwise] = match new_targets.all_targets() { + [a, b] => [*a, *b], + e => bug!("expected 2 switch targets, got: {:?}", e), + }; + + let targets = SwitchTargets::new(iter::once((new_value, bb_cond)), bb_otherwise); + + let terminator = bb.terminator_mut(); + terminator.kind = + TerminatorKind::SwitchInt { discr: Operand::Move(opt.to_switch_on), targets }; + } + + for (idx, bb_idx) in storage_deads_to_remove { + body.basic_blocks_mut()[bb_idx].statements[idx].make_nop(); + } + + for (idx, stmt) in storage_deads_to_insert { + body.basic_blocks_mut()[idx].statements.insert(0, stmt); + } + } + + fn is_required(&self) -> bool { + false + } +} + +struct OptimizationFinder<'a, 'tcx> { + body: &'a Body<'tcx>, +} + +impl<'tcx> OptimizationFinder<'_, 'tcx> { + fn find_optimizations(&self) -> Vec<OptimizationInfo<'tcx>> { + self.body + .basic_blocks + .iter_enumerated() + .filter_map(|(bb_idx, bb)| { + // find switch + let (place_switched_on, targets, place_switched_on_moved) = + match &bb.terminator().kind { + rustc_middle::mir::TerminatorKind::SwitchInt { discr, targets, .. } => { + Some((discr.place()?, targets, discr.is_move())) + } + _ => None, + }?; + + // find the statement that assigns the place being switched on + bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| { + match &stmt.kind { + rustc_middle::mir::StatementKind::Assign(box (lhs, rhs)) + if *lhs == place_switched_on => + { + match rhs { + Rvalue::BinaryOp( + op @ (BinOp::Eq | BinOp::Ne), + box (left, right), + ) => { + let (branch_value_scalar, branch_value_ty, to_switch_on) = + find_branch_value_info(left, right)?; + + Some(OptimizationInfo { + bin_op_stmt_idx: stmt_idx, + bb_idx, + can_remove_bin_op_stmt: place_switched_on_moved, + to_switch_on, + branch_value_scalar, + branch_value_ty, + op: *op, + targets: targets.clone(), + }) + } + _ => None, + } + } + _ => None, + } + }) + }) + .collect() + } +} + +fn find_branch_value_info<'tcx>( + left: &Operand<'tcx>, + right: &Operand<'tcx>, +) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> { + // check that either left or right is a constant. + // if any are, we can use the other to switch on, and the constant as a value in a switch + use Operand::*; + match (left, right) { + (Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on)) + | (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => { + let branch_value_ty = branch_value.const_.ty(); + // we only want to apply this optimization if we are matching on integrals (and chars), + // as it is not possible to switch on floats + if !branch_value_ty.is_integral() && !branch_value_ty.is_char() { + return None; + }; + let branch_value_scalar = branch_value.const_.try_to_scalar()?; + Some((branch_value_scalar, branch_value_ty, *to_switch_on)) + } + _ => None, + } +} + +#[derive(Debug)] +struct OptimizationInfo<'tcx> { + /// Basic block to apply the optimization + bb_idx: BasicBlock, + /// Statement index of Eq/Ne assignment that can be removed. None if the assignment can not be + /// removed - i.e the statement is used later on + bin_op_stmt_idx: usize, + /// Can remove Eq/Ne assignment + can_remove_bin_op_stmt: bool, + /// Place that needs to be switched on. This place is of type integral + to_switch_on: Place<'tcx>, + /// Constant to use in switch target value + branch_value_scalar: Scalar, + /// Type of the constant value + branch_value_ty: Ty<'tcx>, + /// Either Eq or Ne + op: BinOp, + /// Current targets used in the switch + targets: SwitchTargets, +} diff --git a/compiler/rustc_mir_transform/src/single_use_consts.rs b/compiler/rustc_mir_transform/src/single_use_consts.rs new file mode 100644 index 00000000000..02caa92ad3f --- /dev/null +++ b/compiler/rustc_mir_transform/src/single_use_consts.rs @@ -0,0 +1,205 @@ +use rustc_index::IndexVec; +use rustc_index::bit_set::DenseBitSet; +use rustc_middle::bug; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +/// Various parts of MIR building introduce temporaries that are commonly not needed. +/// +/// Notably, `if CONST` and `match CONST` end up being used-once temporaries, which +/// obfuscates the structure for other passes and codegen, which would like to always +/// be able to just see the constant directly. +/// +/// At higher optimization levels fancier passes like GVN will take care of this +/// in a more general fashion, but this handles the easy cases so can run in debug. +/// +/// This only removes constants with a single-use because re-evaluating constants +/// isn't always an improvement, especially for large ones. +/// +/// It also removes *never*-used constants, since it had all the information +/// needed to do that too, including updating the debug info. +pub(super) struct SingleUseConsts; + +impl<'tcx> crate::MirPass<'tcx> for SingleUseConsts { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut finder = SingleUseConstsFinder { + ineligible_locals: DenseBitSet::new_empty(body.local_decls.len()), + locations: IndexVec::from_elem(LocationPair::new(), &body.local_decls), + locals_in_debug_info: DenseBitSet::new_empty(body.local_decls.len()), + }; + + finder.ineligible_locals.insert_range(..=Local::from_usize(body.arg_count)); + + finder.visit_body(body); + + for (local, locations) in finder.locations.iter_enumerated() { + if finder.ineligible_locals.contains(local) { + continue; + } + + let Some(init_loc) = locations.init_loc else { + continue; + }; + + // We're only changing an operand, not the terminator kinds or successors + let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); + let init_statement_kind = std::mem::replace( + &mut basic_blocks[init_loc.block].statements[init_loc.statement_index].kind, + StatementKind::Nop, + ); + let StatementKind::Assign(place_and_rvalue) = init_statement_kind else { + bug!("No longer an assign?"); + }; + let (place, rvalue) = *place_and_rvalue; + assert_eq!(place.as_local(), Some(local)); + let Rvalue::Use(operand) = rvalue else { bug!("No longer a use?") }; + + let mut replacer = LocalReplacer { tcx, local, operand: Some(operand) }; + + if finder.locals_in_debug_info.contains(local) { + for var_debug_info in &mut body.var_debug_info { + replacer.visit_var_debug_info(var_debug_info); + } + } + + let Some(use_loc) = locations.use_loc else { continue }; + + let use_block = &mut basic_blocks[use_loc.block]; + if let Some(use_statement) = use_block.statements.get_mut(use_loc.statement_index) { + replacer.visit_statement(use_statement, use_loc); + } else { + replacer.visit_terminator(use_block.terminator_mut(), use_loc); + } + + if replacer.operand.is_some() { + bug!( + "operand wasn't used replacing local {local:?} with locations {locations:?} in body {body:#?}" + ); + } + } + } + + fn is_required(&self) -> bool { + true + } +} + +#[derive(Copy, Clone, Debug)] +struct LocationPair { + init_loc: Option<Location>, + use_loc: Option<Location>, +} + +impl LocationPair { + fn new() -> Self { + Self { init_loc: None, use_loc: None } + } +} + +struct SingleUseConstsFinder { + ineligible_locals: DenseBitSet<Local>, + locations: IndexVec<Local, LocationPair>, + locals_in_debug_info: DenseBitSet<Local>, +} + +impl<'tcx> Visitor<'tcx> for SingleUseConstsFinder { + fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { + if let Some(local) = place.as_local() + && let Rvalue::Use(operand) = rvalue + && let Operand::Constant(_) = operand + { + let locations = &mut self.locations[local]; + if locations.init_loc.is_some() { + self.ineligible_locals.insert(local); + } else { + locations.init_loc = Some(location); + } + } else { + self.super_assign(place, rvalue, location); + } + } + + fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { + if let Some(place) = operand.place() + && let Some(local) = place.as_local() + { + let locations = &mut self.locations[local]; + if locations.use_loc.is_some() { + self.ineligible_locals.insert(local); + } else { + locations.use_loc = Some(location); + } + } else { + self.super_operand(operand, location); + } + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + // Storage markers are irrelevant to this. + StatementKind::StorageLive(_) | StatementKind::StorageDead(_) => {} + _ => self.super_statement(statement, location), + } + } + + fn visit_var_debug_info(&mut self, var_debug_info: &VarDebugInfo<'tcx>) { + if let VarDebugInfoContents::Place(place) = &var_debug_info.value + && let Some(local) = place.as_local() + { + self.locals_in_debug_info.insert(local); + } else { + self.super_var_debug_info(var_debug_info); + } + } + + fn visit_local(&mut self, local: Local, _context: PlaceContext, _location: Location) { + // If there's any path that gets here, rather than being understood elsewhere, + // then we'd better not do anything with this local. + self.ineligible_locals.insert(local); + } +} + +struct LocalReplacer<'tcx> { + tcx: TyCtxt<'tcx>, + local: Local, + operand: Option<Operand<'tcx>>, +} + +impl<'tcx> MutVisitor<'tcx> for LocalReplacer<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, _location: Location) { + if let Operand::Copy(place) | Operand::Move(place) = operand + && let Some(local) = place.as_local() + && local == self.local + { + *operand = self.operand.take().unwrap_or_else(|| { + bug!("there was a second use of the operand"); + }); + } + } + + fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) { + if let VarDebugInfoContents::Place(place) = &var_debug_info.value + && let Some(local) = place.as_local() + && local == self.local + { + let const_op = *self + .operand + .as_ref() + .unwrap_or_else(|| { + bug!("the operand was already stolen"); + }) + .constant() + .unwrap(); + var_debug_info.value = VarDebugInfoContents::Const(const_op); + } + } +} diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs new file mode 100644 index 00000000000..7c6ccc89c4f --- /dev/null +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -0,0 +1,461 @@ +use rustc_abi::{FIRST_VARIANT, FieldIdx}; +use rustc_data_structures::flat_map_in_place::FlatMapInPlace; +use rustc_hir::LangItem; +use rustc_index::IndexVec; +use rustc_index::bit_set::{DenseBitSet, GrowableBitSet}; +use rustc_middle::bug; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields}; +use tracing::{debug, instrument}; + +use crate::patch::MirPatch; + +pub(super) struct ScalarReplacementOfAggregates; + +impl<'tcx> crate::MirPass<'tcx> for ScalarReplacementOfAggregates { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 2 + } + + #[instrument(level = "debug", skip(self, tcx, body))] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); + + // Avoid query cycles (coroutines require optimized MIR for layout). + if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() { + return; + } + + let mut excluded = excluded_locals(body); + let typing_env = body.typing_env(tcx); + loop { + debug!(?excluded); + let escaping = escaping_locals(tcx, typing_env, &excluded, body); + debug!(?escaping); + let replacements = compute_flattening(tcx, typing_env, body, escaping); + debug!(?replacements); + let all_dead_locals = replace_flattened_locals(tcx, body, replacements); + if !all_dead_locals.is_empty() { + excluded.union(&all_dead_locals); + excluded = { + let mut growable = GrowableBitSet::from(excluded); + growable.ensure(body.local_decls.len()); + growable.into() + }; + } else { + break; + } + } + } + + fn is_required(&self) -> bool { + false + } +} + +/// Identify all locals that are not eligible for SROA. +/// +/// There are 3 cases: +/// - the aggregated local is used or passed to other code (function parameters and arguments); +/// - the locals is a union or an enum; +/// - the local's address is taken, and thus the relative addresses of the fields are observable to +/// client code. +fn escaping_locals<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + excluded: &DenseBitSet<Local>, + body: &Body<'tcx>, +) -> DenseBitSet<Local> { + let is_excluded_ty = |ty: Ty<'tcx>| { + if ty.is_union() || ty.is_enum() { + return true; + } + if let ty::Adt(def, _args) = ty.kind() { + if def.repr().simd() { + // Exclude #[repr(simd)] types so that they are not de-optimized into an array + return true; + } + if tcx.is_lang_item(def.did(), LangItem::DynMetadata) { + // codegen wants to see the `DynMetadata<T>`, + // not the inner reference-to-opaque-type. + return true; + } + // We already excluded unions and enums, so this ADT must have one variant + let variant = def.variant(FIRST_VARIANT); + if variant.fields.len() > 1 { + // If this has more than one field, it cannot be a wrapper that only provides a + // niche, so we do not want to automatically exclude it. + return false; + } + let Ok(layout) = tcx.layout_of(typing_env.as_query_input(ty)) else { + // We can't get the layout + return true; + }; + if layout.layout.largest_niche().is_some() { + // This type has a niche + return true; + } + } + // Default for non-ADTs + false + }; + + let mut set = DenseBitSet::new_empty(body.local_decls.len()); + set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count)); + for (local, decl) in body.local_decls().iter_enumerated() { + if excluded.contains(local) || is_excluded_ty(decl.ty) { + set.insert(local); + } + } + let mut visitor = EscapeVisitor { set }; + visitor.visit_body(body); + return visitor.set; + + struct EscapeVisitor { + set: DenseBitSet<Local>, + } + + impl<'tcx> Visitor<'tcx> for EscapeVisitor { + fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) { + self.set.insert(local); + } + + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { + // Mirror the implementation in PreFlattenVisitor. + if let &[PlaceElem::Field(..), ..] = &place.projection[..] { + return; + } + self.super_place(place, context, location); + } + + fn visit_assign( + &mut self, + lvalue: &Place<'tcx>, + rvalue: &Rvalue<'tcx>, + location: Location, + ) { + if lvalue.as_local().is_some() { + match rvalue { + // Aggregate assignments are expanded in run_pass. + Rvalue::Aggregate(..) | Rvalue::Use(..) => { + self.visit_rvalue(rvalue, location); + return; + } + _ => {} + } + } + self.super_assign(lvalue, rvalue, location) + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match statement.kind { + // Storage statements are expanded in run_pass. + StatementKind::StorageLive(..) + | StatementKind::StorageDead(..) + | StatementKind::Deinit(..) => return, + _ => self.super_statement(statement, location), + } + } + + // We ignore anything that happens in debuginfo, since we expand it using + // `VarDebugInfoFragment`. + fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {} + } +} + +#[derive(Default, Debug)] +struct ReplacementMap<'tcx> { + /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage + /// and deinit statement and debuginfo. + fragments: IndexVec<Local, Option<IndexVec<FieldIdx, Option<(Ty<'tcx>, Local)>>>>, +} + +impl<'tcx> ReplacementMap<'tcx> { + fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> { + let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { + return None; + }; + let fields = self.fragments[place.local].as_ref()?; + let (_, new_local) = fields[f]?; + Some(Place { local: new_local, projection: tcx.mk_place_elems(rest) }) + } + + fn place_fragments( + &self, + place: Place<'tcx>, + ) -> Option<impl Iterator<Item = (FieldIdx, Ty<'tcx>, Local)>> { + let local = place.as_local()?; + let fields = self.fragments[local].as_ref()?; + Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| { + let (ty, local) = opt_ty_local?; + Some((field, ty, local)) + })) + } +} + +/// Compute the replacement of flattened places into locals. +/// +/// For each eligible place, we assign a new local to each accessed field. +/// The replacement will be done later in `ReplacementVisitor`. +fn compute_flattening<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + body: &mut Body<'tcx>, + escaping: DenseBitSet<Local>, +) -> ReplacementMap<'tcx> { + let mut fragments = IndexVec::from_elem(None, &body.local_decls); + + for local in body.local_decls.indices() { + if escaping.contains(local) { + continue; + } + let decl = body.local_decls[local].clone(); + let ty = decl.ty; + iter_fields(ty, tcx, typing_env, |variant, field, field_ty| { + if variant.is_some() { + // Downcasts are currently not supported. + return; + }; + let new_local = + body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() }); + fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local)); + }); + } + ReplacementMap { fragments } +} + +/// Perform the replacement computed by `compute_flattening`. +fn replace_flattened_locals<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + replacements: ReplacementMap<'tcx>, +) -> DenseBitSet<Local> { + let mut all_dead_locals = DenseBitSet::new_empty(replacements.fragments.len()); + for (local, replacements) in replacements.fragments.iter_enumerated() { + if replacements.is_some() { + all_dead_locals.insert(local); + } + } + debug!(?all_dead_locals); + if all_dead_locals.is_empty() { + return all_dead_locals; + } + + let mut visitor = ReplacementVisitor { + tcx, + local_decls: &body.local_decls, + replacements: &replacements, + all_dead_locals, + patch: MirPatch::new(body), + }; + for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { + visitor.visit_basic_block_data(bb, data); + } + for scope in &mut body.source_scopes { + visitor.visit_source_scope_data(scope); + } + for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() { + visitor.visit_user_type_annotation(index, annotation); + } + visitor.expand_var_debug_info(&mut body.var_debug_info); + let ReplacementVisitor { patch, all_dead_locals, .. } = visitor; + patch.apply(body); + all_dead_locals +} + +struct ReplacementVisitor<'tcx, 'll> { + tcx: TyCtxt<'tcx>, + /// This is only used to compute the type for `VarDebugInfoFragment`. + local_decls: &'ll LocalDecls<'tcx>, + /// Work to do. + replacements: &'ll ReplacementMap<'tcx>, + /// This is used to check that we are not leaving references to replaced locals behind. + all_dead_locals: DenseBitSet<Local>, + patch: MirPatch<'tcx>, +} + +impl<'tcx> ReplacementVisitor<'tcx, '_> { + #[instrument(level = "trace", skip(self))] + fn expand_var_debug_info(&mut self, var_debug_info: &mut Vec<VarDebugInfo<'tcx>>) { + var_debug_info.flat_map_in_place(|mut var_debug_info| { + let place = match var_debug_info.value { + VarDebugInfoContents::Const(_) => return vec![var_debug_info], + VarDebugInfoContents::Place(ref mut place) => place, + }; + + if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) { + *place = repl; + return vec![var_debug_info]; + } + + let Some(parts) = self.replacements.place_fragments(*place) else { + return vec![var_debug_info]; + }; + + let ty = place.ty(self.local_decls, self.tcx).ty; + + parts + .map(|(field, field_ty, replacement_local)| { + let mut var_debug_info = var_debug_info.clone(); + let composite = var_debug_info.composite.get_or_insert_with(|| { + Box::new(VarDebugInfoFragment { ty, projection: Vec::new() }) + }); + composite.projection.push(PlaceElem::Field(field, field_ty)); + + var_debug_info.value = VarDebugInfoContents::Place(replacement_local.into()); + var_debug_info + }) + .collect() + }); + } +} + +impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) { + *place = repl + } else { + self.super_place(place, context, location) + } + } + + #[instrument(level = "trace", skip(self))] + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + match statement.kind { + // Duplicate storage and deinit statements, as they pretty much apply to all fields. + StatementKind::StorageLive(l) => { + if let Some(final_locals) = self.replacements.place_fragments(l.into()) { + for (_, _, fl) in final_locals { + self.patch.add_statement(location, StatementKind::StorageLive(fl)); + } + statement.make_nop(); + } + return; + } + StatementKind::StorageDead(l) => { + if let Some(final_locals) = self.replacements.place_fragments(l.into()) { + for (_, _, fl) in final_locals { + self.patch.add_statement(location, StatementKind::StorageDead(fl)); + } + statement.make_nop(); + } + return; + } + StatementKind::Deinit(box place) => { + if let Some(final_locals) = self.replacements.place_fragments(place) { + for (_, _, fl) in final_locals { + self.patch + .add_statement(location, StatementKind::Deinit(Box::new(fl.into()))); + } + statement.make_nop(); + return; + } + } + + // We have `a = Struct { 0: x, 1: y, .. }`. + // We replace it by + // ``` + // a_0 = x + // a_1 = y + // ... + // ``` + StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => { + if let Some(local) = place.as_local() + && let Some(final_locals) = &self.replacements.fragments[local] + { + // This is ok as we delete the statement later. + let operands = std::mem::take(operands); + for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) { + if let Some((_, new_local)) = opt_ty_local { + // Replace mentions of SROA'd locals that appear in the operand. + self.visit_operand(&mut operand, location); + + let rvalue = Rvalue::Use(operand); + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + } + statement.make_nop(); + return; + } + } + + // We have `a = some constant` + // We add the projections. + // ``` + // a_0 = a.0 + // a_1 = a.1 + // ... + // ``` + // ConstProp will pick up the pieces and replace them by actual constants. + StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => { + if let Some(final_locals) = self.replacements.place_fragments(place) { + // Put the deaggregated statements *after* the original one. + let location = location.successor_within_block(); + for (field, ty, new_local) in final_locals { + let rplace = self.tcx.mk_place_field(place, field, ty); + let rvalue = Rvalue::Use(Operand::Move(rplace)); + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + // We still need `place.local` to exist, so don't make it nop. + return; + } + } + + // We have `a = move? place` + // We replace it by + // ``` + // a_0 = move? place.0 + // a_1 = move? place.1 + // ... + // ``` + StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => { + let (rplace, copy) = match *op { + Operand::Copy(rplace) => (rplace, true), + Operand::Move(rplace) => (rplace, false), + Operand::Constant(_) => bug!(), + }; + if let Some(final_locals) = self.replacements.place_fragments(lhs) { + for (field, ty, new_local) in final_locals { + let rplace = self.tcx.mk_place_field(rplace, field, ty); + debug!(?rplace); + let rplace = self + .replacements + .replace_place(self.tcx, rplace.as_ref()) + .unwrap_or(rplace); + debug!(?rplace); + let rvalue = if copy { + Rvalue::Use(Operand::Copy(rplace)) + } else { + Rvalue::Use(Operand::Move(rplace)) + }; + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + statement.make_nop(); + return; + } + } + + _ => {} + } + self.super_statement(statement, location) + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + assert!(!self.all_dead_locals.contains(*local)); + } +} diff --git a/compiler/rustc_mir_transform/src/ssa.rs b/compiler/rustc_mir_transform/src/ssa.rs new file mode 100644 index 00000000000..03b6f9b7ff3 --- /dev/null +++ b/compiler/rustc_mir_transform/src/ssa.rs @@ -0,0 +1,407 @@ +//! We denote as "SSA" the set of locals that verify the following properties: +//! 1/ They are only assigned-to once, either as a function parameter, or in an assign statement; +//! 2/ This single assignment dominates all uses; +//! +//! As we do not track indirect assignments, a local that has its address taken (via a borrow or raw +//! borrow operator) is considered non-SSA. However, it is UB to modify through an immutable borrow +//! of a `Freeze` local. Those can still be considered to be SSA. + +use rustc_data_structures::graph::dominators::Dominators; +use rustc_index::bit_set::DenseBitSet; +use rustc_index::{IndexSlice, IndexVec}; +use rustc_middle::bug; +use rustc_middle::middle::resolve_bound_vars::Set1; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, TyCtxt}; +use tracing::{debug, instrument, trace}; + +pub(super) struct SsaLocals { + /// Assignments to each local. This defines whether the local is SSA. + assignments: IndexVec<Local, Set1<DefLocation>>, + /// We visit the body in reverse postorder, to ensure each local is assigned before it is used. + /// We remember the order in which we saw the assignments to compute the SSA values in a single + /// pass. + assignment_order: Vec<Local>, + /// Copy equivalence classes between locals. See `copy_classes` for documentation. + copy_classes: IndexVec<Local, Local>, + /// Number of "direct" uses of each local, ie. uses that are not dereferences. + /// We ignore non-uses (Storage statements, debuginfo). + direct_uses: IndexVec<Local, u32>, + /// Set of SSA locals that are immutably borrowed. + borrowed_locals: DenseBitSet<Local>, +} + +impl SsaLocals { + pub(super) fn new<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + ) -> SsaLocals { + let assignment_order = Vec::with_capacity(body.local_decls.len()); + + let assignments = IndexVec::from_elem(Set1::Empty, &body.local_decls); + let dominators = body.basic_blocks.dominators(); + + let direct_uses = IndexVec::from_elem(0, &body.local_decls); + let borrowed_locals = DenseBitSet::new_empty(body.local_decls.len()); + let mut visitor = SsaVisitor { + body, + assignments, + assignment_order, + dominators, + direct_uses, + borrowed_locals, + }; + + for local in body.args_iter() { + visitor.assignments[local] = Set1::One(DefLocation::Argument); + visitor.assignment_order.push(local); + } + + // For SSA assignments, a RPO visit will see the assignment before it sees any use. + // We only visit reachable nodes: computing `dominates` on an unreachable node ICEs. + for (bb, data) in traversal::reverse_postorder(body) { + visitor.visit_basic_block_data(bb, data); + } + + for var_debug_info in &body.var_debug_info { + visitor.visit_var_debug_info(var_debug_info); + } + + // The immutability of shared borrows only works on `Freeze` locals. If the visitor found + // borrows, we need to check the types. For raw pointers and mutable borrows, the locals + // have already been marked as non-SSA. + debug!(?visitor.borrowed_locals); + for local in visitor.borrowed_locals.iter() { + if !body.local_decls[local].ty.is_freeze(tcx, typing_env) { + visitor.assignments[local] = Set1::Many; + } + } + + debug!(?visitor.assignments); + debug!(?visitor.direct_uses); + + visitor + .assignment_order + .retain(|&local| matches!(visitor.assignments[local], Set1::One(_))); + debug!(?visitor.assignment_order); + + let mut ssa = SsaLocals { + assignments: visitor.assignments, + assignment_order: visitor.assignment_order, + direct_uses: visitor.direct_uses, + borrowed_locals: visitor.borrowed_locals, + // This is filled by `compute_copy_classes`. + copy_classes: IndexVec::default(), + }; + compute_copy_classes(&mut ssa, body); + ssa + } + + pub(super) fn num_locals(&self) -> usize { + self.assignments.len() + } + + pub(super) fn locals(&self) -> impl Iterator<Item = Local> { + self.assignments.indices() + } + + pub(super) fn is_ssa(&self, local: Local) -> bool { + matches!(self.assignments[local], Set1::One(_)) + } + + /// Return the number of uses if a local that are not "Deref". + pub(super) fn num_direct_uses(&self, local: Local) -> u32 { + self.direct_uses[local] + } + + #[inline] + pub(super) fn assignment_dominates( + &self, + dominators: &Dominators<BasicBlock>, + local: Local, + location: Location, + ) -> bool { + match self.assignments[local] { + Set1::One(def) => def.dominates(location, dominators), + _ => false, + } + } + + pub(super) fn assignments<'a, 'tcx>( + &'a self, + body: &'a Body<'tcx>, + ) -> impl Iterator<Item = (Local, &'a Rvalue<'tcx>, Location)> { + self.assignment_order.iter().filter_map(|&local| { + if let Set1::One(DefLocation::Assignment(loc)) = self.assignments[local] { + let stmt = body.stmt_at(loc).left()?; + // `loc` must point to a direct assignment to `local`. + let Some((target, rvalue)) = stmt.kind.as_assign() else { bug!() }; + assert_eq!(target.as_local(), Some(local)); + Some((local, rvalue, loc)) + } else { + None + } + }) + } + + /// Compute the equivalence classes for locals, based on copy statements. + /// + /// The returned vector maps each local to the one it copies. In the following case: + /// _a = &mut _0 + /// _b = move? _a + /// _c = move? _a + /// _d = move? _c + /// We return the mapping + /// _a => _a // not a copy so, represented by itself + /// _b => _a + /// _c => _a + /// _d => _a // transitively through _c + /// + /// Exception: we do not see through the return place, as it cannot be instantiated. + pub(super) fn copy_classes(&self) -> &IndexSlice<Local, Local> { + &self.copy_classes + } + + /// Set of SSA locals that are immutably borrowed. + pub(super) fn borrowed_locals(&self) -> &DenseBitSet<Local> { + &self.borrowed_locals + } + + /// Make a property uniform on a copy equivalence class by removing elements. + pub(super) fn meet_copy_equivalence(&self, property: &mut DenseBitSet<Local>) { + // Consolidate to have a local iff all its copies are. + // + // `copy_classes` defines equivalence classes between locals. The `local`s that recursively + // move/copy the same local all have the same `head`. + for (local, &head) in self.copy_classes.iter_enumerated() { + // If any copy does not have `property`, then the head is not. + if !property.contains(local) { + property.remove(head); + } + } + for (local, &head) in self.copy_classes.iter_enumerated() { + // If any copy does not have `property`, then the head doesn't either, + // then no copy has `property`. + if !property.contains(head) { + property.remove(local); + } + } + + // Verify that we correctly computed equivalence classes. + #[cfg(debug_assertions)] + for (local, &head) in self.copy_classes.iter_enumerated() { + assert_eq!(property.contains(local), property.contains(head)); + } + } +} + +struct SsaVisitor<'a, 'tcx> { + body: &'a Body<'tcx>, + dominators: &'a Dominators<BasicBlock>, + assignments: IndexVec<Local, Set1<DefLocation>>, + assignment_order: Vec<Local>, + direct_uses: IndexVec<Local, u32>, + // Track locals that are immutably borrowed, so we can check their type is `Freeze` later. + borrowed_locals: DenseBitSet<Local>, +} + +impl SsaVisitor<'_, '_> { + fn check_dominates(&mut self, local: Local, loc: Location) { + let set = &mut self.assignments[local]; + let assign_dominates = match *set { + Set1::Empty | Set1::Many => false, + Set1::One(def) => def.dominates(loc, self.dominators), + }; + // We are visiting a use that is not dominated by an assignment. + // Either there is a cycle involved, or we are reading for uninitialized local. + // Bail out. + if !assign_dominates { + *set = Set1::Many; + } + } +} + +impl<'tcx> Visitor<'tcx> for SsaVisitor<'_, 'tcx> { + fn visit_local(&mut self, local: Local, ctxt: PlaceContext, loc: Location) { + match ctxt { + PlaceContext::MutatingUse(MutatingUseContext::Projection) + | PlaceContext::NonMutatingUse(NonMutatingUseContext::Projection) => bug!(), + // Anything can happen with raw pointers, so remove them. + PlaceContext::NonMutatingUse(NonMutatingUseContext::RawBorrow) + | PlaceContext::MutatingUse(_) => { + self.assignments[local] = Set1::Many; + } + // Immutable borrows are ok, but we need to delay a check that the type is `Freeze`. + PlaceContext::NonMutatingUse( + NonMutatingUseContext::SharedBorrow | NonMutatingUseContext::FakeBorrow, + ) => { + self.borrowed_locals.insert(local); + self.check_dominates(local, loc); + self.direct_uses[local] += 1; + } + PlaceContext::NonMutatingUse(_) => { + self.check_dominates(local, loc); + self.direct_uses[local] += 1; + } + PlaceContext::NonUse(_) => {} + } + } + + fn visit_place(&mut self, place: &Place<'tcx>, ctxt: PlaceContext, loc: Location) { + let location = match ctxt { + PlaceContext::MutatingUse(MutatingUseContext::Store) => { + Some(DefLocation::Assignment(loc)) + } + PlaceContext::MutatingUse(MutatingUseContext::Call) => { + let call = loc.block; + let TerminatorKind::Call { target, .. } = + self.body.basic_blocks[call].terminator().kind + else { + bug!() + }; + Some(DefLocation::CallReturn { call, target }) + } + _ => None, + }; + if let Some(location) = location + && let Some(local) = place.as_local() + { + self.assignments[local].insert(location); + if let Set1::One(_) = self.assignments[local] { + // Only record if SSA-like, to avoid growing the vector needlessly. + self.assignment_order.push(local); + } + } else if place.projection.first() == Some(&PlaceElem::Deref) { + // Do not do anything for debuginfo. + if ctxt.is_use() { + // Only change the context if it is a real use, not a "use" in debuginfo. + let new_ctxt = PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy); + + self.visit_projection(place.as_ref(), new_ctxt, loc); + self.check_dominates(place.local, loc); + } + } else { + self.visit_projection(place.as_ref(), ctxt, loc); + self.visit_local(place.local, ctxt, loc); + } + } +} + +#[instrument(level = "trace", skip(ssa, body))] +fn compute_copy_classes(ssa: &mut SsaLocals, body: &Body<'_>) { + let mut direct_uses = std::mem::take(&mut ssa.direct_uses); + let mut copies = IndexVec::from_fn_n(|l| l, body.local_decls.len()); + // We must not unify two locals that are borrowed. But this is fine if one is borrowed and + // the other is not. This bitset is keyed by *class head* and contains whether any member of + // the class is borrowed. + let mut borrowed_classes = ssa.borrowed_locals().clone(); + + for (local, rvalue, _) in ssa.assignments(body) { + let (Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) + | Rvalue::CopyForDeref(place)) = rvalue + else { + continue; + }; + + let Some(rhs) = place.as_local() else { continue }; + let local_ty = body.local_decls()[local].ty; + let rhs_ty = body.local_decls()[rhs].ty; + if local_ty != rhs_ty { + // FIXME(#112651): This can be removed afterwards. + trace!("skipped `{local:?} = {rhs:?}` due to subtyping: {local_ty} != {rhs_ty}"); + continue; + } + + if !ssa.is_ssa(rhs) { + continue; + } + + // We visit in `assignment_order`, ie. reverse post-order, so `rhs` has been + // visited before `local`, and we just have to copy the representing local. + let head = copies[rhs]; + + // Do not unify two borrowed locals. + if borrowed_classes.contains(local) && borrowed_classes.contains(head) { + continue; + } + + if local == RETURN_PLACE { + // `_0` is special, we cannot rename it. Instead, rename the class of `rhs` to + // `RETURN_PLACE`. This is only possible if the class head is a temporary, not an + // argument. + if body.local_kind(head) != LocalKind::Temp { + continue; + } + for h in copies.iter_mut() { + if *h == head { + *h = RETURN_PLACE; + } + } + if borrowed_classes.contains(head) { + borrowed_classes.insert(RETURN_PLACE); + } + } else { + copies[local] = head; + if borrowed_classes.contains(local) { + borrowed_classes.insert(head); + } + } + direct_uses[rhs] -= 1; + } + + debug!(?copies); + debug!(?direct_uses); + debug!(?borrowed_classes); + + // Invariant: `copies` must point to the head of an equivalence class. + #[cfg(debug_assertions)] + for &head in copies.iter() { + assert_eq!(copies[head], head); + } + debug_assert_eq!(copies[RETURN_PLACE], RETURN_PLACE); + + // Invariant: `borrowed_classes` must be true if any member of the class is borrowed. + #[cfg(debug_assertions)] + for &head in copies.iter() { + let any_borrowed = ssa.borrowed_locals.iter().any(|l| copies[l] == head); + assert_eq!(borrowed_classes.contains(head), any_borrowed); + } + + ssa.direct_uses = direct_uses; + ssa.copy_classes = copies; +} + +#[derive(Debug)] +pub(crate) struct StorageLiveLocals { + /// Set of "StorageLive" statements for each local. + storage_live: IndexVec<Local, Set1<DefLocation>>, +} + +impl StorageLiveLocals { + pub(crate) fn new( + body: &Body<'_>, + always_storage_live_locals: &DenseBitSet<Local>, + ) -> StorageLiveLocals { + let mut storage_live = IndexVec::from_elem(Set1::Empty, &body.local_decls); + for local in always_storage_live_locals.iter() { + storage_live[local] = Set1::One(DefLocation::Argument); + } + for (block, bbdata) in body.basic_blocks.iter_enumerated() { + for (statement_index, statement) in bbdata.statements.iter().enumerate() { + if let StatementKind::StorageLive(local) = statement.kind { + storage_live[local] + .insert(DefLocation::Assignment(Location { block, statement_index })); + } + } + } + debug!(?storage_live); + StorageLiveLocals { storage_live } + } + + #[inline] + pub(crate) fn has_single_storage(&self, local: Local) -> bool { + matches!(self.storage_live[local], Set1::One(_)) + } +} diff --git a/compiler/rustc_mir_transform/src/strip_debuginfo.rs b/compiler/rustc_mir_transform/src/strip_debuginfo.rs new file mode 100644 index 00000000000..9ede8aa79c4 --- /dev/null +++ b/compiler/rustc_mir_transform/src/strip_debuginfo.rs @@ -0,0 +1,38 @@ +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_session::config::MirStripDebugInfo; + +/// Conditionally remove some of the VarDebugInfo in MIR. +/// +/// In particular, stripping non-parameter debug info for tiny, primitive-like +/// methods in core saves work later, and nobody ever wanted to use it anyway. +pub(super) struct StripDebugInfo; + +impl<'tcx> crate::MirPass<'tcx> for StripDebugInfo { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.opts.unstable_opts.mir_strip_debuginfo != MirStripDebugInfo::None + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + match tcx.sess.opts.unstable_opts.mir_strip_debuginfo { + MirStripDebugInfo::None => return, + MirStripDebugInfo::AllLocals => {} + MirStripDebugInfo::LocalsInTinyFunctions + if let TerminatorKind::Return { .. } = + body.basic_blocks[START_BLOCK].terminator().kind => {} + MirStripDebugInfo::LocalsInTinyFunctions => return, + } + + body.var_debug_info.retain(|vdi| { + matches!( + vdi.value, + VarDebugInfoContents::Place(place) + if place.local.as_usize() <= body.arg_count && place.local != RETURN_PLACE, + ) + }); + } + + fn is_required(&self) -> bool { + true + } +} diff --git a/compiler/rustc_mir_transform/src/unreachable_enum_branching.rs b/compiler/rustc_mir_transform/src/unreachable_enum_branching.rs new file mode 100644 index 00000000000..6ccec5b6f21 --- /dev/null +++ b/compiler/rustc_mir_transform/src/unreachable_enum_branching.rs @@ -0,0 +1,216 @@ +//! A pass that eliminates branches on uninhabited or unreachable enum variants. + +use rustc_abi::Variants; +use rustc_data_structures::fx::FxHashSet; +use rustc_middle::bug; +use rustc_middle::mir::{ + BasicBlock, BasicBlockData, BasicBlocks, Body, Local, Operand, Rvalue, StatementKind, + TerminatorKind, +}; +use rustc_middle::ty::layout::TyAndLayout; +use rustc_middle::ty::{Ty, TyCtxt}; +use tracing::trace; + +use crate::patch::MirPatch; + +pub(super) struct UnreachableEnumBranching; + +fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option<Local> { + if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator { + p.as_local() + } else { + None + } +} + +/// If the basic block terminates by switching on a discriminant, this returns the `Ty` the +/// discriminant is read from. Otherwise, returns None. +fn get_switched_on_type<'tcx>( + block_data: &BasicBlockData<'tcx>, + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, +) -> Option<Ty<'tcx>> { + let terminator = block_data.terminator(); + + // Only bother checking blocks which terminate by switching on a local. + let local = get_discriminant_local(&terminator.kind)?; + + let stmt_before_term = block_data.statements.last()?; + + if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind + && l.as_local() == Some(local) + { + let ty = place.ty(body, tcx).ty; + if ty.is_enum() { + return Some(ty); + } + } + + None +} + +fn variant_discriminants<'tcx>( + layout: &TyAndLayout<'tcx>, + ty: Ty<'tcx>, + tcx: TyCtxt<'tcx>, +) -> FxHashSet<u128> { + match &layout.variants { + Variants::Empty => { + // Uninhabited, no valid discriminant. + FxHashSet::default() + } + Variants::Single { index } => { + let mut res = FxHashSet::default(); + res.insert( + ty.discriminant_for_variant(tcx, *index) + .map_or(index.as_u32() as u128, |discr| discr.val), + ); + res + } + Variants::Multiple { variants, .. } => variants + .iter_enumerated() + .filter_map(|(idx, layout)| { + (!layout.is_uninhabited()) + .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val) + }) + .collect(), + } +} + +impl<'tcx> crate::MirPass<'tcx> for UnreachableEnumBranching { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("UnreachableEnumBranching starting for {:?}", body.source); + + let mut unreachable_targets = Vec::new(); + let mut patch = MirPatch::new(body); + + for (bb, bb_data) in body.basic_blocks.iter_enumerated() { + trace!("processing block {:?}", bb); + + if bb_data.is_cleanup { + continue; + } + + let Some(discriminant_ty) = get_switched_on_type(bb_data, tcx, body) else { continue }; + + let layout = tcx.layout_of(body.typing_env(tcx).as_query_input(discriminant_ty)); + + let mut allowed_variants = if let Ok(layout) = layout { + // Find allowed variants based on uninhabited. + variant_discriminants(&layout, discriminant_ty, tcx) + } else if let Some(variant_range) = discriminant_ty.variant_range(tcx) { + // If there are some generics, we can still get the allowed variants. + variant_range + .map(|variant| { + discriminant_ty.discriminant_for_variant(tcx, variant).unwrap().val + }) + .collect() + } else { + continue; + }; + + trace!("allowed_variants = {:?}", allowed_variants); + + unreachable_targets.clear(); + let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else { + bug!() + }; + + for (index, (val, _)) in targets.iter().enumerate() { + if !allowed_variants.remove(&val) { + unreachable_targets.push(index); + } + } + let otherwise_is_empty_unreachable = + body.basic_blocks[targets.otherwise()].is_empty_unreachable(); + fn check_successors(basic_blocks: &BasicBlocks<'_>, bb: BasicBlock) -> bool { + // After resolving https://github.com/llvm/llvm-project/issues/78578, + // We can remove this check. + // The main issue here is that `early-tailduplication` causes compile time overhead + // and potential performance problems. + // Simply put, when encounter a switch (indirect branch) statement, + // `early-tailduplication` tries to duplicate the switch branch statement with BB + // into (each) predecessors. This makes CFG very complex. + // We can understand it as it transforms the following code + // ```rust + // match a { ... many cases }; + // match b { ... many cases }; + // ``` + // into + // ```rust + // match a { ... many match b { goto BB cases } } + // ... BB cases + // ``` + // Abandon this transformation when it is possible (the best effort) + // to encounter the problem. + let mut successors = basic_blocks[bb].terminator().successors(); + let Some(first_successor) = successors.next() else { return true }; + if successors.next().is_some() { + return true; + } + if let TerminatorKind::SwitchInt { .. } = + &basic_blocks[first_successor].terminator().kind + { + return false; + }; + true + } + // If and only if there is a variant that does not have a branch set, change the + // current of otherwise as the variant branch and set otherwise to unreachable. It + // transforms following code + // ```rust + // match c { + // Ordering::Less => 1, + // Ordering::Equal => 2, + // _ => 3, + // } + // ``` + // to + // ```rust + // match c { + // Ordering::Less => 1, + // Ordering::Equal => 2, + // Ordering::Greater => 3, + // } + // ``` + let otherwise_is_last_variant = !otherwise_is_empty_unreachable + && allowed_variants.len() == 1 + // Despite the LLVM issue, we hope that small enum can still be transformed. + // This is valuable for both `a <= b` and `if let Some/Ok(v)`. + && (targets.all_targets().len() <= 3 + || check_successors(&body.basic_blocks, targets.otherwise())); + let replace_otherwise_to_unreachable = otherwise_is_last_variant + || (!otherwise_is_empty_unreachable && allowed_variants.is_empty()); + + if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable { + continue; + } + + let unreachable_block = patch.unreachable_no_cleanup_block(); + let mut targets = targets.clone(); + if replace_otherwise_to_unreachable { + if otherwise_is_last_variant { + // We have checked that `allowed_variants` has only one element. + #[allow(rustc::potential_query_instability)] + let last_variant = *allowed_variants.iter().next().unwrap(); + targets.add_target(last_variant, targets.otherwise()); + } + unreachable_targets.push(targets.iter().count()); + } + for index in unreachable_targets.iter() { + targets.all_targets_mut()[*index] = unreachable_block; + } + patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() }); + } + + patch.apply(body); + } + + fn is_required(&self) -> bool { + false + } +} diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs new file mode 100644 index 00000000000..13fb5b3e56f --- /dev/null +++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs @@ -0,0 +1,154 @@ +//! A pass that propagates the unreachable terminator of a block to its predecessors +//! when all of their successors are unreachable. This is achieved through a +//! post-order traversal of the blocks. + +use rustc_abi::Size; +use rustc_data_structures::fx::FxHashSet; +use rustc_middle::bug; +use rustc_middle::mir::interpret::Scalar; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, TyCtxt}; + +use crate::patch::MirPatch; + +pub(super) struct UnreachablePropagation; + +impl crate::MirPass<'_> for UnreachablePropagation { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + // Enable only under -Zmir-opt-level=2 as this can make programs less debuggable. + sess.mir_opt_level() >= 2 + } + + fn run_pass<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut patch = MirPatch::new(body); + let mut unreachable_blocks = FxHashSet::default(); + + for (bb, bb_data) in traversal::postorder(body) { + let terminator = bb_data.terminator(); + let is_unreachable = match &terminator.kind { + TerminatorKind::Unreachable => true, + // This will unconditionally run into an unreachable and is therefore unreachable + // as well. + TerminatorKind::Goto { target } if unreachable_blocks.contains(target) => { + patch.patch_terminator(bb, TerminatorKind::Unreachable); + true + } + // Try to remove unreachable targets from the switch. + TerminatorKind::SwitchInt { .. } => { + remove_successors_from_switch(tcx, bb, &unreachable_blocks, body, &mut patch) + } + _ => false, + }; + if is_unreachable { + unreachable_blocks.insert(bb); + } + } + + patch.apply(body); + + // We do want do keep some unreachable blocks, but make them empty. + // The order in which we clear bb statements does not matter. + #[allow(rustc::potential_query_instability)] + for bb in unreachable_blocks { + body.basic_blocks_mut()[bb].statements.clear(); + } + } + + fn is_required(&self) -> bool { + false + } +} + +/// Return whether the current terminator is fully unreachable. +fn remove_successors_from_switch<'tcx>( + tcx: TyCtxt<'tcx>, + bb: BasicBlock, + unreachable_blocks: &FxHashSet<BasicBlock>, + body: &Body<'tcx>, + patch: &mut MirPatch<'tcx>, +) -> bool { + let terminator = body.basic_blocks[bb].terminator(); + let TerminatorKind::SwitchInt { discr, targets } = &terminator.kind else { bug!() }; + let source_info = terminator.source_info; + let location = body.terminator_loc(bb); + + let is_unreachable = |bb| unreachable_blocks.contains(&bb); + + // If there are multiple targets, we want to keep information about reachability for codegen. + // For example (see tests/codegen/match-optimizes-away.rs) + // + // pub enum Two { A, B } + // pub fn identity(x: Two) -> Two { + // match x { + // Two::A => Two::A, + // Two::B => Two::B, + // } + // } + // + // This generates a `switchInt() -> [0: 0, 1: 1, otherwise: unreachable]`, which allows us or + // LLVM to turn it into just `x` later. Without the unreachable, such a transformation would be + // illegal. + // + // In order to preserve this information, we record reachable and unreachable targets as + // `Assume` statements in MIR. + + let discr_ty = discr.ty(body, tcx); + let discr_size = Size::from_bits(match discr_ty.kind() { + ty::Uint(uint) => uint.normalize(tcx.sess.target.pointer_width).bit_width().unwrap(), + ty::Int(int) => int.normalize(tcx.sess.target.pointer_width).bit_width().unwrap(), + ty::Char => 32, + ty::Bool => 1, + other => bug!("unhandled type: {:?}", other), + }); + + let mut add_assumption = |binop, value| { + let local = patch.new_temp(tcx.types.bool, source_info.span); + let value = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::from_scalar(tcx, Scalar::from_uint(value, discr_size), discr_ty), + })); + let cmp = Rvalue::BinaryOp(binop, Box::new((discr.to_copy(), value))); + patch.add_assign(location, local.into(), cmp); + + let assume = NonDivergingIntrinsic::Assume(Operand::Move(local.into())); + patch.add_statement(location, StatementKind::Intrinsic(Box::new(assume))); + }; + + let otherwise = targets.otherwise(); + let otherwise_unreachable = is_unreachable(otherwise); + + let reachable_iter = targets.iter().filter(|&(value, bb)| { + let is_unreachable = is_unreachable(bb); + // We remove this target from the switch, so record the inequality using `Assume`. + if is_unreachable && !otherwise_unreachable { + add_assumption(BinOp::Ne, value); + } + !is_unreachable + }); + + let new_targets = SwitchTargets::new(reachable_iter, otherwise); + + let num_targets = new_targets.all_targets().len(); + let fully_unreachable = num_targets == 1 && otherwise_unreachable; + + let terminator = match (num_targets, otherwise_unreachable) { + // If all targets are unreachable, we can be unreachable as well. + (1, true) => TerminatorKind::Unreachable, + (1, false) => TerminatorKind::Goto { target: otherwise }, + (2, true) => { + // All targets are unreachable except one. Record the equality, and make it a goto. + let (value, target) = new_targets.iter().next().unwrap(); + add_assumption(BinOp::Eq, value); + TerminatorKind::Goto { target } + } + _ if num_targets == targets.all_targets().len() => { + // Nothing has changed. + return false; + } + _ => TerminatorKind::SwitchInt { discr: discr.clone(), targets: new_targets }, + }; + + patch.patch_terminator(bb, terminator); + fully_unreachable +} diff --git a/compiler/rustc_mir_transform/src/validate.rs b/compiler/rustc_mir_transform/src/validate.rs new file mode 100644 index 00000000000..7dcdd7999f2 --- /dev/null +++ b/compiler/rustc_mir_transform/src/validate.rs @@ -0,0 +1,1606 @@ +//! Validates the MIR to ensure that invariants are upheld. + +use rustc_abi::{ExternAbi, FIRST_VARIANT, Size}; +use rustc_attr_data_structures::InlineAttr; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_hir::LangItem; +use rustc_index::IndexVec; +use rustc_index::bit_set::DenseBitSet; +use rustc_infer::infer::TyCtxtInferExt; +use rustc_infer::traits::{Obligation, ObligationCause}; +use rustc_middle::mir::coverage::CoverageKind; +use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::adjustment::PointerCoercion; +use rustc_middle::ty::print::with_no_trimmed_paths; +use rustc_middle::ty::{ + self, CoroutineArgsExt, InstanceKind, ScalarInt, Ty, TyCtxt, TypeVisitableExt, Upcast, Variance, +}; +use rustc_middle::{bug, span_bug}; +use rustc_trait_selection::traits::ObligationCtxt; + +use crate::util::{self, is_within_packed}; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum EdgeKind { + Unwind, + Normal, +} + +pub(super) struct Validator { + /// Describes at which point in the pipeline this validation is happening. + pub when: String, +} + +impl<'tcx> crate::MirPass<'tcx> for Validator { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // FIXME(JakobDegen): These bodies never instantiated in codegend anyway, so it's not + // terribly important that they pass the validator. However, I think other passes might + // still see them, in which case they might be surprised. It would probably be better if we + // didn't put this through the MIR pipeline at all. + if matches!(body.source.instance, InstanceKind::Intrinsic(..) | InstanceKind::Virtual(..)) { + return; + } + let def_id = body.source.def_id(); + let typing_env = body.typing_env(tcx); + let can_unwind = if body.phase <= MirPhase::Runtime(RuntimePhase::Initial) { + // In this case `AbortUnwindingCalls` haven't yet been executed. + true + } else if !tcx.def_kind(def_id).is_fn_like() { + true + } else { + let body_ty = tcx.type_of(def_id).skip_binder(); + let body_abi = match body_ty.kind() { + ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), + ty::Closure(..) => ExternAbi::RustCall, + ty::CoroutineClosure(..) => ExternAbi::RustCall, + ty::Coroutine(..) => ExternAbi::Rust, + // No need to do MIR validation on error bodies + ty::Error(_) => return, + _ => span_bug!(body.span, "unexpected body ty: {body_ty}"), + }; + + ty::layout::fn_can_unwind(tcx, Some(def_id), body_abi) + }; + + let mut cfg_checker = CfgChecker { + when: &self.when, + body, + tcx, + unwind_edge_count: 0, + reachable_blocks: traversal::reachable_as_bitset(body), + value_cache: FxHashSet::default(), + can_unwind, + }; + cfg_checker.visit_body(body); + cfg_checker.check_cleanup_control_flow(); + + // Also run the TypeChecker. + for (location, msg) in validate_types(tcx, typing_env, body, body) { + cfg_checker.fail(location, msg); + } + + if let MirPhase::Runtime(_) = body.phase { + if let ty::InstanceKind::Item(_) = body.source.instance { + if body.has_free_regions() { + cfg_checker.fail( + Location::START, + format!("Free regions in optimized {} MIR", body.phase.name()), + ); + } + } + } + } + + fn is_required(&self) -> bool { + true + } +} + +/// This checker covers basic properties of the control-flow graph, (dis)allowed statements and terminators. +/// Everything checked here must be stable under substitution of generic parameters. In other words, +/// this is about the *structure* of the MIR, not the *contents*. +/// +/// Everything that depends on types, or otherwise can be affected by generic parameters, +/// must be checked in `TypeChecker`. +struct CfgChecker<'a, 'tcx> { + when: &'a str, + body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, + unwind_edge_count: usize, + reachable_blocks: DenseBitSet<BasicBlock>, + value_cache: FxHashSet<u128>, + // If `false`, then the MIR must not contain `UnwindAction::Continue` or + // `TerminatorKind::Resume`. + can_unwind: bool, +} + +impl<'a, 'tcx> CfgChecker<'a, 'tcx> { + #[track_caller] + fn fail(&self, location: Location, msg: impl AsRef<str>) { + // We might see broken MIR when other errors have already occurred. + assert!( + self.tcx.dcx().has_errors().is_some(), + "broken MIR in {:?} ({}) at {:?}:\n{}", + self.body.source.instance, + self.when, + location, + msg.as_ref(), + ); + } + + fn check_edge(&mut self, location: Location, bb: BasicBlock, edge_kind: EdgeKind) { + if bb == START_BLOCK { + self.fail(location, "start block must not have predecessors") + } + if let Some(bb) = self.body.basic_blocks.get(bb) { + let src = self.body.basic_blocks.get(location.block).unwrap(); + match (src.is_cleanup, bb.is_cleanup, edge_kind) { + // Non-cleanup blocks can jump to non-cleanup blocks along non-unwind edges + (false, false, EdgeKind::Normal) + // Cleanup blocks can jump to cleanup blocks along non-unwind edges + | (true, true, EdgeKind::Normal) => {} + // Non-cleanup blocks can jump to cleanup blocks along unwind edges + (false, true, EdgeKind::Unwind) => { + self.unwind_edge_count += 1; + } + // All other jumps are invalid + _ => { + self.fail( + location, + format!( + "{:?} edge to {:?} violates unwind invariants (cleanup {:?} -> {:?})", + edge_kind, + bb, + src.is_cleanup, + bb.is_cleanup, + ) + ) + } + } + } else { + self.fail(location, format!("encountered jump to invalid basic block {bb:?}")) + } + } + + fn check_cleanup_control_flow(&self) { + if self.unwind_edge_count <= 1 { + return; + } + let doms = self.body.basic_blocks.dominators(); + let mut post_contract_node = FxHashMap::default(); + // Reusing the allocation across invocations of the closure + let mut dom_path = vec![]; + let mut get_post_contract_node = |mut bb| { + let root = loop { + if let Some(root) = post_contract_node.get(&bb) { + break *root; + } + let parent = doms.immediate_dominator(bb).unwrap(); + dom_path.push(bb); + if !self.body.basic_blocks[parent].is_cleanup { + break bb; + } + bb = parent; + }; + for bb in dom_path.drain(..) { + post_contract_node.insert(bb, root); + } + root + }; + + let mut parent = IndexVec::from_elem(None, &self.body.basic_blocks); + for (bb, bb_data) in self.body.basic_blocks.iter_enumerated() { + if !bb_data.is_cleanup || !self.reachable_blocks.contains(bb) { + continue; + } + let bb = get_post_contract_node(bb); + for s in bb_data.terminator().successors() { + let s = get_post_contract_node(s); + if s == bb { + continue; + } + let parent = &mut parent[bb]; + match parent { + None => { + *parent = Some(s); + } + Some(e) if *e == s => (), + Some(e) => self.fail( + Location { block: bb, statement_index: 0 }, + format!( + "Cleanup control flow violation: The blocks dominated by {:?} have edges to both {:?} and {:?}", + bb, + s, + *e + ) + ), + } + } + } + + // Check for cycles + let mut stack = FxHashSet::default(); + for (mut bb, parent) in parent.iter_enumerated_mut() { + stack.clear(); + stack.insert(bb); + loop { + let Some(parent) = parent.take() else { break }; + let no_cycle = stack.insert(parent); + if !no_cycle { + self.fail( + Location { block: bb, statement_index: 0 }, + format!( + "Cleanup control flow violation: Cycle involving edge {bb:?} -> {parent:?}", + ), + ); + break; + } + bb = parent; + } + } + } + + fn check_unwind_edge(&mut self, location: Location, unwind: UnwindAction) { + let is_cleanup = self.body.basic_blocks[location.block].is_cleanup; + match unwind { + UnwindAction::Cleanup(unwind) => { + if is_cleanup { + self.fail(location, "`UnwindAction::Cleanup` in cleanup block"); + } + self.check_edge(location, unwind, EdgeKind::Unwind); + } + UnwindAction::Continue => { + if is_cleanup { + self.fail(location, "`UnwindAction::Continue` in cleanup block"); + } + + if !self.can_unwind { + self.fail(location, "`UnwindAction::Continue` in no-unwind function"); + } + } + UnwindAction::Terminate(UnwindTerminateReason::InCleanup) => { + if !is_cleanup { + self.fail( + location, + "`UnwindAction::Terminate(InCleanup)` in a non-cleanup block", + ); + } + } + // These are allowed everywhere. + UnwindAction::Unreachable | UnwindAction::Terminate(UnwindTerminateReason::Abi) => (), + } + } + + fn is_critical_call_edge(&self, target: Option<BasicBlock>, unwind: UnwindAction) -> bool { + let Some(target) = target else { return false }; + matches!(unwind, UnwindAction::Cleanup(_) | UnwindAction::Terminate(_)) + && self.body.basic_blocks.predecessors()[target].len() > 1 + } +} + +impl<'a, 'tcx> Visitor<'tcx> for CfgChecker<'a, 'tcx> { + fn visit_local(&mut self, local: Local, _context: PlaceContext, location: Location) { + if self.body.local_decls.get(local).is_none() { + self.fail( + location, + format!("local {local:?} has no corresponding declaration in `body.local_decls`"), + ); + } + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::AscribeUserType(..) => { + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`AscribeUserType` should have been removed after drop lowering phase", + ); + } + } + StatementKind::FakeRead(..) => { + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FakeRead` should have been removed after drop lowering phase", + ); + } + } + StatementKind::SetDiscriminant { .. } => { + if self.body.phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`SetDiscriminant`is not allowed until deaggregation"); + } + } + StatementKind::Deinit(..) => { + if self.body.phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`Deinit`is not allowed until deaggregation"); + } + } + StatementKind::Retag(kind, _) => { + // FIXME(JakobDegen) The validator should check that `self.body.phase < + // DropsLowered`. However, this causes ICEs with generation of drop shims, which + // seem to fail to set their `MirPhase` correctly. + if matches!(kind, RetagKind::TwoPhase) { + self.fail(location, format!("explicit `{kind:?}` is forbidden")); + } + } + StatementKind::Coverage(kind) => { + if self.body.phase >= MirPhase::Analysis(AnalysisPhase::PostCleanup) + && let CoverageKind::BlockMarker { .. } | CoverageKind::SpanMarker { .. } = kind + { + self.fail( + location, + format!("{kind:?} should have been removed after analysis"), + ); + } + } + StatementKind::Assign(..) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter + | StatementKind::PlaceMention(..) + | StatementKind::BackwardIncompatibleDropHint { .. } + | StatementKind::Nop => {} + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + match &terminator.kind { + TerminatorKind::Goto { target } => { + self.check_edge(location, *target, EdgeKind::Normal); + } + TerminatorKind::SwitchInt { targets, discr: _ } => { + for (_, target) in targets.iter() { + self.check_edge(location, target, EdgeKind::Normal); + } + self.check_edge(location, targets.otherwise(), EdgeKind::Normal); + + self.value_cache.clear(); + self.value_cache.extend(targets.iter().map(|(value, _)| value)); + let has_duplicates = targets.iter().len() != self.value_cache.len(); + if has_duplicates { + self.fail( + location, + format!( + "duplicated values in `SwitchInt` terminator: {:?}", + terminator.kind, + ), + ); + } + } + TerminatorKind::Drop { target, unwind, drop, .. } => { + self.check_edge(location, *target, EdgeKind::Normal); + self.check_unwind_edge(location, *unwind); + if let Some(drop) = drop { + self.check_edge(location, *drop, EdgeKind::Normal); + } + } + TerminatorKind::Call { func, args, .. } + | TerminatorKind::TailCall { func, args, .. } => { + // FIXME(explicit_tail_calls): refactor this & add tail-call specific checks + if let TerminatorKind::Call { target, unwind, destination, .. } = terminator.kind { + if let Some(target) = target { + self.check_edge(location, target, EdgeKind::Normal); + } + self.check_unwind_edge(location, unwind); + + // The code generation assumes that there are no critical call edges. The + // assumption is used to simplify inserting code that should be executed along + // the return edge from the call. FIXME(tmiasko): Since this is a strictly code + // generation concern, the code generation should be responsible for handling + // it. + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Optimized) + && self.is_critical_call_edge(target, unwind) + { + self.fail( + location, + format!( + "encountered critical edge in `Call` terminator {:?}", + terminator.kind, + ), + ); + } + + // The call destination place and Operand::Move place used as an argument might + // be passed by a reference to the callee. Consequently they cannot be packed. + if is_within_packed(self.tcx, &self.body.local_decls, destination).is_some() { + // This is bad! The callee will expect the memory to be aligned. + self.fail( + location, + format!( + "encountered packed place in `Call` terminator destination: {:?}", + terminator.kind, + ), + ); + } + } + + for arg in args { + if let Operand::Move(place) = &arg.node { + if is_within_packed(self.tcx, &self.body.local_decls, *place).is_some() { + // This is bad! The callee will expect the memory to be aligned. + self.fail( + location, + format!( + "encountered `Move` of a packed place in `Call` terminator: {:?}", + terminator.kind, + ), + ); + } + } + } + + if let ty::FnDef(did, ..) = func.ty(&self.body.local_decls, self.tcx).kind() + && self.body.phase >= MirPhase::Runtime(RuntimePhase::Optimized) + && matches!(self.tcx.codegen_fn_attrs(did).inline, InlineAttr::Force { .. }) + { + self.fail(location, "`#[rustc_force_inline]`-annotated function not inlined"); + } + } + TerminatorKind::Assert { target, unwind, .. } => { + self.check_edge(location, *target, EdgeKind::Normal); + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::Yield { resume, drop, .. } => { + if self.body.coroutine.is_none() { + self.fail(location, "`Yield` cannot appear outside coroutine bodies"); + } + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`Yield` should have been replaced by coroutine lowering"); + } + self.check_edge(location, *resume, EdgeKind::Normal); + if let Some(drop) = drop { + self.check_edge(location, *drop, EdgeKind::Normal); + } + } + TerminatorKind::FalseEdge { real_target, imaginary_target } => { + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FalseEdge` should have been removed after drop elaboration", + ); + } + self.check_edge(location, *real_target, EdgeKind::Normal); + self.check_edge(location, *imaginary_target, EdgeKind::Normal); + } + TerminatorKind::FalseUnwind { real_target, unwind } => { + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FalseUnwind` should have been removed after drop elaboration", + ); + } + self.check_edge(location, *real_target, EdgeKind::Normal); + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::InlineAsm { targets, unwind, .. } => { + for &target in targets { + self.check_edge(location, target, EdgeKind::Normal); + } + self.check_unwind_edge(location, *unwind); + } + TerminatorKind::CoroutineDrop => { + if self.body.coroutine.is_none() { + self.fail(location, "`CoroutineDrop` cannot appear outside coroutine bodies"); + } + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`CoroutineDrop` should have been replaced by coroutine lowering", + ); + } + } + TerminatorKind::UnwindResume => { + let bb = location.block; + if !self.body.basic_blocks[bb].is_cleanup { + self.fail(location, "Cannot `UnwindResume` from non-cleanup basic block") + } + if !self.can_unwind { + self.fail(location, "Cannot `UnwindResume` in a function that cannot unwind") + } + } + TerminatorKind::UnwindTerminate(_) => { + let bb = location.block; + if !self.body.basic_blocks[bb].is_cleanup { + self.fail(location, "Cannot `UnwindTerminate` from non-cleanup basic block") + } + } + TerminatorKind::Return => { + let bb = location.block; + if self.body.basic_blocks[bb].is_cleanup { + self.fail(location, "Cannot `Return` from cleanup basic block") + } + } + TerminatorKind::Unreachable => {} + } + + self.super_terminator(terminator, location); + } + + fn visit_source_scope(&mut self, scope: SourceScope) { + if self.body.source_scopes.get(scope).is_none() { + self.tcx.dcx().span_bug( + self.body.span, + format!( + "broken MIR in {:?} ({}):\ninvalid source scope {:?}", + self.body.source.instance, self.when, scope, + ), + ); + } + } +} + +/// A faster version of the validation pass that only checks those things which may break when +/// instantiating any generic parameters. +/// +/// `caller_body` is used to detect cycles in MIR inlining and MIR validation before +/// `optimized_mir` is available. +pub(super) fn validate_types<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + body: &Body<'tcx>, + caller_body: &Body<'tcx>, +) -> Vec<(Location, String)> { + let mut type_checker = TypeChecker { body, caller_body, tcx, typing_env, failures: Vec::new() }; + // The type checker formats a bunch of strings with type names in it, but these strings + // are not always going to be encountered on the error path since the inliner also uses + // the validator, and there are certain kinds of inlining (even for valid code) that + // can cause validation errors (mostly around where clauses and rigid projections). + with_no_trimmed_paths!({ + type_checker.visit_body(body); + }); + type_checker.failures +} + +struct TypeChecker<'a, 'tcx> { + body: &'a Body<'tcx>, + caller_body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + failures: Vec<(Location, String)>, +} + +impl<'a, 'tcx> TypeChecker<'a, 'tcx> { + fn fail(&mut self, location: Location, msg: impl Into<String>) { + self.failures.push((location, msg.into())); + } + + /// Check if src can be assigned into dest. + /// This is not precise, it will accept some incorrect assignments. + fn mir_assign_valid_types(&self, src: Ty<'tcx>, dest: Ty<'tcx>) -> bool { + // Fast path before we normalize. + if src == dest { + // Equal types, all is good. + return true; + } + + // We sometimes have to use `defining_opaque_types` for subtyping + // to succeed here and figuring out how exactly that should work + // is annoying. It is harmless enough to just not validate anything + // in that case. We still check this after analysis as all opaque + // types have been revealed at this point. + if (src, dest).has_opaque_types() { + return true; + } + + // After borrowck subtyping should be fully explicit via + // `Subtype` projections. + let variance = if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) { + Variance::Invariant + } else { + Variance::Covariant + }; + + crate::util::relate_types(self.tcx, self.typing_env, variance, src, dest) + } + + /// Check that the given predicate definitely holds in the param-env of this MIR body. + fn predicate_must_hold_modulo_regions( + &self, + pred: impl Upcast<TyCtxt<'tcx>, ty::Predicate<'tcx>>, + ) -> bool { + let pred: ty::Predicate<'tcx> = pred.upcast(self.tcx); + + // We sometimes have to use `defining_opaque_types` for predicates + // to succeed here and figuring out how exactly that should work + // is annoying. It is harmless enough to just not validate anything + // in that case. We still check this after analysis as all opaque + // types have been revealed at this point. + if pred.has_opaque_types() { + return true; + } + + let (infcx, param_env) = self.tcx.infer_ctxt().build_with_typing_env(self.typing_env); + let ocx = ObligationCtxt::new(&infcx); + ocx.register_obligation(Obligation::new( + self.tcx, + ObligationCause::dummy(), + param_env, + pred, + )); + ocx.select_all_or_error().is_empty() + } +} + +impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> { + fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { + // This check is somewhat expensive, so only run it when -Zvalidate-mir is passed. + if self.tcx.sess.opts.unstable_opts.validate_mir + && self.body.phase < MirPhase::Runtime(RuntimePhase::Initial) + { + // `Operand::Copy` is only supposed to be used with `Copy` types. + if let Operand::Copy(place) = operand { + let ty = place.ty(&self.body.local_decls, self.tcx).ty; + + if !self.tcx.type_is_copy_modulo_regions(self.typing_env, ty) { + self.fail(location, format!("`Operand::Copy` with non-`Copy` type {ty}")); + } + } + } + + self.super_operand(operand, location); + } + + fn visit_projection_elem( + &mut self, + place_ref: PlaceRef<'tcx>, + elem: PlaceElem<'tcx>, + context: PlaceContext, + location: Location, + ) { + match elem { + ProjectionElem::OpaqueCast(ty) + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) => + { + self.fail( + location, + format!("explicit opaque type cast to `{ty}` after `PostAnalysisNormalize`"), + ) + } + ProjectionElem::Index(index) => { + let index_ty = self.body.local_decls[index].ty; + if index_ty != self.tcx.types.usize { + self.fail(location, format!("bad index ({index_ty} != usize)")) + } + } + ProjectionElem::Deref + if self.body.phase >= MirPhase::Runtime(RuntimePhase::PostCleanup) => + { + let base_ty = place_ref.ty(&self.body.local_decls, self.tcx).ty; + + if base_ty.is_box() { + self.fail(location, format!("{base_ty} dereferenced after ElaborateBoxDerefs")) + } + } + ProjectionElem::Field(f, ty) => { + let parent_ty = place_ref.ty(&self.body.local_decls, self.tcx); + let fail_out_of_bounds = |this: &mut Self, location| { + this.fail(location, format!("Out of bounds field {f:?} for {parent_ty:?}")); + }; + let check_equal = |this: &mut Self, location, f_ty| { + if !this.mir_assign_valid_types(ty, f_ty) { + this.fail( + location, + format!( + "Field projection `{place_ref:?}.{f:?}` specified type `{ty}`, but actual type is `{f_ty}`" + ) + ) + } + }; + + let kind = match parent_ty.ty.kind() { + &ty::Alias(ty::Opaque, ty::AliasTy { def_id, args, .. }) => { + self.tcx.type_of(def_id).instantiate(self.tcx, args).kind() + } + kind => kind, + }; + + match kind { + ty::Tuple(fields) => { + let Some(f_ty) = fields.get(f.as_usize()) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, *f_ty); + } + ty::Adt(adt_def, args) => { + // see <https://github.com/rust-lang/rust/blob/7601adcc764d42c9f2984082b49948af652df986/compiler/rustc_middle/src/ty/layout.rs#L861-L864> + if self.tcx.is_lang_item(adt_def.did(), LangItem::DynMetadata) { + self.fail( + location, + format!( + "You can't project to field {f:?} of `DynMetadata` because \ + layout is weird and thinks it doesn't have fields." + ), + ); + } + + let var = parent_ty.variant_index.unwrap_or(FIRST_VARIANT); + let Some(field) = adt_def.variant(var).fields.get(f) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, field.ty(self.tcx, args)); + } + ty::Closure(_, args) => { + let args = args.as_closure(); + let Some(&f_ty) = args.upvar_tys().get(f.as_usize()) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, f_ty); + } + ty::CoroutineClosure(_, args) => { + let args = args.as_coroutine_closure(); + let Some(&f_ty) = args.upvar_tys().get(f.as_usize()) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, f_ty); + } + &ty::Coroutine(def_id, args) => { + let f_ty = if let Some(var) = parent_ty.variant_index { + // If we're currently validating an inlined copy of this body, + // then it will no longer be parameterized over the original + // args of the coroutine. Otherwise, we prefer to use this body + // since we may be in the process of computing this MIR in the + // first place. + let layout = if def_id == self.caller_body.source.def_id() { + self.caller_body + .coroutine_layout_raw() + .or_else(|| self.tcx.coroutine_layout(def_id, args).ok()) + } else if self.tcx.needs_coroutine_by_move_body_def_id(def_id) + && let ty::ClosureKind::FnOnce = + args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() + && self.caller_body.source.def_id() + == self.tcx.coroutine_by_move_body_def_id(def_id) + { + // Same if this is the by-move body of a coroutine-closure. + self.caller_body.coroutine_layout_raw() + } else { + self.tcx.coroutine_layout(def_id, args).ok() + }; + + let Some(layout) = layout else { + self.fail( + location, + format!("No coroutine layout for {parent_ty:?}"), + ); + return; + }; + + let Some(&local) = layout.variant_fields[var].get(f) else { + fail_out_of_bounds(self, location); + return; + }; + + let Some(f_ty) = layout.field_tys.get(local) else { + self.fail( + location, + format!("Out of bounds local {local:?} for {parent_ty:?}"), + ); + return; + }; + + ty::EarlyBinder::bind(f_ty.ty).instantiate(self.tcx, args) + } else { + let Some(&f_ty) = args.as_coroutine().prefix_tys().get(f.index()) + else { + fail_out_of_bounds(self, location); + return; + }; + + f_ty + }; + + check_equal(self, location, f_ty); + } + _ => { + self.fail(location, format!("{:?} does not have fields", parent_ty.ty)); + } + } + } + ProjectionElem::Subtype(ty) => { + if !util::sub_types( + self.tcx, + self.typing_env, + ty, + place_ref.ty(&self.body.local_decls, self.tcx).ty, + ) { + self.fail( + location, + format!( + "Failed subtyping {ty} and {}", + place_ref.ty(&self.body.local_decls, self.tcx).ty + ), + ) + } + } + ProjectionElem::UnwrapUnsafeBinder(unwrapped_ty) => { + let binder_ty = place_ref.ty(&self.body.local_decls, self.tcx); + let ty::UnsafeBinder(binder_ty) = *binder_ty.ty.kind() else { + self.fail( + location, + format!("WrapUnsafeBinder does not produce a ty::UnsafeBinder"), + ); + return; + }; + let binder_inner_ty = self.tcx.instantiate_bound_regions_with_erased(*binder_ty); + if !self.mir_assign_valid_types(unwrapped_ty, binder_inner_ty) { + self.fail( + location, + format!( + "Cannot unwrap unsafe binder {binder_ty:?} into type {unwrapped_ty}" + ), + ); + } + } + _ => {} + } + self.super_projection_elem(place_ref, elem, context, location); + } + + fn visit_var_debug_info(&mut self, debuginfo: &VarDebugInfo<'tcx>) { + if let Some(box VarDebugInfoFragment { ty, ref projection }) = debuginfo.composite { + if ty.is_union() || ty.is_enum() { + self.fail( + START_BLOCK.start_location(), + format!("invalid type {ty} in debuginfo for {:?}", debuginfo.name), + ); + } + if projection.is_empty() { + self.fail( + START_BLOCK.start_location(), + format!("invalid empty projection in debuginfo for {:?}", debuginfo.name), + ); + } + if projection.iter().any(|p| !matches!(p, PlaceElem::Field(..))) { + self.fail( + START_BLOCK.start_location(), + format!( + "illegal projection {:?} in debuginfo for {:?}", + projection, debuginfo.name + ), + ); + } + } + match debuginfo.value { + VarDebugInfoContents::Const(_) => {} + VarDebugInfoContents::Place(place) => { + if place.projection.iter().any(|p| !p.can_use_in_debuginfo()) { + self.fail( + START_BLOCK.start_location(), + format!("illegal place {:?} in debuginfo for {:?}", place, debuginfo.name), + ); + } + } + } + self.super_var_debug_info(debuginfo); + } + + fn visit_place(&mut self, place: &Place<'tcx>, cntxt: PlaceContext, location: Location) { + // Set off any `bug!`s in the type computation code + let _ = place.ty(&self.body.local_decls, self.tcx); + + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) + && place.projection.len() > 1 + && cntxt != PlaceContext::NonUse(NonUseContext::VarDebugInfo) + && place.projection[1..].contains(&ProjectionElem::Deref) + { + self.fail( + location, + format!("place {place:?} has deref as a later projection (it is only permitted as the first projection)"), + ); + } + + // Ensure all downcast projections are followed by field projections. + let mut projections_iter = place.projection.iter(); + while let Some(proj) = projections_iter.next() { + if matches!(proj, ProjectionElem::Downcast(..)) { + if !matches!(projections_iter.next(), Some(ProjectionElem::Field(..))) { + self.fail( + location, + format!( + "place {place:?} has `Downcast` projection not followed by `Field`" + ), + ); + } + } + } + + self.super_place(place, cntxt, location); + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + macro_rules! check_kinds { + ($t:expr, $text:literal, $typat:pat) => { + if !matches!(($t).kind(), $typat) { + self.fail(location, format!($text, $t)); + } + }; + } + match rvalue { + Rvalue::Use(_) | Rvalue::CopyForDeref(_) => {} + Rvalue::Aggregate(kind, fields) => match **kind { + AggregateKind::Tuple => {} + AggregateKind::Array(dest) => { + for src in fields { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "array field has the wrong type"); + } + } + } + AggregateKind::Adt(def_id, idx, args, _, Some(field)) => { + let adt_def = self.tcx.adt_def(def_id); + assert!(adt_def.is_union()); + assert_eq!(idx, FIRST_VARIANT); + let dest_ty = self.tcx.normalize_erasing_regions( + self.typing_env, + adt_def.non_enum_variant().fields[field].ty(self.tcx, args), + ); + if let [field] = fields.raw.as_slice() { + let src_ty = field.ty(self.body, self.tcx); + if !self.mir_assign_valid_types(src_ty, dest_ty) { + self.fail(location, "union field has the wrong type"); + } + } else { + self.fail(location, "unions should have one initialized field"); + } + } + AggregateKind::Adt(def_id, idx, args, _, None) => { + let adt_def = self.tcx.adt_def(def_id); + assert!(!adt_def.is_union()); + let variant = &adt_def.variants()[idx]; + if variant.fields.len() != fields.len() { + self.fail(location, "adt has the wrong number of initialized fields"); + } + for (src, dest) in std::iter::zip(fields, &variant.fields) { + let dest_ty = self + .tcx + .normalize_erasing_regions(self.typing_env, dest.ty(self.tcx, args)); + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest_ty) { + self.fail(location, "adt field has the wrong type"); + } + } + } + AggregateKind::Closure(_, args) => { + let upvars = args.as_closure().upvar_tys(); + if upvars.len() != fields.len() { + self.fail(location, "closure has the wrong number of initialized fields"); + } + for (src, dest) in std::iter::zip(fields, upvars) { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "closure field has the wrong type"); + } + } + } + AggregateKind::Coroutine(_, args) => { + let upvars = args.as_coroutine().upvar_tys(); + if upvars.len() != fields.len() { + self.fail(location, "coroutine has the wrong number of initialized fields"); + } + for (src, dest) in std::iter::zip(fields, upvars) { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "coroutine field has the wrong type"); + } + } + } + AggregateKind::CoroutineClosure(_, args) => { + let upvars = args.as_coroutine_closure().upvar_tys(); + if upvars.len() != fields.len() { + self.fail( + location, + "coroutine-closure has the wrong number of initialized fields", + ); + } + for (src, dest) in std::iter::zip(fields, upvars) { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "coroutine-closure field has the wrong type"); + } + } + } + AggregateKind::RawPtr(pointee_ty, mutability) => { + if !matches!(self.body.phase, MirPhase::Runtime(_)) { + // It would probably be fine to support this in earlier phases, but at the + // time of writing it's only ever introduced from intrinsic lowering, so + // earlier things just `bug!` on it. + self.fail(location, "RawPtr should be in runtime MIR only"); + } + + if let [data_ptr, metadata] = fields.raw.as_slice() { + let data_ptr_ty = data_ptr.ty(self.body, self.tcx); + let metadata_ty = metadata.ty(self.body, self.tcx); + if let ty::RawPtr(in_pointee, in_mut) = data_ptr_ty.kind() { + if *in_mut != mutability { + self.fail(location, "input and output mutability must match"); + } + + // FIXME: check `Thin` instead of `Sized` + if !in_pointee.is_sized(self.tcx, self.typing_env) { + self.fail(location, "input pointer must be thin"); + } + } else { + self.fail( + location, + "first operand to raw pointer aggregate must be a raw pointer", + ); + } + + // FIXME: Check metadata more generally + if pointee_ty.is_slice() { + if !self.mir_assign_valid_types(metadata_ty, self.tcx.types.usize) { + self.fail(location, "slice metadata must be usize"); + } + } else if pointee_ty.is_sized(self.tcx, self.typing_env) { + if metadata_ty != self.tcx.types.unit { + self.fail(location, "metadata for pointer-to-thin must be unit"); + } + } + } else { + self.fail(location, "raw pointer aggregate must have 2 fields"); + } + } + }, + Rvalue::Ref(_, BorrowKind::Fake(_), _) => { + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`Assign` statement with a `Fake` borrow should have been removed in runtime MIR", + ); + } + } + Rvalue::Ref(..) => {} + Rvalue::Len(p) => { + let pty = p.ty(&self.body.local_decls, self.tcx).ty; + check_kinds!( + pty, + "Cannot compute length of non-array type {:?}", + ty::Array(..) | ty::Slice(..) + ); + } + Rvalue::BinaryOp(op, vals) => { + use BinOp::*; + let a = vals.0.ty(&self.body.local_decls, self.tcx); + let b = vals.1.ty(&self.body.local_decls, self.tcx); + if crate::util::binop_right_homogeneous(*op) { + if let Eq | Lt | Le | Ne | Ge | Gt = op { + // The function pointer types can have lifetimes + if !self.mir_assign_valid_types(a, b) { + self.fail( + location, + format!("Cannot {op:?} compare incompatible types {a} and {b}"), + ); + } + } else if a != b { + self.fail( + location, + format!("Cannot perform binary op {op:?} on unequal types {a} and {b}"), + ); + } + } + + match op { + Offset => { + check_kinds!(a, "Cannot offset non-pointer type {:?}", ty::RawPtr(..)); + if b != self.tcx.types.isize && b != self.tcx.types.usize { + self.fail(location, format!("Cannot offset by non-isize type {b}")); + } + } + Eq | Lt | Le | Ne | Ge | Gt => { + for x in [a, b] { + check_kinds!( + x, + "Cannot {op:?} compare type {:?}", + ty::Bool + | ty::Char + | ty::Int(..) + | ty::Uint(..) + | ty::Float(..) + | ty::RawPtr(..) + | ty::FnPtr(..) + ) + } + } + Cmp => { + for x in [a, b] { + check_kinds!( + x, + "Cannot three-way compare non-integer type {:?}", + ty::Char | ty::Uint(..) | ty::Int(..) + ) + } + } + AddUnchecked | AddWithOverflow | SubUnchecked | SubWithOverflow + | MulUnchecked | MulWithOverflow | Shl | ShlUnchecked | Shr | ShrUnchecked => { + for x in [a, b] { + check_kinds!( + x, + "Cannot {op:?} non-integer type {:?}", + ty::Uint(..) | ty::Int(..) + ) + } + } + BitAnd | BitOr | BitXor => { + for x in [a, b] { + check_kinds!( + x, + "Cannot perform bitwise op {op:?} on type {:?}", + ty::Uint(..) | ty::Int(..) | ty::Bool + ) + } + } + Add | Sub | Mul | Div | Rem => { + for x in [a, b] { + check_kinds!( + x, + "Cannot perform arithmetic {op:?} on type {:?}", + ty::Uint(..) | ty::Int(..) | ty::Float(..) + ) + } + } + } + } + Rvalue::UnaryOp(op, operand) => { + let a = operand.ty(&self.body.local_decls, self.tcx); + match op { + UnOp::Neg => { + check_kinds!(a, "Cannot negate type {:?}", ty::Int(..) | ty::Float(..)) + } + UnOp::Not => { + check_kinds!( + a, + "Cannot binary not type {:?}", + ty::Int(..) | ty::Uint(..) | ty::Bool + ); + } + UnOp::PtrMetadata => { + check_kinds!( + a, + "Cannot PtrMetadata non-pointer non-reference type {:?}", + ty::RawPtr(..) | ty::Ref(..) + ); + } + } + } + Rvalue::ShallowInitBox(operand, _) => { + let a = operand.ty(&self.body.local_decls, self.tcx); + check_kinds!(a, "Cannot shallow init type {:?}", ty::RawPtr(..)); + } + Rvalue::Cast(kind, operand, target_type) => { + let op_ty = operand.ty(self.body, self.tcx); + match kind { + // FIXME: Add Checks for these + CastKind::PointerWithExposedProvenance | CastKind::PointerExposeProvenance => {} + CastKind::PointerCoercion(PointerCoercion::ReifyFnPointer, _) => { + // FIXME: check signature compatibility. + check_kinds!( + op_ty, + "CastKind::{kind:?} input must be a fn item, not {:?}", + ty::FnDef(..) + ); + check_kinds!( + target_type, + "CastKind::{kind:?} output must be a fn pointer, not {:?}", + ty::FnPtr(..) + ); + } + CastKind::PointerCoercion(PointerCoercion::UnsafeFnPointer, _) => { + // FIXME: check safety and signature compatibility. + check_kinds!( + op_ty, + "CastKind::{kind:?} input must be a fn pointer, not {:?}", + ty::FnPtr(..) + ); + check_kinds!( + target_type, + "CastKind::{kind:?} output must be a fn pointer, not {:?}", + ty::FnPtr(..) + ); + } + CastKind::PointerCoercion(PointerCoercion::ClosureFnPointer(..), _) => { + // FIXME: check safety, captures, and signature compatibility. + check_kinds!( + op_ty, + "CastKind::{kind:?} input must be a closure, not {:?}", + ty::Closure(..) + ); + check_kinds!( + target_type, + "CastKind::{kind:?} output must be a fn pointer, not {:?}", + ty::FnPtr(..) + ); + } + CastKind::PointerCoercion(PointerCoercion::MutToConstPointer, _) => { + // FIXME: check same pointee? + check_kinds!( + op_ty, + "CastKind::{kind:?} input must be a raw mut pointer, not {:?}", + ty::RawPtr(_, Mutability::Mut) + ); + check_kinds!( + target_type, + "CastKind::{kind:?} output must be a raw const pointer, not {:?}", + ty::RawPtr(_, Mutability::Not) + ); + if self.body.phase >= MirPhase::Analysis(AnalysisPhase::PostCleanup) { + self.fail(location, format!("After borrowck, MIR disallows {kind:?}")); + } + } + CastKind::PointerCoercion(PointerCoercion::ArrayToPointer, _) => { + // FIXME: Check pointee types + check_kinds!( + op_ty, + "CastKind::{kind:?} input must be a raw pointer, not {:?}", + ty::RawPtr(..) + ); + check_kinds!( + target_type, + "CastKind::{kind:?} output must be a raw pointer, not {:?}", + ty::RawPtr(..) + ); + if self.body.phase >= MirPhase::Analysis(AnalysisPhase::PostCleanup) { + self.fail(location, format!("After borrowck, MIR disallows {kind:?}")); + } + } + CastKind::PointerCoercion(PointerCoercion::Unsize, _) => { + // Pointers being unsize coerced should at least implement + // `CoerceUnsized`. + if !self.predicate_must_hold_modulo_regions(ty::TraitRef::new( + self.tcx, + self.tcx.require_lang_item( + LangItem::CoerceUnsized, + self.body.source_info(location).span, + ), + [op_ty, *target_type], + )) { + self.fail(location, format!("Unsize coercion, but `{op_ty}` isn't coercible to `{target_type}`")); + } + } + CastKind::PointerCoercion(PointerCoercion::DynStar, _) => { + // FIXME(dyn-star): make sure nothing needs to be done here. + } + CastKind::IntToInt | CastKind::IntToFloat => { + let input_valid = op_ty.is_integral() || op_ty.is_char() || op_ty.is_bool(); + let target_valid = target_type.is_numeric() || target_type.is_char(); + if !input_valid || !target_valid { + self.fail( + location, + format!("Wrong cast kind {kind:?} for the type {op_ty}"), + ); + } + } + CastKind::FnPtrToPtr => { + check_kinds!( + op_ty, + "CastKind::{kind:?} input must be a fn pointer, not {:?}", + ty::FnPtr(..) + ); + check_kinds!( + target_type, + "CastKind::{kind:?} output must be a raw pointer, not {:?}", + ty::RawPtr(..) + ); + } + CastKind::PtrToPtr => { + check_kinds!( + op_ty, + "CastKind::{kind:?} input must be a raw pointer, not {:?}", + ty::RawPtr(..) + ); + check_kinds!( + target_type, + "CastKind::{kind:?} output must be a raw pointer, not {:?}", + ty::RawPtr(..) + ); + } + CastKind::FloatToFloat | CastKind::FloatToInt => { + if !op_ty.is_floating_point() || !target_type.is_numeric() { + self.fail( + location, + format!( + "Trying to cast non 'Float' as {kind:?} into {target_type:?}" + ), + ); + } + } + CastKind::Transmute => { + // Unlike `mem::transmute`, a MIR `Transmute` is well-formed + // for any two `Sized` types, just potentially UB to run. + + if !self + .tcx + .normalize_erasing_regions(self.typing_env, op_ty) + .is_sized(self.tcx, self.typing_env) + { + self.fail( + location, + format!("Cannot transmute from non-`Sized` type {op_ty}"), + ); + } + if !self + .tcx + .normalize_erasing_regions(self.typing_env, *target_type) + .is_sized(self.tcx, self.typing_env) + { + self.fail( + location, + format!("Cannot transmute to non-`Sized` type {target_type:?}"), + ); + } + } + } + } + Rvalue::NullaryOp(NullOp::OffsetOf(indices), container) => { + let fail_out_of_bounds = |this: &mut Self, location, field, ty| { + this.fail(location, format!("Out of bounds field {field:?} for {ty}")); + }; + + let mut current_ty = *container; + + for (variant, field) in indices.iter() { + match current_ty.kind() { + ty::Tuple(fields) => { + if variant != FIRST_VARIANT { + self.fail( + location, + format!("tried to get variant {variant:?} of tuple"), + ); + return; + } + let Some(&f_ty) = fields.get(field.as_usize()) else { + fail_out_of_bounds(self, location, field, current_ty); + return; + }; + + current_ty = self.tcx.normalize_erasing_regions(self.typing_env, f_ty); + } + ty::Adt(adt_def, args) => { + let Some(field) = adt_def.variant(variant).fields.get(field) else { + fail_out_of_bounds(self, location, field, current_ty); + return; + }; + + let f_ty = field.ty(self.tcx, args); + current_ty = self.tcx.normalize_erasing_regions(self.typing_env, f_ty); + } + _ => { + self.fail( + location, + format!("Cannot get offset ({variant:?}, {field:?}) from type {current_ty}"), + ); + return; + } + } + } + } + Rvalue::Repeat(_, _) + | Rvalue::ThreadLocalRef(_) + | Rvalue::RawPtr(_, _) + | Rvalue::NullaryOp( + NullOp::SizeOf | NullOp::AlignOf | NullOp::UbChecks | NullOp::ContractChecks, + _, + ) + | Rvalue::Discriminant(_) => {} + + Rvalue::WrapUnsafeBinder(op, ty) => { + let unwrapped_ty = op.ty(self.body, self.tcx); + let ty::UnsafeBinder(binder_ty) = *ty.kind() else { + self.fail( + location, + format!("WrapUnsafeBinder does not produce a ty::UnsafeBinder"), + ); + return; + }; + let binder_inner_ty = self.tcx.instantiate_bound_regions_with_erased(*binder_ty); + if !self.mir_assign_valid_types(unwrapped_ty, binder_inner_ty) { + self.fail( + location, + format!("Cannot wrap {unwrapped_ty} into unsafe binder {binder_ty:?}"), + ); + } + } + } + self.super_rvalue(rvalue, location); + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::Assign(box (dest, rvalue)) => { + // LHS and RHS of the assignment must have the same type. + let left_ty = dest.ty(&self.body.local_decls, self.tcx).ty; + let right_ty = rvalue.ty(&self.body.local_decls, self.tcx); + + if !self.mir_assign_valid_types(right_ty, left_ty) { + self.fail( + location, + format!( + "encountered `{:?}` with incompatible types:\n\ + left-hand side has type: {}\n\ + right-hand side has type: {}", + statement.kind, left_ty, right_ty, + ), + ); + } + if let Rvalue::CopyForDeref(place) = rvalue { + if place.ty(&self.body.local_decls, self.tcx).ty.builtin_deref(true).is_none() { + self.fail( + location, + "`CopyForDeref` should only be used for dereferenceable types", + ) + } + } + } + StatementKind::AscribeUserType(..) => { + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`AscribeUserType` should have been removed after drop lowering phase", + ); + } + } + StatementKind::FakeRead(..) => { + if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) { + self.fail( + location, + "`FakeRead` should have been removed after drop lowering phase", + ); + } + } + StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(op)) => { + let ty = op.ty(&self.body.local_decls, self.tcx); + if !ty.is_bool() { + self.fail( + location, + format!("`assume` argument must be `bool`, but got: `{ty}`"), + ); + } + } + StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping( + CopyNonOverlapping { src, dst, count }, + )) => { + let src_ty = src.ty(&self.body.local_decls, self.tcx); + let op_src_ty = if let Some(src_deref) = src_ty.builtin_deref(true) { + src_deref + } else { + self.fail( + location, + format!("Expected src to be ptr in copy_nonoverlapping, got: {src_ty}"), + ); + return; + }; + let dst_ty = dst.ty(&self.body.local_decls, self.tcx); + let op_dst_ty = if let Some(dst_deref) = dst_ty.builtin_deref(true) { + dst_deref + } else { + self.fail( + location, + format!("Expected dst to be ptr in copy_nonoverlapping, got: {dst_ty}"), + ); + return; + }; + // since CopyNonOverlapping is parametrized by 1 type, + // we only need to check that they are equal and not keep an extra parameter. + if !self.mir_assign_valid_types(op_src_ty, op_dst_ty) { + self.fail(location, format!("bad arg ({op_src_ty} != {op_dst_ty})")); + } + + let op_cnt_ty = count.ty(&self.body.local_decls, self.tcx); + if op_cnt_ty != self.tcx.types.usize { + self.fail(location, format!("bad arg ({op_cnt_ty} != usize)")) + } + } + StatementKind::SetDiscriminant { place, .. } => { + if self.body.phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`SetDiscriminant`is not allowed until deaggregation"); + } + let pty = place.ty(&self.body.local_decls, self.tcx).ty; + if !matches!( + pty.kind(), + ty::Adt(..) | ty::Coroutine(..) | ty::Alias(ty::Opaque, ..) + ) { + self.fail( + location, + format!( + "`SetDiscriminant` is only allowed on ADTs and coroutines, not {pty}" + ), + ); + } + } + StatementKind::Deinit(..) => { + if self.body.phase < MirPhase::Runtime(RuntimePhase::Initial) { + self.fail(location, "`Deinit`is not allowed until deaggregation"); + } + } + StatementKind::Retag(kind, _) => { + // FIXME(JakobDegen) The validator should check that `self.body.phase < + // DropsLowered`. However, this causes ICEs with generation of drop shims, which + // seem to fail to set their `MirPhase` correctly. + if matches!(kind, RetagKind::TwoPhase) { + self.fail(location, format!("explicit `{kind:?}` is forbidden")); + } + } + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Coverage(_) + | StatementKind::ConstEvalCounter + | StatementKind::PlaceMention(..) + | StatementKind::BackwardIncompatibleDropHint { .. } + | StatementKind::Nop => {} + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + match &terminator.kind { + TerminatorKind::SwitchInt { targets, discr } => { + let switch_ty = discr.ty(&self.body.local_decls, self.tcx); + + let target_width = self.tcx.sess.target.pointer_width; + + let size = Size::from_bits(match switch_ty.kind() { + ty::Uint(uint) => uint.normalize(target_width).bit_width().unwrap(), + ty::Int(int) => int.normalize(target_width).bit_width().unwrap(), + ty::Char => 32, + ty::Bool => 1, + other => bug!("unhandled type: {:?}", other), + }); + + for (value, _) in targets.iter() { + if ScalarInt::try_from_uint(value, size).is_none() { + self.fail( + location, + format!("the value {value:#x} is not a proper {switch_ty}"), + ) + } + } + } + TerminatorKind::Call { func, .. } | TerminatorKind::TailCall { func, .. } => { + let func_ty = func.ty(&self.body.local_decls, self.tcx); + match func_ty.kind() { + ty::FnPtr(..) | ty::FnDef(..) => {} + _ => self.fail( + location, + format!( + "encountered non-callable type {func_ty} in `{}` terminator", + terminator.kind.name() + ), + ), + } + + if let TerminatorKind::TailCall { .. } = terminator.kind { + // FIXME(explicit_tail_calls): implement tail-call specific checks here (such + // as signature matching, forbidding closures, etc) + } + } + TerminatorKind::Assert { cond, .. } => { + let cond_ty = cond.ty(&self.body.local_decls, self.tcx); + if cond_ty != self.tcx.types.bool { + self.fail( + location, + format!( + "encountered non-boolean condition of type {cond_ty} in `Assert` terminator" + ), + ); + } + } + TerminatorKind::Goto { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::Yield { .. } + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::InlineAsm { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Unreachable => {} + } + + self.super_terminator(terminator, location); + } +} |
