diff options
| -rw-r--r-- | compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs | 42 |
1 files changed, 39 insertions, 3 deletions
diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs index cda9ba9dcc8..7133724d07d 100644 --- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs +++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs @@ -3,7 +3,8 @@ use crate::MirPass; use rustc_data_structures::stable_set::FxHashSet; use rustc_middle::mir::{ - BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, TerminatorKind, + BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator, + TerminatorKind, }; use rustc_middle::ty::layout::TyAndLayout; use rustc_middle::ty::{Ty, TyCtxt}; @@ -71,6 +72,28 @@ fn variant_discriminants<'tcx>( } } +/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new +/// bb to use as the new target if not. +fn ensure_otherwise_unreachable<'tcx>( + body: &Body<'tcx>, + targets: &SwitchTargets, +) -> Option<BasicBlockData<'tcx>> { + let otherwise = targets.otherwise(); + let bb = &body.basic_blocks()[otherwise]; + if bb.terminator().kind == TerminatorKind::Unreachable + && bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_))) + { + return None; + } + + let mut new_block = BasicBlockData::new(Some(Terminator { + source_info: bb.terminator().source_info, + kind: TerminatorKind::Unreachable, + })); + new_block.is_cleanup = bb.is_cleanup; + Some(new_block) +} + impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { sess.mir_opt_level() > 0 @@ -99,12 +122,25 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { if let TerminatorKind::SwitchInt { targets, .. } = &mut body.basic_blocks_mut()[bb].terminator_mut().kind { - let new_targets = SwitchTargets::new( + let mut new_targets = SwitchTargets::new( targets.iter().filter(|(val, _)| allowed_variants.contains(val)), targets.otherwise(), ); - *targets = new_targets; + if new_targets.iter().count() == allowed_variants.len() { + if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) { + let new_otherwise = body.basic_blocks_mut().push(updated); + *new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise; + } + } + + if let TerminatorKind::SwitchInt { targets, .. } = + &mut body.basic_blocks_mut()[bb].terminator_mut().kind + { + *targets = new_targets; + } else { + unreachable!() + } } else { unreachable!() } |
