about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/check_enums.rs501
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs5
-rw-r--r--compiler/rustc_mir_transform/src/inline/cycle.rs273
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs4
4 files changed, 655 insertions, 128 deletions
diff --git a/compiler/rustc_mir_transform/src/check_enums.rs b/compiler/rustc_mir_transform/src/check_enums.rs
new file mode 100644
index 00000000000..e06e0c6122e
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/check_enums.rs
@@ -0,0 +1,501 @@
+use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
+use rustc_hir::LangItem;
+use rustc_index::IndexVec;
+use rustc_middle::bug;
+use rustc_middle::mir::visit::Visitor;
+use rustc_middle::mir::*;
+use rustc_middle::ty::layout::PrimitiveExt;
+use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
+use rustc_session::Session;
+use tracing::debug;
+
+/// This pass inserts checks for a valid enum discriminant where they are most
+/// likely to find UB, because checking everywhere like Miri would generate too
+/// much MIR.
+pub(super) struct CheckEnums;
+
+impl<'tcx> crate::MirPass<'tcx> for CheckEnums {
+    fn is_enabled(&self, sess: &Session) -> bool {
+        sess.ub_checks()
+    }
+
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+        // This pass emits new panics. If for whatever reason we do not have a panic
+        // implementation, running this pass may cause otherwise-valid code to not compile.
+        if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
+            return;
+        }
+
+        let typing_env = body.typing_env(tcx);
+        let basic_blocks = body.basic_blocks.as_mut();
+        let local_decls = &mut body.local_decls;
+
+        // This operation inserts new blocks. Each insertion changes the Location for all
+        // statements/blocks after. Iterating or visiting the MIR in order would require updating
+        // our current location after every insertion. By iterating backwards, we dodge this issue:
+        // The only Locations that an insertion changes have already been handled.
+        for block in basic_blocks.indices().rev() {
+            for statement_index in (0..basic_blocks[block].statements.len()).rev() {
+                let location = Location { block, statement_index };
+                let statement = &basic_blocks[block].statements[statement_index];
+                let source_info = statement.source_info;
+
+                let mut finder = EnumFinder::new(tcx, local_decls, typing_env);
+                finder.visit_statement(statement, location);
+
+                for check in finder.into_found_enums() {
+                    debug!("Inserting enum check");
+                    let new_block = split_block(basic_blocks, location);
+
+                    match check {
+                        EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => {
+                            insert_direct_enum_check(
+                                tcx,
+                                local_decls,
+                                basic_blocks,
+                                block,
+                                source_op,
+                                discr,
+                                op_size,
+                                valid_discrs,
+                                source_info,
+                                new_block,
+                            )
+                        }
+                        EnumCheckType::Uninhabited => insert_uninhabited_enum_check(
+                            tcx,
+                            local_decls,
+                            &mut basic_blocks[block],
+                            source_info,
+                            new_block,
+                        ),
+                        EnumCheckType::WithNiche {
+                            source_op,
+                            discr,
+                            op_size,
+                            offset,
+                            valid_range,
+                        } => insert_niche_check(
+                            tcx,
+                            local_decls,
+                            &mut basic_blocks[block],
+                            source_op,
+                            valid_range,
+                            discr,
+                            op_size,
+                            offset,
+                            source_info,
+                            new_block,
+                        ),
+                    }
+                }
+            }
+        }
+    }
+
+    fn is_required(&self) -> bool {
+        true
+    }
+}
+
+/// Represent the different kind of enum checks we can insert.
+enum EnumCheckType<'tcx> {
+    /// We know we try to create an uninhabited enum from an inhabited variant.
+    Uninhabited,
+    /// We know the enum does no niche optimizations and can thus easily compute
+    /// the valid discriminants.
+    Direct {
+        source_op: Operand<'tcx>,
+        discr: TyAndSize<'tcx>,
+        op_size: Size,
+        valid_discrs: Vec<u128>,
+    },
+    /// We try to construct an enum that has a niche.
+    WithNiche {
+        source_op: Operand<'tcx>,
+        discr: TyAndSize<'tcx>,
+        op_size: Size,
+        offset: Size,
+        valid_range: WrappingRange,
+    },
+}
+
+struct TyAndSize<'tcx> {
+    pub ty: Ty<'tcx>,
+    pub size: Size,
+}
+
+/// A [Visitor] that finds the construction of enums and evaluates which checks
+/// we should apply.
+struct EnumFinder<'a, 'tcx> {
+    tcx: TyCtxt<'tcx>,
+    local_decls: &'a mut LocalDecls<'tcx>,
+    typing_env: TypingEnv<'tcx>,
+    enums: Vec<EnumCheckType<'tcx>>,
+}
+
+impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
+    fn new(
+        tcx: TyCtxt<'tcx>,
+        local_decls: &'a mut LocalDecls<'tcx>,
+        typing_env: TypingEnv<'tcx>,
+    ) -> Self {
+        EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() }
+    }
+
+    /// Returns the found enum creations and which checks should be inserted.
+    fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
+        self.enums
+    }
+}
+
+impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
+    fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
+        if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
+            let ty::Adt(adt_def, _) = ty.kind() else {
+                return;
+            };
+            if !adt_def.is_enum() {
+                return;
+            }
+
+            let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
+                return;
+            };
+            let Ok(op_layout) = self
+                .tcx
+                .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
+            else {
+                return;
+            };
+
+            match enum_layout.variants {
+                Variants::Empty if op_layout.is_uninhabited() => return,
+                // An empty enum that tries to be constructed from an inhabited value, this
+                // is never correct.
+                Variants::Empty => {
+                    // The enum layout is uninhabited but we construct it from sth inhabited.
+                    // This is always UB.
+                    self.enums.push(EnumCheckType::Uninhabited);
+                }
+                // Construction of Single value enums is always fine.
+                Variants::Single { .. } => {}
+                // Construction of an enum with multiple variants but no niche optimizations.
+                Variants::Multiple {
+                    tag_encoding: TagEncoding::Direct,
+                    tag: Scalar::Initialized { value, .. },
+                    ..
+                } => {
+                    let valid_discrs =
+                        adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
+
+                    let discr =
+                        TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
+                    self.enums.push(EnumCheckType::Direct {
+                        source_op: op.to_copy(),
+                        discr,
+                        op_size: op_layout.size,
+                        valid_discrs,
+                    });
+                }
+                // Construction of an enum with multiple variants and niche optimizations.
+                Variants::Multiple {
+                    tag_encoding: TagEncoding::Niche { .. },
+                    tag: Scalar::Initialized { value, valid_range, .. },
+                    tag_field,
+                    ..
+                } => {
+                    let discr =
+                        TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
+                    self.enums.push(EnumCheckType::WithNiche {
+                        source_op: op.to_copy(),
+                        discr,
+                        op_size: op_layout.size,
+                        offset: enum_layout.fields.offset(tag_field.as_usize()),
+                        valid_range,
+                    });
+                }
+                _ => return,
+            }
+
+            self.super_rvalue(rvalue, location);
+        }
+    }
+}
+
+fn split_block(
+    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
+    location: Location,
+) -> BasicBlock {
+    let block_data = &mut basic_blocks[location.block];
+
+    // Drain every statement after this one and move the current terminator to a new basic block.
+    let new_block = BasicBlockData {
+        statements: block_data.statements.split_off(location.statement_index),
+        terminator: block_data.terminator.take(),
+        is_cleanup: block_data.is_cleanup,
+    };
+
+    basic_blocks.push(new_block)
+}
+
+/// Inserts the cast of an operand (any type) to a u128 value that holds the discriminant value.
+fn insert_discr_cast_to_u128<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
+    block_data: &mut BasicBlockData<'tcx>,
+    source_op: Operand<'tcx>,
+    discr: TyAndSize<'tcx>,
+    op_size: Size,
+    offset: Option<Size>,
+    source_info: SourceInfo,
+) -> Place<'tcx> {
+    let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> {
+        match size.bytes() {
+            1 => tcx.types.u8,
+            2 => tcx.types.u16,
+            4 => tcx.types.u32,
+            8 => tcx.types.u64,
+            16 => tcx.types.u128,
+            invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid),
+        }
+    };
+
+    let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
+        // The discriminant is less wide than the operand, cast the operand into
+        // [MaybeUninit; N] and then index into it.
+        let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8);
+        let array_len = op_size.bytes();
+        let mu_array_ty = Ty::new_array(tcx, mu, array_len);
+        let mu_array =
+            local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into();
+        let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty);
+        block_data.statements.push(Statement {
+            source_info,
+            kind: StatementKind::Assign(Box::new((mu_array, rvalue))),
+        });
+
+        // Index into the array of MaybeUninit to get something that is actually
+        // as wide as the discriminant.
+        let offset = offset.unwrap_or(Size::ZERO);
+        let smaller_mu_array = mu_array.project_deeper(
+            &[ProjectionElem::Subslice {
+                from: offset.bytes(),
+                to: offset.bytes() + discr.size.bytes(),
+                from_end: false,
+            }],
+            tcx,
+        );
+
+        (CastKind::Transmute, Operand::Copy(smaller_mu_array))
+    } else {
+        let operand_int_ty = get_ty_for_size(tcx, op_size);
+
+        let op_as_int =
+            local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into();
+        let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty);
+        block_data.statements.push(Statement {
+            source_info,
+            kind: StatementKind::Assign(Box::new((op_as_int, rvalue))),
+        });
+
+        (CastKind::IntToInt, Operand::Copy(op_as_int))
+    };
+
+    // Cast the resulting value to the actual discriminant integer type.
+    let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty);
+    let discr_in_discr_ty =
+        local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into();
+    block_data.statements.push(Statement {
+        source_info,
+        kind: StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))),
+    });
+
+    // Cast the discriminant to a u128 (base for comparisions of enum discriminants).
+    let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128);
+    let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128);
+    let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into();
+    block_data
+        .statements
+        .push(Statement { source_info, kind: StatementKind::Assign(Box::new((discr, rvalue))) });
+
+    discr
+}
+
+fn insert_direct_enum_check<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
+    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
+    current_block: BasicBlock,
+    source_op: Operand<'tcx>,
+    discr: TyAndSize<'tcx>,
+    op_size: Size,
+    discriminants: Vec<u128>,
+    source_info: SourceInfo,
+    new_block: BasicBlock,
+) {
+    // Insert a new target block that is branched to in case of an invalid discriminant.
+    let invalid_discr_block_data = BasicBlockData::new(None, false);
+    let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
+    let block_data = &mut basic_blocks[current_block];
+    let discr = insert_discr_cast_to_u128(
+        tcx,
+        local_decls,
+        block_data,
+        source_op,
+        discr,
+        op_size,
+        None,
+        source_info,
+    );
+
+    // Branch based on the discriminant value.
+    block_data.terminator = Some(Terminator {
+        source_info,
+        kind: TerminatorKind::SwitchInt {
+            discr: Operand::Copy(discr),
+            targets: SwitchTargets::new(
+                discriminants.into_iter().map(|discr| (discr, new_block)),
+                invalid_discr_block,
+            ),
+        },
+    });
+
+    // Abort in case of an invalid enum discriminant.
+    basic_blocks[invalid_discr_block].terminator = Some(Terminator {
+        source_info,
+        kind: TerminatorKind::Assert {
+            cond: Operand::Constant(Box::new(ConstOperand {
+                span: source_info.span,
+                user_ty: None,
+                const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
+            })),
+            expected: true,
+            target: new_block,
+            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
+            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
+            // We never want to insert an unwind into unsafe code, because unwinding could
+            // make a failing UB check turn into much worse UB when we start unwinding.
+            unwind: UnwindAction::Unreachable,
+        },
+    });
+}
+
+fn insert_uninhabited_enum_check<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
+    block_data: &mut BasicBlockData<'tcx>,
+    source_info: SourceInfo,
+    new_block: BasicBlock,
+) {
+    let is_ok: Place<'_> =
+        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
+    block_data.statements.push(Statement {
+        source_info,
+        kind: StatementKind::Assign(Box::new((
+            is_ok,
+            Rvalue::Use(Operand::Constant(Box::new(ConstOperand {
+                span: source_info.span,
+                user_ty: None,
+                const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
+            }))),
+        ))),
+    });
+
+    block_data.terminator = Some(Terminator {
+        source_info,
+        kind: TerminatorKind::Assert {
+            cond: Operand::Copy(is_ok),
+            expected: true,
+            target: new_block,
+            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new(
+                ConstOperand {
+                    span: source_info.span,
+                    user_ty: None,
+                    const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128),
+                },
+            )))),
+            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
+            // We never want to insert an unwind into unsafe code, because unwinding could
+            // make a failing UB check turn into much worse UB when we start unwinding.
+            unwind: UnwindAction::Unreachable,
+        },
+    });
+}
+
+fn insert_niche_check<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
+    block_data: &mut BasicBlockData<'tcx>,
+    source_op: Operand<'tcx>,
+    valid_range: WrappingRange,
+    discr: TyAndSize<'tcx>,
+    op_size: Size,
+    offset: Size,
+    source_info: SourceInfo,
+    new_block: BasicBlock,
+) {
+    let discr = insert_discr_cast_to_u128(
+        tcx,
+        local_decls,
+        block_data,
+        source_op,
+        discr,
+        op_size,
+        Some(offset),
+        source_info,
+    );
+
+    // Compare the discriminant agains the valid_range.
+    let start_const = Operand::Constant(Box::new(ConstOperand {
+        span: source_info.span,
+        user_ty: None,
+        const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128),
+    }));
+    let end_start_diff_const = Operand::Constant(Box::new(ConstOperand {
+        span: source_info.span,
+        user_ty: None,
+        const_: Const::Val(
+            ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)),
+            tcx.types.u128,
+        ),
+    }));
+
+    let discr_diff: Place<'_> =
+        local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
+    block_data.statements.push(Statement {
+        source_info,
+        kind: StatementKind::Assign(Box::new((
+            discr_diff,
+            Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))),
+        ))),
+    });
+
+    let is_ok: Place<'_> =
+        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
+    block_data.statements.push(Statement {
+        source_info,
+        kind: StatementKind::Assign(Box::new((
+            is_ok,
+            Rvalue::BinaryOp(
+                // This is a `WrappingRange`, so make sure to get the wrapping right.
+                BinOp::Le,
+                Box::new((Operand::Copy(discr_diff), end_start_diff_const)),
+            ),
+        ))),
+    });
+
+    block_data.terminator = Some(Terminator {
+        source_info,
+        kind: TerminatorKind::Assert {
+            cond: Operand::Copy(is_ok),
+            expected: true,
+            target: new_block,
+            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
+            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
+            // We never want to insert an unwind into unsafe code, because unwinding could
+            // make a failing UB check turn into much worse UB when we start unwinding.
+            unwind: UnwindAction::Unreachable,
+        },
+    });
+}
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index f48dba9663a..c27087fea11 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -770,14 +770,15 @@ fn check_mir_is_available<'tcx, I: Inliner<'tcx>>(
         return Ok(());
     }
 
