about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs104
-rw-r--r--tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir8
-rw-r--r--tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff9
-rw-r--r--tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir12
-rw-r--r--tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff8
-rw-r--r--tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff6
6 files changed, 84 insertions, 63 deletions
diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
index cb028a92d49..98f67e18a8d 100644
--- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
+++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
@@ -3,8 +3,7 @@
 use crate::MirPass;
 use rustc_data_structures::fx::FxHashSet;
 use rustc_middle::mir::{
-    BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator,
-    TerminatorKind,
+    BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
 };
 use rustc_middle::ty::layout::TyAndLayout;
 use rustc_middle::ty::{Ty, TyCtxt};
@@ -30,17 +29,20 @@ fn get_switched_on_type<'tcx>(
     let terminator = block_data.terminator();
 
     // Only bother checking blocks which terminate by switching on a local.
-    if let Some(local) = get_discriminant_local(&terminator.kind)
-        && let [.., stmt_before_term] = &block_data.statements[..]
-        && let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
+    let local = get_discriminant_local(&terminator.kind)?;
+
+    let stmt_before_term = block_data.statements.last()?;
+
+    if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
         && l.as_local() == Some(local)
-        && let ty = place.ty(body, tcx).ty
-        && ty.is_enum()
     {
-        Some(ty)
-    } else {
-        None
+        let ty = place.ty(body, tcx).ty;
+        if ty.is_enum() {
+            return Some(ty);
+        }
     }
+
+    None
 }
 
 fn variant_discriminants<'tcx>(
@@ -67,28 +69,6 @@ fn variant_discriminants<'tcx>(
     }
 }
 
-/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new
-/// bb to use as the new target if not.
-fn ensure_otherwise_unreachable<'tcx>(
-    body: &Body<'tcx>,
-    targets: &SwitchTargets,
-) -> Option<BasicBlockData<'tcx>> {
-    let otherwise = targets.otherwise();
-    let bb = &body.basic_blocks[otherwise];
-    if bb.terminator().kind == TerminatorKind::Unreachable
-        && bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_)))
-    {
-        return None;
-    }
-
-    let mut new_block = BasicBlockData::new(Some(Terminator {
-        source_info: bb.terminator().source_info,
-        kind: TerminatorKind::Unreachable,
-    }));
-    new_block.is_cleanup = bb.is_cleanup;
-    Some(new_block)
-}
-
 impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
         sess.mir_opt_level() > 0
@@ -97,13 +77,16 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         trace!("UninhabitedEnumBranching starting for {:?}", body.source);
 
-        for bb in body.basic_blocks.indices() {
+        let mut removable_switchs = Vec::new();
+
+        for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
             trace!("processing block {:?}", bb);
 
-            let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks[bb], tcx, body)
-            else {
+            if bb_data.is_cleanup {
                 continue;
-            };
+            }
+
+            let Some(discriminant_ty) = get_switched_on_type(&bb_data, tcx, body) else { continue };
 
             let layout = tcx.layout_of(
                 tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty),
@@ -117,31 +100,38 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
 
             trace!("allowed_variants = {:?}", allowed_variants);
 
-            if let TerminatorKind::SwitchInt { targets, .. } =
-                &mut body.basic_blocks_mut()[bb].terminator_mut().kind
-            {
-                let mut new_targets = SwitchTargets::new(
-                    targets.iter().filter(|(val, _)| allowed_variants.contains(val)),
-                    targets.otherwise(),
-                );
-
-                if new_targets.iter().count() == allowed_variants.len() {
-                    if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) {
-                        let new_otherwise = body.basic_blocks_mut().push(updated);
-                        *new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise;
-                    }
-                }
+            let terminator = bb_data.terminator();
+            let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };
 
-                if let TerminatorKind::SwitchInt { targets, .. } =
-                    &mut body.basic_blocks_mut()[bb].terminator_mut().kind
-                {
-                    *targets = new_targets;
+            let mut reachable_count = 0;
+            for (index, (val, _)) in targets.iter().enumerate() {
+                if allowed_variants.contains(&val) {
+                    reachable_count += 1;
                 } else {
-                    unreachable!()
+                    removable_switchs.push((bb, index));
                 }
-            } else {
-                unreachable!()
             }
+
+            if reachable_count == allowed_variants.len() {
+                removable_switchs.push((bb, targets.iter().count()));
+            }
+        }
+
+        if removable_switchs.is_empty() {
+            return;
+        }
+
+        let new_block = BasicBlockData::new(Some(Terminator {
+            source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
+            kind: TerminatorKind::Unreachable,
+        }));
+        let unreachable_block = body.basic_blocks.as_mut().push(new_block);
+
+        for (bb, index) in removable_switchs {
+            let bb = &mut body.basic_blocks.as_mut()[bb];
+            let terminator = bb.terminator_mut();
+            let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
+            targets.all_targets_mut()[index] = unreachable_block;
         }
     }
 }
