about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs84
1 files changed, 57 insertions, 27 deletions
diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
index e68d37f4c70..57fe46ad75a 100644
--- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
+++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
@@ -2,8 +2,10 @@
 
 use crate::MirPass;
 use rustc_data_structures::fx::FxHashSet;
+use rustc_middle::mir::patch::MirPatch;
 use rustc_middle::mir::{
-    BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
+    BasicBlock, BasicBlockData, BasicBlocks, Body, Local, Operand, Rvalue, StatementKind,
+    TerminatorKind,
 };
 use rustc_middle::ty::layout::TyAndLayout;
 use rustc_middle::ty::{Ty, TyCtxt};
@@ -77,7 +79,8 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         trace!("UninhabitedEnumBranching starting for {:?}", body.source);
 
-        let mut removable_switchs = Vec::new();
+        let mut unreachable_targets = Vec::new();
+        let mut patch = MirPatch::new(body);
 
         for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
             trace!("processing block {:?}", bb);
@@ -92,46 +95,73 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
                 tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty),
             );
 
-            let allowed_variants = if let Ok(layout) = layout {
+            let mut allowed_variants = if let Ok(layout) = layout {
                 variant_discriminants(&layout, discriminant_ty, tcx)
+            } else if let Some(variant_range) = discriminant_ty.variant_range(tcx) {
+                variant_range
+                    .map(|variant| {
+                        discriminant_ty.discriminant_for_variant(tcx, variant).unwrap().val
+                    })
+                    .collect()
             } else {
                 continue;
             };
 
             trace!("allowed_variants = {:?}", allowed_variants);
 
-            let terminator = bb_data.terminator();
-            let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };
+            unreachable_targets.clear();
+            let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else {
+                bug!()
+            };
 
-            let mut reachable_count = 0;
             for (index, (val, _)) in targets.iter().enumerate() {
-                if allowed_variants.contains(&val) {
-                    reachable_count += 1;
-                } else {
-                    removable_switchs.push((bb, index));
+                if !allowed_variants.remove(&val) {
+                    unreachable_targets.push(index);
+                }
+            }
+            let otherwise_is_empty_unreachable =
+                body.basic_blocks[targets.otherwise()].is_empty_unreachable();
+            // After resolving https://github.com/llvm/llvm-project/issues/78578,
+            // we can remove the limit on the number of successors.
+            fn check_successors(basic_blocks: &BasicBlocks<'_>, bb: BasicBlock) -> bool {
+                let mut successors = basic_blocks[bb].terminator().successors();
+                let Some(first_successor) = successors.next() else { return true };
+                if successors.next().is_some() {
+                    return true;
                 }
+                if let TerminatorKind::SwitchInt { .. } =
+                    &basic_blocks[first_successor].terminator().kind
+                {
+                    return false;
+                };
+                true
             }
+            let otherwise_is_last_variant = !otherwise_is_empty_unreachable
+                && allowed_variants.len() == 1
+                && check_successors(&body.basic_blocks, targets.otherwise());
+            let replace_otherwise_to_unreachable = otherwise_is_last_variant
+                || !otherwise_is_empty_unreachable && allowed_variants.is_empty();
 
-            if reachable_count == allowed_variants.len() {
-                removable_switchs.push((bb, targets.iter().count()));
+            if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable {
+                continue;
             }
-        }
 
-        if removable_switchs.is_empty() {
-            return;
+            let unreachable_block = patch.unreachable_no_cleanup_block();
+            let mut targets = targets.clone();
+            if replace_otherwise_to_unreachable {
+                if otherwise_is_last_variant {
+                    #[allow(rustc::potential_query_instability)]
+                    let last_variant = *allowed_variants.iter().next().unwrap();
+                    targets.add_target(last_variant, targets.otherwise());
+                }
+                unreachable_targets.push(targets.iter().count());
+            }
+            for index in unreachable_targets.iter() {
+                targets.all_targets_mut()[*index] = unreachable_block;
+            }
+            patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() });
         }
 
-        let new_block = BasicBlockData::new(Some(Terminator {
-            source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
-            kind: TerminatorKind::Unreachable,
-        }));
-        let unreachable_block = body.basic_blocks.as_mut().push(new_block);
-
-        for (bb, index) in removable_switchs {
-            let bb = &mut body.basic_blocks.as_mut()[bb];
-            let terminator = bb.terminator_mut();
-            let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
-            targets.all_targets_mut()[index] = unreachable_block;
-        }
+        patch.apply(body);
     }
 }