-    if callee_def_id.is_local()
+    if let Some(callee_def_id) = callee_def_id.as_local()
         && !inliner
             .tcx()
             .is_lang_item(inliner.tcx().parent(caller_def_id), rustc_hir::LangItem::FnOnce)
     {
         // If we know for sure that the function we're calling will itself try to
         // call us, then we avoid inlining that function.
-        if inliner.tcx().mir_callgraph_reachable((callee, caller_def_id.expect_local())) {
+        if inliner.tcx().mir_callgraph_cyclic(caller_def_id.expect_local()).contains(&callee_def_id)
+        {
             debug!("query cycle avoidance");
             return Err("caller might be reachable from callee");
         }
diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs
index a944960ce4a..08f3ce5fd67 100644
--- a/compiler/rustc_mir_transform/src/inline/cycle.rs
+++ b/compiler/rustc_mir_transform/src/inline/cycle.rs
@@ -1,5 +1,6 @@
 use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
 use rustc_data_structures::stack::ensure_sufficient_stack;
+use rustc_data_structures::unord::UnordSet;
 use rustc_hir::def_id::{DefId, LocalDefId};
 use rustc_middle::mir::TerminatorKind;
 use rustc_middle::ty::{self, GenericArgsRef, InstanceKind, TyCtxt, TypeVisitableExt};
@@ -7,137 +8,143 @@ use rustc_session::Limit;
 use rustc_span::sym;
 use tracing::{instrument, trace};
 
-// FIXME: check whether it is cheaper to precompute the entire call graph instead of invoking
-// this query ridiculously often.
-#[instrument(level = "debug", skip(tcx, root, target))]
-pub(crate) fn mir_callgraph_reachable<'tcx>(
+#[instrument(level = "debug", skip(tcx), ret)]
+fn should_recurse<'tcx>(tcx: TyCtxt<'tcx>, callee: ty::Instance<'tcx>) -> bool {
+    match callee.def {
+        // If there is no MIR available (either because it was not in metadata or
+        // because it has no MIR because it's an extern function), then the inliner
+        // won't cause cycles on this.
+        InstanceKind::Item(_) => {
+            if !tcx.is_mir_available(callee.def_id()) {
+                return false;
+            }
+        }
+
+        // These have no own callable MIR.
+        InstanceKind::Intrinsic(_) | InstanceKind::Virtual(..) => return false,
+
+        // These have MIR and if that MIR is inlined, instantiated and then inlining is run
+        // again, a function item can end up getting inlined. Thus we'll be able to cause
+        // a cycle that way
+        InstanceKind::VTableShim(_)
+        | InstanceKind::ReifyShim(..)
+        | InstanceKind::FnPtrShim(..)
+        | InstanceKind::ClosureOnceShim { .. }
+        | InstanceKind::ConstructCoroutineInClosureShim { .. }
+        | InstanceKind::ThreadLocalShim { .. }
+        | InstanceKind::CloneShim(..) => {}
+
+        // This shim does not call any other functions, thus there can be no recursion.
+        InstanceKind::FnPtrAddrShim(..) => return false,
+
+        // FIXME: A not fully instantiated drop shim can cause ICEs if one attempts to
+        // have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this
+        // needs some more analysis.
+        InstanceKind::DropGlue(..)
+        | InstanceKind::FutureDropPollShim(..)
+        | InstanceKind::AsyncDropGlue(..)
+        | InstanceKind::AsyncDropGlueCtorShim(..) => {
+            if callee.has_param() {
+                return false;
+            }
+        }
+    }
+
+    crate::pm::should_run_pass(tcx, &crate::inline::Inline, crate::pm::Optimizations::Allowed)
+        || crate::inline::ForceInline::should_run_pass_for_callee(tcx, callee.def.def_id())
+}
+
+#[instrument(
+    level = "debug",
+    skip(tcx, typing_env, seen, involved, recursion_limiter, recursion_limit),
+    ret
+)]
+fn process<'tcx>(
     tcx: TyCtxt<'tcx>,
