about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-03-08 07:18:17 +0000
committerbors <bors@rust-lang.org>2024-03-08 07:18:17 +0000
commit14fbc3c00525b41a3a3ee2c90e9ab6fd3b05274f (patch)
tree81cec2387610314e63c2468b5c3a2d1b797b6529 /compiler/rustc_mir_transform/src
parent9fb91aa2e70bfcc1c0adaad79711f0321ea81ece (diff)
parent2884230df2bdf6ff23d51808f5e2f270dfbe0d3a (diff)
downloadrust-14fbc3c00525b41a3a3ee2c90e9ab6fd3b05274f.tar.gz
rust-14fbc3c00525b41a3a3ee2c90e9ab6fd3b05274f.zip
Auto merge of #120268 - DianQK:otherwise_is_last_variant_switchs, r=oli-obk
Replace the default branch with an unreachable branch If it is the last variant

Fixes #119520. Fixes #110097.

LLVM currently has limited ability to eliminate dead branches in switches, even with the patch of https://github.com/llvm/llvm-project/issues/73446.

The main reasons are as follows:

- Additional costs are required to calculate the range of values, and there exist many scenarios that cannot be analyzed accurately.
- Matching values by bitwise calculation cannot handle odd branches, nor can it handle values like `-1, 0, 1`. See [SimplifyCFG.cpp#L5424](https://github.com/llvm/llvm-project/blob/llvmorg-17.0.6/llvm/lib/Transforms/Utils/SimplifyCFG.cpp#L5424) and https://llvm.godbolt.org/z/qYMqhvMa8
- The current range information is continuous, even if the metadata for the range is submitted. See [ConstantRange.cpp#L1869-L1870](https://github.com/llvm/llvm-project/blob/llvmorg-17.0.6/llvm/lib/IR/ConstantRange.cpp#L1869-L1870).
- The metadata of the range may be lost in passes such as SROA. See https://rust.godbolt.org/z/e7f87vKMK.

Although we can make improvements, I think it would be more appropriate to put this issue to rustc first. After all, we can easily know the possible values.

Note that we've currently found a slow compilation problem in the presence of unreachable branches. See
https://github.com/llvm/llvm-project/issues/78578.

r? compiler
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs84
1 files changed, 57 insertions, 27 deletions
diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
index e68d37f4c70..57fe46ad75a 100644
--- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
+++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
@@ -2,8 +2,10 @@
 
 use crate::MirPass;
 use rustc_data_structures::fx::FxHashSet;
+use rustc_middle::mir::patch::MirPatch;
 use rustc_middle::mir::{
-    BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
+    BasicBlock, BasicBlockData, BasicBlocks, Body, Local, Operand, Rvalue, StatementKind,
+    TerminatorKind,
 };
 use rustc_middle::ty::layout::TyAndLayout;
 use rustc_middle::ty::{Ty, TyCtxt};
@@ -77,7 +79,8 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         trace!("UninhabitedEnumBranching starting for {:?}", body.source);
 
-        let mut removable_switchs = Vec::new();
+        let mut unreachable_targets = Vec::new();
+        let mut patch = MirPatch::new(body);
 
         for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
             trace!("processing block {:?}", bb);
@@ -92,46 +95,73 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
                 tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty),
             );
 
-            let allowed_variants = if let Ok(layout) = layout {
+            let mut allowed_variants = if let Ok(layout) = layout {
                 variant_discriminants(&layout, discriminant_ty, tcx)
+            } else if let Some(variant_range) = discriminant_ty.variant_range(tcx) {
+                variant_range
+                    .map(|variant| {
+                        discriminant_ty.discriminant_for_variant(tcx, variant).unwrap().val
+                    })
+                    .collect()
             } else {
                 continue;
             };
 
             trace!("allowed_variants = {:?}", allowed_variants);
 
-            let terminator = bb_data.terminator();
-            let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };
+            unreachable_targets.clear();
+            let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else {
+                bug!()
+            };
 
-            let mut reachable_count = 0;
             for (index, (val, _)) in targets.iter().enumerate() {
-                if allowed_variants.contains(&val) {
-                    reachable_count += 1;
-                } else {
-                    removable_switchs.push((bb, index));
+                if !allowed_variants.remove(&val) {
+                    unreachable_targets.push(index);
+                }
+            }
+            let otherwise_is_empty_unreachable =
+                body.basic_blocks[targets.otherwise()].is_empty_unreachable();
+            // After resolving https://github.com/llvm/llvm-project/issues/78578,
+            // we can remove the limit on the number of successors.
+            fn check_successors(basic_blocks: &BasicBlocks<'_>, bb: BasicBlock) -> bool {
+                let mut successors = basic_blocks[bb].terminator().successors();
+                let Some(first_successor) = successors.next() else { return true };
+                if successors.next().is_some() {
+                    return true;
                 }
+                if let TerminatorKind::SwitchInt { .. } =
+                    &basic_blocks[first_successor].terminator().kind
+                {
+                    return false;
+                };
+                true
             }
+            let otherwise_is_last_variant = !otherwise_is_empty_unreachable
+                && allowed_variants.len() == 1
+                && check_successors(&body.basic_blocks, targets.otherwise());
+            let replace_otherwise_to_unreachable = otherwise_is_last_variant
+                || !otherwise_is_empty_unreachable && allowed_variants.is_empty();
 
-            if reachable_count == allowed_variants.len() {
-                removable_switchs.push((bb, targets.iter().count()));
+            if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable {
+                continue;
             }
-        }
 
-        if removable_switchs.is_empty() {
-            return;
+            let unreachable_block = patch.unreachable_no_cleanup_block();
+            let mut targets = targets.clone();
+            if replace_otherwise_to_unreachable {
+                if otherwise_is_last_variant {
+                    #[allow(rustc::potential_query_instability)]
+                    let last_variant = *allowed_variants.iter().next().unwrap();
+                    targets.add_target(last_variant, targets.otherwise());
+                }
+                unreachable_targets.push(targets.iter().count());
+            }
+            for index in unreachable_targets.iter() {
+                targets.all_targets_mut()[*index] = unreachable_block;
+            }
+            patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() });
         }
 
-        let new_block = BasicBlockData::new(Some(Terminator {
-            source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
-            kind: TerminatorKind::Unreachable,
-        }));
-        let unreachable_block = body.basic_blocks.as_mut().push(new_block);
-
-        for (bb, index) in removable_switchs {
-            let bb = &mut body.basic_blocks.as_mut()[bb];
-            let terminator = bb.terminator_mut();
-            let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
-            targets.all_targets_mut()[index] = unreachable_block;
-        }
+        patch.apply(body);
     }
 }