diff --git a/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir b/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir
index 1ee44e48c8e..1df3b74b3e7 100644
--- a/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir
+++ b/tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir
@@ -12,14 +12,20 @@ fn main() -> () {
     let mut _8: isize;
     let _9: &str;
     let mut _10: bool;
+    let mut _11: bool;
+    let mut _12: bool;
 
     bb0: {
         StorageLive(_1);
         StorageLive(_2);
         _2 = Test1::C;
         _3 = discriminant(_2);
-        _10 = Eq(_3, const 2_isize);
+        _10 = Ne(_3, const 0_isize);
         assume(move _10);
+        _11 = Ne(_3, const 1_isize);
+        assume(move _11);
+        _12 = Eq(_3, const 2_isize);
+        assume(move _12);
         StorageLive(_5);
         _5 = const "C";
         _1 = &(*_5);
diff --git a/tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff b/tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff
index 9db95abec34..5b107f80ce8 100644
--- a/tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff
+++ b/tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff
@@ -19,7 +19,7 @@
           _2 = Test1::C;
           _3 = discriminant(_2);
 -         switchInt(move _3) -> [0: bb3, 1: bb4, 2: bb1, otherwise: bb2];
-+         switchInt(move _3) -> [2: bb1, otherwise: bb2];
++         switchInt(move _3) -> [0: bb9, 1: bb9, 2: bb1, otherwise: bb9];
       }
   
       bb1: {
@@ -54,7 +54,8 @@
           StorageLive(_7);
           _7 = Test2::D;
           _8 = discriminant(_7);
-          switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb2];
+-         switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb2];
++         switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb9];
       }
   
       bb6: {
@@ -75,6 +76,10 @@
           StorageDead(_6);
           _0 = const ();
           return;
++     }
++ 
++     bb9: {
++         unreachable;
       }
   }
   
diff --git a/tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir b/tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir
index 9c0c5d18917..06375b3ffae 100644
--- a/tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir
+++ b/tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir
@@ -15,6 +15,10 @@ fn main() -> () {
     let _11: &str;
     let _12: &str;
     let _13: &str;
+    let mut _14: bool;
+    let mut _15: bool;
+    let mut _16: bool;
+    let mut _17: bool;
     scope 1 {
         debug plop => _1;
     }
@@ -29,6 +33,10 @@ fn main() -> () {
         StorageLive(_4);
         _4 = &(_1.1: Test1);
         _5 = discriminant((*_4));
+        _16 = Ne(_5, const 0_isize);
+        assume(move _16);
+        _17 = Ne(_5, const 1_isize);
+        assume(move _17);
         switchInt(move _5) -> [2: bb3, 3: bb1, otherwise: bb2];
     }
 
@@ -57,6 +65,10 @@ fn main() -> () {
         StorageDead(_3);
         StorageLive(_9);
         _10 = discriminant((_1.1: Test1));
+        _14 = Ne(_10, const 0_isize);
+        assume(move _14);
+        _15 = Ne(_10, const 1_isize);
+        assume(move _15);
         switchInt(move _10) -> [2: bb6, 3: bb5, otherwise: bb2];
     }
 
diff --git a/tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff b/tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff
index 12ce6505af9..165421acd69 100644
--- a/tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff
+++ b/tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff
@@ -31,7 +31,7 @@
           _4 = &(_1.1: Test1);
           _5 = discriminant((*_4));
 -         switchInt(move _5) -> [0: bb3, 1: bb4, 2: bb5, 3: bb1, otherwise: bb2];
-+         switchInt(move _5) -> [2: bb5, 3: bb1, otherwise: bb2];
++         switchInt(move _5) -> [0: bb12, 1: bb12, 2: bb5, 3: bb1, otherwise: bb12];
       }
   
       bb1: {
@@ -73,7 +73,7 @@
           StorageLive(_9);
           _10 = discriminant((_1.1: Test1));
 -         switchInt(move _10) -> [0: bb8, 1: bb9, 2: bb10, 3: bb7, otherwise: bb2];
-+         switchInt(move _10) -> [2: bb10, 3: bb7, otherwise: bb2];
++         switchInt(move _10) -> [0: bb12, 1: bb12, 2: bb10, 3: bb7, otherwise: bb12];
       }
   
       bb7: {
@@ -110,6 +110,10 @@
           _0 = const ();
           StorageDead(_1);
           return;
++     }
++ 
++     bb12: {
++         unreachable;
       }
   }
   
diff --git a/tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff b/tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff
index 498e1e20f8a..79948139f88 100644
--- a/tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff
+++ b/tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff
@@ -9,7 +9,7 @@
       bb0: {
           _2 = discriminant(_1);
 -         switchInt(move _2) -> [0: bb2, 1: bb3, otherwise: bb1];
-+         switchInt(move _2) -> [1: bb3, otherwise: bb1];
++         switchInt(move _2) -> [0: bb5, 1: bb3, otherwise: bb1];
       }
   
       bb1: {
@@ -29,6 +29,10 @@
   
       bb4: {
           return;
++     }
++ 
++     bb5: {
++         unreachable;
       }
   }