-    (root, target): (ty::Instance<'tcx>, LocalDefId),
+    typing_env: ty::TypingEnv<'tcx>,
+    caller: ty::Instance<'tcx>,
+    target: LocalDefId,
+    seen: &mut FxHashSet<ty::Instance<'tcx>>,
+    involved: &mut FxHashSet<LocalDefId>,
+    recursion_limiter: &mut FxHashMap<DefId, usize>,
+    recursion_limit: Limit,
 ) -> bool {
-    trace!(%root, target = %tcx.def_path_str(target));
-    assert_ne!(
-        root.def_id().expect_local(),
-        target,
-        "you should not call `mir_callgraph_reachable` on immediate self recursion"
-    );
-    assert!(
-        matches!(root.def, InstanceKind::Item(_)),
-        "you should not call `mir_callgraph_reachable` on shims"
-    );
-    assert!(
-        !tcx.is_constructor(root.def_id()),
-        "you should not call `mir_callgraph_reachable` on enum/struct constructor functions"
-    );
-    #[instrument(
-        level = "debug",
-        skip(tcx, typing_env, target, stack, seen, recursion_limiter, caller, recursion_limit)
-    )]
-    fn process<'tcx>(
-        tcx: TyCtxt<'tcx>,
-        typing_env: ty::TypingEnv<'tcx>,
-        caller: ty::Instance<'tcx>,
-        target: LocalDefId,
-        stack: &mut Vec<ty::Instance<'tcx>>,
-        seen: &mut FxHashSet<ty::Instance<'tcx>>,
-        recursion_limiter: &mut FxHashMap<DefId, usize>,
-        recursion_limit: Limit,
-    ) -> bool {
-        trace!(%caller);
-        for &(callee, args) in tcx.mir_inliner_callees(caller.def) {
-            let Ok(args) = caller.try_instantiate_mir_and_normalize_erasing_regions(
-                tcx,
-                typing_env,
-                ty::EarlyBinder::bind(args),
-            ) else {
-                trace!(?caller, ?typing_env, ?args, "cannot normalize, skipping");
-                continue;
-            };
-            let Ok(Some(callee)) = ty::Instance::try_resolve(tcx, typing_env, callee, args) else {
-                trace!(?callee, "cannot resolve, skipping");
-                continue;
-            };
+    trace!(%caller);
+    let mut cycle_found = false;
 
-            // Found a path.
-            if callee.def_id() == target.to_def_id() {
-                return true;
-            }
+    for &(callee, args) in tcx.mir_inliner_callees(caller.def) {
+        let Ok(args) = caller.try_instantiate_mir_and_normalize_erasing_regions(
+            tcx,
+            typing_env,
+            ty::EarlyBinder::bind(args),
+        ) else {
+            trace!(?caller, ?typing_env, ?args, "cannot normalize, skipping");
+            continue;
+        };
+        let Ok(Some(callee)) = ty::Instance::try_resolve(tcx, typing_env, callee, args) else {
+            trace!(?callee, "cannot resolve, skipping");
+            continue;
+        };
 
-            if tcx.is_constructor(callee.def_id()) {
-                trace!("constructors always have MIR");
-                // Constructor functions cannot cause a query cycle.
-                continue;
-            }
+        // Found a path.
+        if callee.def_id() == target.to_def_id() {
+            cycle_found = true;
+        }
 
-            match callee.def {
-                InstanceKind::Item(_) => {
-                    // If there is no MIR available (either because it was not in metadata or
-                    // because it has no MIR because it's an extern function), then the inliner
-                    // won't cause cycles on this.
-                    if !tcx.is_mir_available(callee.def_id()) {
-                        trace!(?callee, "no mir available, skipping");
-                        continue;
-                    }
-                }
-                // These have no own callable MIR.
-                InstanceKind::Intrinsic(_) | InstanceKind::Virtual(..) => continue,
-                // These have MIR and if that MIR is inlined, instantiated and then inlining is run
-                // again, a function item can end up getting inlined. Thus we'll be able to cause
-                // a cycle that way
-                InstanceKind::VTableShim(_)
-                | InstanceKind::ReifyShim(..)
-                | InstanceKind::FnPtrShim(..)
-                | InstanceKind::ClosureOnceShim { .. }
-                | InstanceKind::ConstructCoroutineInClosureShim { .. }
-                | InstanceKind::ThreadLocalShim { .. }
-                | InstanceKind::CloneShim(..) => {}
-
-                // This shim does not call any other functions, thus there can be no recursion.
-                InstanceKind::FnPtrAddrShim(..) => {
-                    continue;
-                }
-                InstanceKind::DropGlue(..)
-                | InstanceKind::FutureDropPollShim(..)
-                | InstanceKind::AsyncDropGlue(..)
-                | InstanceKind::AsyncDropGlueCtorShim(..) => {
-                    // FIXME: A not fully instantiated drop shim can cause ICEs if one attempts to
-                    // have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this
-                    // needs some more analysis.
-                    if callee.has_param() {
-                        continue;
-                    }
-                }
-            }
+        if tcx.is_constructor(callee.def_id()) {
+            trace!("constructors always have MIR");
+            // Constructor functions cannot cause a query cycle.
+            continue;
+        }
+
+        if !should_recurse(tcx, callee) {
+            continue;
+        }
 
-            if seen.insert(callee) {
-                let recursion = recursion_limiter.entry(callee.def_id()).or_default();
-                trace!(?callee, recursion = *recursion);
-                if recursion_limit.value_within_limit(*recursion) {
-                    *recursion += 1;
-                    stack.push(callee);
-                    let found_recursion = ensure_sufficient_stack(|| {
-                        process(
-                            tcx,
-                            typing_env,
-                            callee,
-                            target,
-                            stack,
-                            seen,
-                            recursion_limiter,
-                            recursion_limit,
-                        )
-                    });
-                    if found_recursion {
-                        return true;
-                    }
-                    stack.pop();
-                } else {
-                    // Pessimistically assume that there could be recursion.
-                    return true;
+        if seen.insert(callee) {
+            let recursion = recursion_limiter.entry(callee.def_id()).or_default();
+            trace!(?callee, recursion = *recursion);
+            let found_recursion = if recursion_limit.value_within_limit(*recursion) {
+                *recursion += 1;
+                ensure_sufficient_stack(|| {
+                    process(
+                        tcx,
+                        typing_env,
+                        callee,
+                        target,
+                        seen,
+                        involved,
+                        recursion_limiter,
+                        recursion_limit,
+                    )
+                })
+            } else {
+                // Pessimistically assume that there could be recursion.
+                true
+            };
+            if found_recursion {
+                if let Some(callee) = callee.def_id().as_local() {
+                    // Calling `optimized_mir` of a non-local definition cannot cycle.
+                    involved.insert(callee);
                 }
+                cycle_found = true;
             }
         }
-        false
     }
+
+    cycle_found
+}
+
+#[instrument(level = "debug", skip(tcx), ret)]
+pub(crate) fn mir_callgraph_cyclic<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    root: LocalDefId,
+) -> UnordSet<LocalDefId> {
+    assert!(
+        !tcx.is_constructor(root.to_def_id()),
+        "you should not call `mir_callgraph_reachable` on enum/struct constructor functions"
+    );
+
     // FIXME(-Znext-solver=no): Remove this hack when trait solver overflow can return an error.
     // In code like that pointed out in #128887, the type complexity we ask the solver to deal with
     // grows as we recurse into the call graph. If we use the same recursion limit here and in the
@@ -146,16 +153,32 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
     // the default recursion limits are quite generous for us. If we need to recurse 64 times
     // into the call graph, we're probably not going to find any useful MIR inlining.
     let recursion_limit = tcx.recursion_limit() / 2;
+    let mut involved = FxHashSet::default();
+    let typing_env = ty::TypingEnv::post_analysis(tcx, root);
+    let Ok(Some(root_instance)) = ty::Instance::try_resolve(
+        tcx,
+        typing_env,
+        root.to_def_id(),
+        ty::GenericArgs::identity_for_item(tcx, root.to_def_id()),
+    ) else {
+        trace!("cannot resolve, skipping");
+        return involved.into();
+    };
+    if !should_recurse(tcx, root_instance) {
+        trace!("cannot walk, skipping");
+        return involved.into();
+    }
     process(
         tcx,
-        ty::TypingEnv::post_analysis(tcx, target),
+        typing_env,
+        root_instance,
         root,
-        target,
-        &mut Vec::new(),
         &mut FxHashSet::default(),
+        &mut involved,
         &mut FxHashMap::default(),
         recursion_limit,
-    )
+    );
+    involved.into()
 }
 
 pub(crate) fn mir_inliner_callees<'tcx>(
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 572ad585c8c..c4415294264 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -117,6 +117,7 @@ declare_passes! {
     mod check_inline : CheckForceInline;
     mod check_call_recursion : CheckCallRecursion, CheckDropRecursion;
     mod check_alignment : CheckAlignment;
+    mod check_enums : CheckEnums;
     mod check_const_item_mutation : CheckConstItemMutation;
     mod check_null : CheckNull;
     mod check_packed_ref : CheckPackedRef;
@@ -215,7 +216,7 @@ pub fn provide(providers: &mut Providers) {
         optimized_mir,
         is_mir_available,
         is_ctfe_mir_available: is_mir_available,
-        mir_callgraph_reachable: inline::cycle::mir_callgraph_reachable,
+        mir_callgraph_cyclic: inline::cycle::mir_callgraph_cyclic,
         mir_inliner_callees: inline::cycle::mir_inliner_callees,
         promoted_mir,
         deduced_param_attrs: deduce_param_attrs::deduced_param_attrs,
@@ -666,6 +667,7 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'
             // Add some UB checks before any UB gets optimized away.
             &check_alignment::CheckAlignment,
             &check_null::CheckNull,
+            &check_enums::CheckEnums,
             // Before inlining: trim down MIR with passes to reduce inlining work.
 
             // Has to be done before inlining, otherwise actual call will be almost always inlined.