diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
| -rw-r--r-- | compiler/rustc_mir_transform/src/check_enums.rs | 501 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline.rs | 5 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/inline/cycle.rs | 273 | ||||
| -rw-r--r-- | compiler/rustc_mir_transform/src/lib.rs | 4 |
4 files changed, 655 insertions, 128 deletions
diff --git a/compiler/rustc_mir_transform/src/check_enums.rs b/compiler/rustc_mir_transform/src/check_enums.rs new file mode 100644 index 00000000000..e06e0c6122e --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_enums.rs @@ -0,0 +1,501 @@ +use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange}; +use rustc_hir::LangItem; +use rustc_index::IndexVec; +use rustc_middle::bug; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::ty::layout::PrimitiveExt; +use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv}; +use rustc_session::Session; +use tracing::debug; + +/// This pass inserts checks for a valid enum discriminant where they are most +/// likely to find UB, because checking everywhere like Miri would generate too +/// much MIR. +pub(super) struct CheckEnums; + +impl<'tcx> crate::MirPass<'tcx> for CheckEnums { + fn is_enabled(&self, sess: &Session) -> bool { + sess.ub_checks() + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // This pass emits new panics. If for whatever reason we do not have a panic + // implementation, running this pass may cause otherwise-valid code to not compile. + if tcx.lang_items().get(LangItem::PanicImpl).is_none() { + return; + } + + let typing_env = body.typing_env(tcx); + let basic_blocks = body.basic_blocks.as_mut(); + let local_decls = &mut body.local_decls; + + // This operation inserts new blocks. Each insertion changes the Location for all + // statements/blocks after. Iterating or visiting the MIR in order would require updating + // our current location after every insertion. By iterating backwards, we dodge this issue: + // The only Locations that an insertion changes have already been handled. + for block in basic_blocks.indices().rev() { + for statement_index in (0..basic_blocks[block].statements.len()).rev() { + let location = Location { block, statement_index }; + let statement = &basic_blocks[block].statements[statement_index]; + let source_info = statement.source_info; + + let mut finder = EnumFinder::new(tcx, local_decls, typing_env); + finder.visit_statement(statement, location); + + for check in finder.into_found_enums() { + debug!("Inserting enum check"); + let new_block = split_block(basic_blocks, location); + + match check { + EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => { + insert_direct_enum_check( + tcx, + local_decls, + basic_blocks, + block, + source_op, + discr, + op_size, + valid_discrs, + source_info, + new_block, + ) + } + EnumCheckType::Uninhabited => insert_uninhabited_enum_check( + tcx, + local_decls, + &mut basic_blocks[block], + source_info, + new_block, + ), + EnumCheckType::WithNiche { + source_op, + discr, + op_size, + offset, + valid_range, + } => insert_niche_check( + tcx, + local_decls, + &mut basic_blocks[block], + source_op, + valid_range, + discr, + op_size, + offset, + source_info, + new_block, + ), + } + } + } + } + } + + fn is_required(&self) -> bool { + true + } +} + +/// Represent the different kind of enum checks we can insert. +enum EnumCheckType<'tcx> { + /// We know we try to create an uninhabited enum from an inhabited variant. + Uninhabited, + /// We know the enum does no niche optimizations and can thus easily compute + /// the valid discriminants. + Direct { + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + valid_discrs: Vec<u128>, + }, + /// We try to construct an enum that has a niche. + WithNiche { + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + offset: Size, + valid_range: WrappingRange, + }, +} + +struct TyAndSize<'tcx> { + pub ty: Ty<'tcx>, + pub size: Size, +} + +/// A [Visitor] that finds the construction of enums and evaluates which checks +/// we should apply. +struct EnumFinder<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + local_decls: &'a mut LocalDecls<'tcx>, + typing_env: TypingEnv<'tcx>, + enums: Vec<EnumCheckType<'tcx>>, +} + +impl<'a, 'tcx> EnumFinder<'a, 'tcx> { + fn new( + tcx: TyCtxt<'tcx>, + local_decls: &'a mut LocalDecls<'tcx>, + typing_env: TypingEnv<'tcx>, + ) -> Self { + EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() } + } + + /// Returns the found enum creations and which checks should be inserted. + fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> { + self.enums + } +} + +impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> { + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue { + let ty::Adt(adt_def, _) = ty.kind() else { + return; + }; + if !adt_def.is_enum() { + return; + } + + let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else { + return; + }; + let Ok(op_layout) = self + .tcx + .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx))) + else { + return; + }; + + match enum_layout.variants { + Variants::Empty if op_layout.is_uninhabited() => return, + // An empty enum that tries to be constructed from an inhabited value, this + // is never correct. + Variants::Empty => { + // The enum layout is uninhabited but we construct it from sth inhabited. + // This is always UB. + self.enums.push(EnumCheckType::Uninhabited); + } + // Construction of Single value enums is always fine. + Variants::Single { .. } => {} + // Construction of an enum with multiple variants but no niche optimizations. + Variants::Multiple { + tag_encoding: TagEncoding::Direct, + tag: Scalar::Initialized { value, .. }, + .. + } => { + let valid_discrs = + adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect(); + + let discr = + TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) }; + self.enums.push(EnumCheckType::Direct { + source_op: op.to_copy(), + discr, + op_size: op_layout.size, + valid_discrs, + }); + } + // Construction of an enum with multiple variants and niche optimizations. + Variants::Multiple { + tag_encoding: TagEncoding::Niche { .. }, + tag: Scalar::Initialized { value, valid_range, .. }, + tag_field, + .. + } => { + let discr = + TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) }; + self.enums.push(EnumCheckType::WithNiche { + source_op: op.to_copy(), + discr, + op_size: op_layout.size, + offset: enum_layout.fields.offset(tag_field.as_usize()), + valid_range, + }); + } + _ => return, + } + + self.super_rvalue(rvalue, location); + } + } +} + +fn split_block( + basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, + location: Location, +) -> BasicBlock { + let block_data = &mut basic_blocks[location.block]; + + // Drain every statement after this one and move the current terminator to a new basic block. + let new_block = BasicBlockData { + statements: block_data.statements.split_off(location.statement_index), + terminator: block_data.terminator.take(), + is_cleanup: block_data.is_cleanup, + }; + + basic_blocks.push(new_block) +} + +/// Inserts the cast of an operand (any type) to a u128 value that holds the discriminant value. +fn insert_discr_cast_to_u128<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + block_data: &mut BasicBlockData<'tcx>, + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + offset: Option<Size>, + source_info: SourceInfo, +) -> Place<'tcx> { + let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> { + match size.bytes() { + 1 => tcx.types.u8, + 2 => tcx.types.u16, + 4 => tcx.types.u32, + 8 => tcx.types.u64, + 16 => tcx.types.u128, + invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid), + } + }; + + let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() { + // The discriminant is less wide than the operand, cast the operand into + // [MaybeUninit; N] and then index into it. + let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8); + let array_len = op_size.bytes(); + let mu_array_ty = Ty::new_array(tcx, mu, array_len); + let mu_array = + local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into(); + let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((mu_array, rvalue))), + }); + + // Index into the array of MaybeUninit to get something that is actually + // as wide as the discriminant. + let offset = offset.unwrap_or(Size::ZERO); + let smaller_mu_array = mu_array.project_deeper( + &[ProjectionElem::Subslice { + from: offset.bytes(), + to: offset.bytes() + discr.size.bytes(), + from_end: false, + }], + tcx, + ); + + (CastKind::Transmute, Operand::Copy(smaller_mu_array)) + } else { + let operand_int_ty = get_ty_for_size(tcx, op_size); + + let op_as_int = + local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into(); + let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((op_as_int, rvalue))), + }); + + (CastKind::IntToInt, Operand::Copy(op_as_int)) + }; + + // Cast the resulting value to the actual discriminant integer type. + let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty); + let discr_in_discr_ty = + local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))), + }); + + // Cast the discriminant to a u128 (base for comparisions of enum discriminants). + let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128); + let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128); + let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into(); + block_data + .statements + .push(Statement { source_info, kind: StatementKind::Assign(Box::new((discr, rvalue))) }); + + discr +} + +fn insert_direct_enum_check<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>, + current_block: BasicBlock, + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + discriminants: Vec<u128>, + source_info: SourceInfo, + new_block: BasicBlock, +) { + // Insert a new target block that is branched to in case of an invalid discriminant. + let invalid_discr_block_data = BasicBlockData::new(None, false); + let invalid_discr_block = basic_blocks.push(invalid_discr_block_data); + let block_data = &mut basic_blocks[current_block]; + let discr = insert_discr_cast_to_u128( + tcx, + local_decls, + block_data, + source_op, + discr, + op_size, + None, + source_info, + ); + + // Branch based on the discriminant value. + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::SwitchInt { + discr: Operand::Copy(discr), + targets: SwitchTargets::new( + discriminants.into_iter().map(|discr| (discr, new_block)), + invalid_discr_block, + ), + }, + }); + + // Abort in case of an invalid enum discriminant. + basic_blocks[invalid_discr_block].terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool), + })), + expected: true, + target: new_block, + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))), + // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. + // We never want to insert an unwind into unsafe code, because unwinding could + // make a failing UB check turn into much worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); +} + +fn insert_uninhabited_enum_check<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + block_data: &mut BasicBlockData<'tcx>, + source_info: SourceInfo, + new_block: BasicBlock, +) { + let is_ok: Place<'_> = + local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_ok, + Rvalue::Use(Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool), + }))), + ))), + }); + + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: Operand::Copy(is_ok), + expected: true, + target: new_block, + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new( + ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128), + }, + )))), + // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. + // We never want to insert an unwind into unsafe code, because unwinding could + // make a failing UB check turn into much worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); +} + +fn insert_niche_check<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + block_data: &mut BasicBlockData<'tcx>, + source_op: Operand<'tcx>, + valid_range: WrappingRange, + discr: TyAndSize<'tcx>, + op_size: Size, + offset: Size, + source_info: SourceInfo, + new_block: BasicBlock, +) { + let discr = insert_discr_cast_to_u128( + tcx, + local_decls, + block_data, + source_op, + discr, + op_size, + Some(offset), + source_info, + ); + + // Compare the discriminant agains the valid_range. + let start_const = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128), + })); + let end_start_diff_const = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val( + ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)), + tcx.types.u128, + ), + })); + + let discr_diff: Place<'_> = + local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + discr_diff, + Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))), + ))), + }); + + let is_ok: Place<'_> = + local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_ok, + Rvalue::BinaryOp( + // This is a `WrappingRange`, so make sure to get the wrapping right. + BinOp::Le, + Box::new((Operand::Copy(discr_diff), end_start_diff_const)), + ), + ))), + }); + + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: Operand::Copy(is_ok), + expected: true, + target: new_block, + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))), + // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. + // We never want to insert an unwind into unsafe code, because unwinding could + // make a failing UB check turn into much worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); +} diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index f48dba9663a..c27087fea11 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -770,14 +770,15 @@ fn check_mir_is_available<'tcx, I: Inliner<'tcx>>( return Ok(()); } - if callee_def_id.is_local() + if let Some(callee_def_id) = callee_def_id.as_local() && !inliner .tcx() .is_lang_item(inliner.tcx().parent(caller_def_id), rustc_hir::LangItem::FnOnce) { // If we know for sure that the function we're calling will itself try to // call us, then we avoid inlining that function. - if inliner.tcx().mir_callgraph_reachable((callee, caller_def_id.expect_local())) { + if inliner.tcx().mir_callgraph_cyclic(caller_def_id.expect_local()).contains(&callee_def_id) + { debug!("query cycle avoidance"); return Err("caller might be reachable from callee"); } diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index a944960ce4a..08f3ce5fd67 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -1,5 +1,6 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet}; use rustc_data_structures::stack::ensure_sufficient_stack; +use rustc_data_structures::unord::UnordSet; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_middle::mir::TerminatorKind; use rustc_middle::ty::{self, GenericArgsRef, InstanceKind, TyCtxt, TypeVisitableExt}; @@ -7,137 +8,143 @@ use rustc_session::Limit; use rustc_span::sym; use tracing::{instrument, trace}; -// FIXME: check whether it is cheaper to precompute the entire call graph instead of invoking -// this query ridiculously often. -#[instrument(level = "debug", skip(tcx, root, target))] -pub(crate) fn mir_callgraph_reachable<'tcx>( +#[instrument(level = "debug", skip(tcx), ret)] +fn should_recurse<'tcx>(tcx: TyCtxt<'tcx>, callee: ty::Instance<'tcx>) -> bool { + match callee.def { + // If there is no MIR available (either because it was not in metadata or + // because it has no MIR because it's an extern function), then the inliner + // won't cause cycles on this. + InstanceKind::Item(_) => { + if !tcx.is_mir_available(callee.def_id()) { + return false; + } + } + + // These have no own callable MIR. + InstanceKind::Intrinsic(_) | InstanceKind::Virtual(..) => return false, + + // These have MIR and if that MIR is inlined, instantiated and then inlining is run + // again, a function item can end up getting inlined. Thus we'll be able to cause + // a cycle that way + InstanceKind::VTableShim(_) + | InstanceKind::ReifyShim(..) + | InstanceKind::FnPtrShim(..) + | InstanceKind::ClosureOnceShim { .. } + | InstanceKind::ConstructCoroutineInClosureShim { .. } + | InstanceKind::ThreadLocalShim { .. } + | InstanceKind::CloneShim(..) => {} + + // This shim does not call any other functions, thus there can be no recursion. + InstanceKind::FnPtrAddrShim(..) => return false, + + // FIXME: A not fully instantiated drop shim can cause ICEs if one attempts to + // have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this + // needs some more analysis. + InstanceKind::DropGlue(..) + | InstanceKind::FutureDropPollShim(..) + | InstanceKind::AsyncDropGlue(..) + | InstanceKind::AsyncDropGlueCtorShim(..) => { + if callee.has_param() { + return false; + } + } + } + + crate::pm::should_run_pass(tcx, &crate::inline::Inline, crate::pm::Optimizations::Allowed) + || crate::inline::ForceInline::should_run_pass_for_callee(tcx, callee.def.def_id()) +} + +#[instrument( + level = "debug", + skip(tcx, typing_env, seen, involved, recursion_limiter, recursion_limit), + ret +)] +fn process<'tcx>( tcx: TyCtxt<'tcx>, - (root, target): (ty::Instance<'tcx>, LocalDefId), + typing_env: ty::TypingEnv<'tcx>, + caller: ty::Instance<'tcx>, + target: LocalDefId, + seen: &mut FxHashSet<ty::Instance<'tcx>>, + involved: &mut FxHashSet<LocalDefId>, + recursion_limiter: &mut FxHashMap<DefId, usize>, + recursion_limit: Limit, ) -> bool { - trace!(%root, target = %tcx.def_path_str(target)); - assert_ne!( - root.def_id().expect_local(), - target, - "you should not call `mir_callgraph_reachable` on immediate self recursion" - ); - assert!( - matches!(root.def, InstanceKind::Item(_)), - "you should not call `mir_callgraph_reachable` on shims" - ); - assert!( - !tcx.is_constructor(root.def_id()), - "you should not call `mir_callgraph_reachable` on enum/struct constructor functions" - ); - #[instrument( - level = "debug", - skip(tcx, typing_env, target, stack, seen, recursion_limiter, caller, recursion_limit) - )] - fn process<'tcx>( - tcx: TyCtxt<'tcx>, - typing_env: ty::TypingEnv<'tcx>, - caller: ty::Instance<'tcx>, - target: LocalDefId, - stack: &mut Vec<ty::Instance<'tcx>>, - seen: &mut FxHashSet<ty::Instance<'tcx>>, - recursion_limiter: &mut FxHashMap<DefId, usize>, - recursion_limit: Limit, - ) -> bool { - trace!(%caller); - for &(callee, args) in tcx.mir_inliner_callees(caller.def) { - let Ok(args) = caller.try_instantiate_mir_and_normalize_erasing_regions( - tcx, - typing_env, - ty::EarlyBinder::bind(args), - ) else { - trace!(?caller, ?typing_env, ?args, "cannot normalize, skipping"); - continue; - }; - let Ok(Some(callee)) = ty::Instance::try_resolve(tcx, typing_env, callee, args) else { - trace!(?callee, "cannot resolve, skipping"); - continue; - }; + trace!(%caller); + let mut cycle_found = false; - // Found a path. - if callee.def_id() == target.to_def_id() { - return true; - } + for &(callee, args) in tcx.mir_inliner_callees(caller.def) { + let Ok(args) = caller.try_instantiate_mir_and_normalize_erasing_regions( + tcx, + typing_env, + ty::EarlyBinder::bind(args), + ) else { + trace!(?caller, ?typing_env, ?args, "cannot normalize, skipping"); + continue; + }; + let Ok(Some(callee)) = ty::Instance::try_resolve(tcx, typing_env, callee, args) else { + trace!(?callee, "cannot resolve, skipping"); + continue; + }; - if tcx.is_constructor(callee.def_id()) { - trace!("constructors always have MIR"); - // Constructor functions cannot cause a query cycle. - continue; - } + // Found a path. + if callee.def_id() == target.to_def_id() { + cycle_found = true; + } - match callee.def { - InstanceKind::Item(_) => { - // If there is no MIR available (either because it was not in metadata or - // because it has no MIR because it's an extern function), then the inliner - // won't cause cycles on this. - if !tcx.is_mir_available(callee.def_id()) { - trace!(?callee, "no mir available, skipping"); - continue; - } - } - // These have no own callable MIR. - InstanceKind::Intrinsic(_) | InstanceKind::Virtual(..) => continue, - // These have MIR and if that MIR is inlined, instantiated and then inlining is run - // again, a function item can end up getting inlined. Thus we'll be able to cause - // a cycle that way - InstanceKind::VTableShim(_) - | InstanceKind::ReifyShim(..) - | InstanceKind::FnPtrShim(..) - | InstanceKind::ClosureOnceShim { .. } - | InstanceKind::ConstructCoroutineInClosureShim { .. } - | InstanceKind::ThreadLocalShim { .. } - | InstanceKind::CloneShim(..) => {} - - // This shim does not call any other functions, thus there can be no recursion. - InstanceKind::FnPtrAddrShim(..) => { - continue; - } - InstanceKind::DropGlue(..) - | InstanceKind::FutureDropPollShim(..) - | InstanceKind::AsyncDropGlue(..) - | InstanceKind::AsyncDropGlueCtorShim(..) => { - // FIXME: A not fully instantiated drop shim can cause ICEs if one attempts to - // have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this - // needs some more analysis. - if callee.has_param() { - continue; - } - } - } + if tcx.is_constructor(callee.def_id()) { + trace!("constructors always have MIR"); + // Constructor functions cannot cause a query cycle. + continue; + } + + if !should_recurse(tcx, callee) { + continue; + } - if seen.insert(callee) { - let recursion = recursion_limiter.entry(callee.def_id()).or_default(); - trace!(?callee, recursion = *recursion); - if recursion_limit.value_within_limit(*recursion) { - *recursion += 1; - stack.push(callee); - let found_recursion = ensure_sufficient_stack(|| { - process( - tcx, - typing_env, - callee, - target, - stack, - seen, - recursion_limiter, - recursion_limit, - ) - }); - if found_recursion { - return true; - } - stack.pop(); - } else { - // Pessimistically assume that there could be recursion. - return true; + if seen.insert(callee) { + let recursion = recursion_limiter.entry(callee.def_id()).or_default(); + trace!(?callee, recursion = *recursion); + let found_recursion = if recursion_limit.value_within_limit(*recursion) { + *recursion += 1; + ensure_sufficient_stack(|| { + process( + tcx, + typing_env, + callee, + target, + seen, + involved, + recursion_limiter, + recursion_limit, + ) + }) + } else { + // Pessimistically assume that there could be recursion. + true + }; + if found_recursion { + if let Some(callee) = callee.def_id().as_local() { + // Calling `optimized_mir` of a non-local definition cannot cycle. + involved.insert(callee); } + cycle_found = true; } } - false } + + cycle_found +} + +#[instrument(level = "debug", skip(tcx), ret)] +pub(crate) fn mir_callgraph_cyclic<'tcx>( + tcx: TyCtxt<'tcx>, + root: LocalDefId, +) -> UnordSet<LocalDefId> { + assert!( + !tcx.is_constructor(root.to_def_id()), + "you should not call `mir_callgraph_reachable` on enum/struct constructor functions" + ); + // FIXME(-Znext-solver=no): Remove this hack when trait solver overflow can return an error. // In code like that pointed out in #128887, the type complexity we ask the solver to deal with // grows as we recurse into the call graph. If we use the same recursion limit here and in the @@ -146,16 +153,32 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( // the default recursion limits are quite generous for us. If we need to recurse 64 times // into the call graph, we're probably not going to find any useful MIR inlining. let recursion_limit = tcx.recursion_limit() / 2; + let mut involved = FxHashSet::default(); + let typing_env = ty::TypingEnv::post_analysis(tcx, root); + let Ok(Some(root_instance)) = ty::Instance::try_resolve( + tcx, + typing_env, + root.to_def_id(), + ty::GenericArgs::identity_for_item(tcx, root.to_def_id()), + ) else { + trace!("cannot resolve, skipping"); + return involved.into(); + }; + if !should_recurse(tcx, root_instance) { + trace!("cannot walk, skipping"); + return involved.into(); + } process( tcx, - ty::TypingEnv::post_analysis(tcx, target), + typing_env, + root_instance, root, - target, - &mut Vec::new(), &mut FxHashSet::default(), + &mut involved, &mut FxHashMap::default(), recursion_limit, - ) + ); + involved.into() } pub(crate) fn mir_inliner_callees<'tcx>( diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 572ad585c8c..c4415294264 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -117,6 +117,7 @@ declare_passes! { mod check_inline : CheckForceInline; mod check_call_recursion : CheckCallRecursion, CheckDropRecursion; mod check_alignment : CheckAlignment; + mod check_enums : CheckEnums; mod check_const_item_mutation : CheckConstItemMutation; mod check_null : CheckNull; mod check_packed_ref : CheckPackedRef; @@ -215,7 +216,7 @@ pub fn provide(providers: &mut Providers) { optimized_mir, is_mir_available, is_ctfe_mir_available: is_mir_available, - mir_callgraph_reachable: inline::cycle::mir_callgraph_reachable, + mir_callgraph_cyclic: inline::cycle::mir_callgraph_cyclic, mir_inliner_callees: inline::cycle::mir_inliner_callees, promoted_mir, deduced_param_attrs: deduce_param_attrs::deduced_param_attrs, @@ -666,6 +667,7 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<' // Add some UB checks before any UB gets optimized away. &check_alignment::CheckAlignment, &check_null::CheckNull, + &check_enums::CheckEnums, // Before inlining: trim down MIR with passes to reduce inlining work. // Has to be done before inlining, otherwise actual call will be almost always inlined. |
