about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNadrieril <nadrieril+git@gmail.com>2024-02-29 00:52:03 +0100
committerNadrieril <nadrieril+git@gmail.com>2024-03-02 18:33:17 +0100
commit3d3b321c60f6ce1ac59edf0706c083aa7fbd1e83 (patch)
treed2ed663aa97faef599a00c828f521c7b74acc118
parent832b23ffcfae702c333915bf2c25493f9f62ebb5 (diff)
downloadrust-3d3b321c60f6ce1ac59edf0706c083aa7fbd1e83.tar.gz
rust-3d3b321c60f6ce1ac59edf0706c083aa7fbd1e83.zip
Use an enum instead of manually tracking indices for `target_blocks`
-rw-r--r--compiler/rustc_mir_build/src/build/matches/mod.rs34
-rw-r--r--compiler/rustc_mir_build/src/build/matches/test.rs117
-rw-r--r--tests/mir-opt/building/issue_49232.main.built.after.mir8
-rw-r--r--tests/mir-opt/match_arm_scopes.complicated_match.panic-abort.SimplifyCfg-initial.after-ElaborateDrops.after.diff15
-rw-r--r--tests/mir-opt/match_arm_scopes.complicated_match.panic-unwind.SimplifyCfg-initial.after-ElaborateDrops.after.diff15
5 files changed, 112 insertions, 77 deletions
diff --git a/compiler/rustc_mir_build/src/build/matches/mod.rs b/compiler/rustc_mir_build/src/build/matches/mod.rs
index daa0349789e..aea52fc497f 100644
--- a/compiler/rustc_mir_build/src/build/matches/mod.rs
+++ b/compiler/rustc_mir_build/src/build/matches/mod.rs
@@ -1160,6 +1160,19 @@ pub(crate) struct Test<'tcx> {
     kind: TestKind<'tcx>,
 }
 
+/// The branch to be taken after a test.
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+enum TestBranch<'tcx> {
+    /// Success branch, used for tests with two possible outcomes.
+    Success,
+    /// Branch corresponding to this constant.
+    Constant(Const<'tcx>, u128),
+    /// Branch corresponding to this variant.
+    Variant(VariantIdx),
+    /// Failure branch for tests with two possible outcomes, and "otherwise" branch for other tests.
+    Failure,
+}
+
 /// `ArmHasGuard` is a wrapper around a boolean flag. It indicates whether
 /// a match arm has a guard expression attached to it.
 #[derive(Copy, Clone, Debug)]
@@ -1636,11 +1649,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         match_place: &PlaceBuilder<'tcx>,
         test: &Test<'tcx>,
         mut candidates: &'b mut [&'c mut Candidate<'pat, 'tcx>],
-    ) -> (&'b mut [&'c mut Candidate<'pat, 'tcx>], Vec<Vec<&'b mut Candidate<'pat, 'tcx>>>) {
+    ) -> (
+        &'b mut [&'c mut Candidate<'pat, 'tcx>],
+        FxIndexMap<TestBranch<'tcx>, Vec<&'b mut Candidate<'pat, 'tcx>>>,
+    ) {
         // For each of the N possible outcomes, create a (initially empty) vector of candidates.
         // Those are the candidates that apply if the test has that particular outcome.
-        let mut target_candidates: Vec<Vec<&mut Candidate<'pat, 'tcx>>> = vec![];
-        target_candidates.resize_with(test.targets(), Default::default);
+        let mut target_candidates: FxIndexMap<_, Vec<&mut Candidate<'pat, 'tcx>>> =
+            test.targets().into_iter().map(|branch| (branch, Vec::new())).collect();
 
         let total_candidate_count = candidates.len();
 
@@ -1648,11 +1664,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         // point we may encounter a candidate where the test is not relevant; at that point, we stop
         // sorting.
         while let Some(candidate) = candidates.first_mut() {
-            let Some(idx) = self.sort_candidate(&match_place, &test, candidate) else {
+            let Some(branch) = self.sort_candidate(&match_place, &test, candidate) else {
                 break;
             };
             let (candidate, rest) = candidates.split_first_mut().unwrap();
-            target_candidates[idx].push(candidate);
+            target_candidates[&branch].push(candidate);
             candidates = rest;
         }
 
@@ -1797,9 +1813,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         // apply. Collect a list of blocks where control flow will
         // branch if one of the `target_candidate` sets is not
         // exhaustive.
-        let target_blocks: Vec<_> = target_candidates
+        let target_blocks: FxIndexMap<_, _> = target_candidates
             .into_iter()
-            .map(|mut candidates| {
+            .map(|(branch, mut candidates)| {
                 if !candidates.is_empty() {
                     let candidate_start = self.cfg.start_new_block();
                     self.match_candidates(
@@ -1809,9 +1825,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                         remainder_start,
                         &mut *candidates,
                     );
-                    candidate_start
+                    (branch, candidate_start)
                 } else {
-                    remainder_start
+                    (branch, remainder_start)
                 }
             })
             .collect();
diff --git a/compiler/rustc_mir_build/src/build/matches/test.rs b/compiler/rustc_mir_build/src/build/matches/test.rs
index d811141f50f..d003ae8d803 100644
--- a/compiler/rustc_mir_build/src/build/matches/test.rs
+++ b/compiler/rustc_mir_build/src/build/matches/test.rs
@@ -6,7 +6,7 @@
 // the candidates based on the result.
 
 use crate::build::expr::as_place::PlaceBuilder;
-use crate::build::matches::{Candidate, MatchPair, Test, TestCase, TestKind};
+use crate::build::matches::{Candidate, MatchPair, Test, TestBranch, TestCase, TestKind};
 use crate::build::Builder;
 use rustc_data_structures::fx::FxIndexMap;
 use rustc_hir::{LangItem, RangeEnd};
@@ -129,11 +129,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         block: BasicBlock,
         place_builder: &PlaceBuilder<'tcx>,
         test: &Test<'tcx>,
-        target_blocks: Vec<BasicBlock>,
+        target_blocks: FxIndexMap<TestBranch<'tcx>, BasicBlock>,
     ) {
         let place = place_builder.to_place(self);
         let place_ty = place.ty(&self.local_decls, self.tcx);
-        debug!(?place, ?place_ty,);
+        debug!(?place, ?place_ty);
+        let target_block = |branch| target_blocks[&branch];
 
         let source_info = self.source_info(test.span);
         match test.kind {
@@ -141,20 +142,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                 // Variants is a BitVec of indexes into adt_def.variants.
                 let num_enum_variants = adt_def.variants().len();
                 debug_assert_eq!(target_blocks.len(), num_enum_variants + 1);
-                let otherwise_block = *target_blocks.last().unwrap();
+                let otherwise_block = target_block(TestBranch::Failure);
                 let tcx = self.tcx;
                 let switch_targets = SwitchTargets::new(
                     adt_def.discriminants(tcx).filter_map(|(idx, discr)| {
                         if variants.contains(idx) {
                             debug_assert_ne!(
-                                target_blocks[idx.index()],
+                                target_block(TestBranch::Variant(idx)),
                                 otherwise_block,
                                 "no candidates for tested discriminant: {discr:?}",
                             );
-                            Some((discr.val, target_blocks[idx.index()]))
+                            Some((discr.val, target_block(TestBranch::Variant(idx))))
                         } else {
                             debug_assert_eq!(
-                                target_blocks[idx.index()],
+                                target_block(TestBranch::Variant(idx)),
                                 otherwise_block,
                                 "found candidates for untested discriminant: {discr:?}",
                             );
@@ -185,9 +186,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
             TestKind::SwitchInt { ref options } => {
                 // The switch may be inexhaustive so we have a catch-all block
                 debug_assert_eq!(options.len() + 1, target_blocks.len());
-                let otherwise_block = *target_blocks.last().unwrap();
+                let otherwise_block = target_block(TestBranch::Failure);
                 let switch_targets = SwitchTargets::new(
-                    options.values().copied().zip(target_blocks),
+                    options
+                        .iter()
+                        .map(|(&val, &bits)| (bits, target_block(TestBranch::Constant(val, bits)))),
                     otherwise_block,
                 );
                 let terminator = TerminatorKind::SwitchInt {
@@ -198,18 +201,19 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
             }
 
             TestKind::If => {
-                let [false_bb, true_bb] = *target_blocks else {
-                    bug!("`TestKind::If` should have two targets")
-                };
-                let terminator = TerminatorKind::if_(Operand::Copy(place), true_bb, false_bb);
+                debug_assert_eq!(target_blocks.len(), 2);
+                let success_block = target_block(TestBranch::Success);
+                let fail_block = target_block(TestBranch::Failure);
+                let terminator =
+                    TerminatorKind::if_(Operand::Copy(place), success_block, fail_block);
                 self.cfg.terminate(block, self.source_info(match_start_span), terminator);
             }
 
             TestKind::Eq { value, ty } => {
                 let tcx = self.tcx;
-                let [success_block, fail_block] = *target_blocks else {
-                    bug!("`TestKind::Eq` should have two target blocks")
-                };
+                debug_assert_eq!(target_blocks.len(), 2);
+                let success_block = target_block(TestBranch::Success);
+                let fail_block = target_block(TestBranch::Failure);
                 if let ty::Adt(def, _) = ty.kind()
                     && Some(def.did()) == tcx.lang_items().string()
                 {
@@ -286,9 +290,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
             }
 
             TestKind::Range(ref range) => {
-                let [success, fail] = *target_blocks else {
-                    bug!("`TestKind::Range` should have two target blocks");
-                };
+                debug_assert_eq!(target_blocks.len(), 2);
+                let success = target_block(TestBranch::Success);
+                let fail = target_block(TestBranch::Failure);
                 // Test `val` by computing `lo <= val && val <= hi`, using primitive comparisons.
                 let val = Operand::Copy(place);
 
@@ -333,15 +337,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                 // expected = <N>
                 let expected = self.push_usize(block, source_info, len);
 
-                let [true_bb, false_bb] = *target_blocks else {
-                    bug!("`TestKind::Len` should have two target blocks");
-                };
+                debug_assert_eq!(target_blocks.len(), 2);
+                let success_block = target_block(TestBranch::Success);
+                let fail_block = target_block(TestBranch::Failure);
                 // result = actual == expected OR result = actual < expected
                 // branch based on result
                 self.compare(
                     block,
-                    true_bb,
-                    false_bb,
+                    success_block,
+                    fail_block,
                     source_info,
                     op,
                     Operand::Move(actual),
@@ -526,10 +530,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
 
     /// Given that we are performing `test` against `test_place`, this job
     /// sorts out what the status of `candidate` will be after the test. See
-    /// `test_candidates` for the usage of this function. The returned index is
-    /// the index that this candidate should be placed in the
-    /// `target_candidates` vec. The candidate may be modified to update its
-    /// `match_pairs`.
+    /// `test_candidates` for the usage of this function. The candidate may
+    /// be modified to update its `match_pairs`.
     ///
     /// So, for example, if this candidate is `x @ Some(P0)` and the `Test` is
     /// a variant test, then we would modify the candidate to be `(x as
@@ -556,7 +558,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         test_place: &PlaceBuilder<'tcx>,
         test: &Test<'tcx>,
         candidate: &mut Candidate<'pat, 'tcx>,
-    ) -> Option<usize> {
+    ) -> Option<TestBranch<'tcx>> {
         // Find the match_pair for this place (if any). At present,
         // afaik, there can be at most one. (In the future, if we
         // adopted a more general `@` operator, there might be more
@@ -576,7 +578,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
             ) => {
                 assert_eq!(adt_def, tested_adt_def);
                 fully_matched = true;
-                Some(variant_index.as_usize())
+                Some(TestBranch::Variant(variant_index))
             }
 
             // If we are performing a switch over integers, then this informs integer
@@ -584,12 +586,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
             //
             // FIXME(#29623) we could use PatKind::Range to rule
             // things out here, in some cases.
-            (TestKind::SwitchInt { options }, TestCase::Constant { value })
+            (TestKind::SwitchInt { options }, &TestCase::Constant { value })
                 if is_switch_ty(match_pair.pattern.ty) =>
             {
                 fully_matched = true;
-                let index = options.get_index_of(value).unwrap();
-                Some(index)
+                let bits = options.get(&value).unwrap();
+                Some(TestBranch::Constant(value, *bits))
             }
             (TestKind::SwitchInt { options }, TestCase::Range(range)) => {
                 fully_matched = false;
@@ -599,7 +601,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                 not_contained.then(|| {
                     // No switch values are contained in the pattern range,
                     // so the pattern can be matched only if this test fails.
-                    options.len()
+                    TestBranch::Failure
                 })
             }
 
@@ -608,7 +610,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                 let value = value.try_eval_bool(self.tcx, self.param_env).unwrap_or_else(|| {
                     span_bug!(test.span, "expected boolean value but got {value:?}")
                 });
-                Some(value as usize)
+                Some(if value { TestBranch::Success } else { TestBranch::Failure })
             }
 
             (
@@ -620,14 +622,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                         // on true, min_len = len = $actual_length,
                         // on false, len != $actual_length
                         fully_matched = true;
-                        Some(0)
+                        Some(TestBranch::Success)
                     }
                     (Ordering::Less, _) => {
                         // test_len < pat_len. If $actual_len = test_len,
                         // then $actual_len < pat_len and we don't have
                         // enough elements.
                         fully_matched = false;
-                        Some(1)
+                        Some(TestBranch::Failure)
                     }
                     (Ordering::Equal | Ordering::Greater, true) => {
                         // This can match both if $actual_len = test_len >= pat_len,
@@ -639,7 +641,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                         // test_len != pat_len, so if $actual_len = test_len, then
                         // $actual_len != pat_len.
                         fully_matched = false;
-                        Some(1)
+                        Some(TestBranch::Failure)
                     }
                 }
             }
@@ -653,20 +655,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                         // $actual_len >= test_len = pat_len,
                         // so we can match.
                         fully_matched = true;
-                        Some(0)
+                        Some(TestBranch::Success)
                     }
                     (Ordering::Less, _) | (Ordering::Equal, false) => {
                         // test_len <= pat_len. If $actual_len < test_len,
                         // then it is also < pat_len, so the test passing is
                         // necessary (but insufficient).
                         fully_matched = false;
-                        Some(0)
+                        Some(TestBranch::Success)
                     }
                     (Ordering::Greater, false) => {
                         // test_len > pat_len. If $actual_len >= test_len > pat_len,
                         // then we know we won't have a match.
                         fully_matched = false;
-                        Some(1)
+                        Some(TestBranch::Failure)
                     }
                     (Ordering::Greater, true) => {
                         // test_len < pat_len, and is therefore less
@@ -680,12 +682,16 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
             (TestKind::Range(test), &TestCase::Range(pat)) => {
                 if test.as_ref() == pat {
                     fully_matched = true;
-                    Some(0)
+                    Some(TestBranch::Success)
                 } else {
                     fully_matched = false;
                     // If the testing range does not overlap with pattern range,
                     // the pattern can be matched only if this test fails.
-                    if !test.overlaps(pat, self.tcx, self.param_env)? { Some(1) } else { None }
+                    if !test.overlaps(pat, self.tcx, self.param_env)? {
+                        Some(TestBranch::Failure)
+                    } else {
+                        None
+                    }
                 }
             }
             (TestKind::Range(range), &TestCase::Constant { value }) => {
@@ -693,7 +699,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                 if !range.contains(value, self.tcx, self.param_env)? {
                     // `value` is not contained in the testing range,
                     // so `value` can be matched only if this test fails.
-                    Some(1)
+                    Some(TestBranch::Failure)
                 } else {
                     None
                 }
@@ -704,7 +710,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                 if test_val == case_val =>
             {
                 fully_matched = true;
-                Some(0)
+                Some(TestBranch::Success)
             }
 
             (
@@ -747,18 +753,29 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
     }
 }
 
-impl Test<'_> {
-    pub(super) fn targets(&self) -> usize {
+impl<'tcx> Test<'tcx> {
+    pub(super) fn targets(&self) -> Vec<TestBranch<'tcx>> {
         match self.kind {
-            TestKind::Eq { .. } | TestKind::Range(_) | TestKind::Len { .. } | TestKind::If => 2,
+            TestKind::Eq { .. } | TestKind::Range(_) | TestKind::Len { .. } | TestKind::If => {
+                vec![TestBranch::Success, TestBranch::Failure]
+            }
             TestKind::Switch { adt_def, .. } => {
                 // While the switch that we generate doesn't test for all
                 // variants, we have a target for each variant and the
                 // otherwise case, and we make sure that all of the cases not
                 // specified have the same block.
-                adt_def.variants().len() + 1
+                adt_def
+                    .variants()
+                    .indices()
+                    .map(|idx| TestBranch::Variant(idx))
+                    .chain([TestBranch::Failure])
+                    .collect()
             }
-            TestKind::SwitchInt { ref options } => options.len() + 1,
+            TestKind::SwitchInt { ref options } => options
+                .iter()
+                .map(|(val, bits)| TestBranch::Constant(*val, *bits))
+                .chain([TestBranch::Failure])
+                .collect(),
         }
     }
 }
diff --git a/tests/mir-opt/building/issue_49232.main.built.after.mir b/tests/mir-opt/building/issue_49232.main.built.after.mir
index d09a1748a8b..166e28ce51d 100644
--- a/tests/mir-opt/building/issue_49232.main.built.after.mir
+++ b/tests/mir-opt/building/issue_49232.main.built.after.mir
@@ -25,7 +25,7 @@ fn main() -> () {
         StorageLive(_3);
         _3 = const true;
         PlaceMention(_3);
-        switchInt(_3) -> [0: bb4, otherwise: bb6];
+        switchInt(_3) -> [0: bb6, otherwise: bb4];
     }
 
     bb3: {
@@ -34,7 +34,8 @@ fn main() -> () {
     }
 
     bb4: {
-        falseEdge -> [real: bb8, imaginary: bb6];
+        _0 = const ();
+        goto -> bb13;
     }
 
     bb5: {
@@ -42,8 +43,7 @@ fn main() -> () {
     }
 
     bb6: {
-        _0 = const ();
-        goto -> bb13;
+        falseEdge -> [real: bb8, imaginary: bb4];
     }
 
     bb7: {
diff --git a/tests/mir-opt/match_arm_scopes.complicated_match.panic-abort.SimplifyCfg-initial.after-ElaborateDrops.after.diff b/tests/mir-opt/match_arm_scopes.complicated_match.panic-abort.SimplifyCfg-initial.after-ElaborateDrops.after.diff
index 619fda339a6..307f7105dd2 100644
--- a/tests/mir-opt/match_arm_scopes.complicated_match.panic-abort.SimplifyCfg-initial.after-ElaborateDrops.after.diff
+++ b/tests/mir-opt/match_arm_scopes.complicated_match.panic-abort.SimplifyCfg-initial.after-ElaborateDrops.after.diff
@@ -42,11 +42,15 @@
       }
   
       bb2: {
--         switchInt((_2.0: bool)) -> [0: bb3, otherwise: bb4];
+-         switchInt((_2.0: bool)) -> [0: bb4, otherwise: bb3];
 +         switchInt((_2.0: bool)) -> [0: bb3, otherwise: bb17];
       }
   
       bb3: {
+-         falseEdge -> [real: bb20, imaginary: bb4];
+-     }
+- 
+-     bb4: {
           StorageLive(_15);
           _15 = (_2.1: bool);
           StorageLive(_16);
@@ -55,12 +59,8 @@
 +         goto -> bb16;
       }
   
-      bb4: {
--         falseEdge -> [real: bb20, imaginary: bb3];
--     }
-- 
 -     bb5: {
--         falseEdge -> [real: bb13, imaginary: bb4];
+-         falseEdge -> [real: bb13, imaginary: bb3];
 -     }
 - 
 -     bb6: {
@@ -68,6 +68,7 @@
 -     }
 - 
 -     bb7: {
++     bb4: {
           _0 = const 1_i32;
 -         drop(_7) -> [return: bb18, unwind: bb25];
 +         drop(_7) -> [return: bb15, unwind: bb22];
@@ -183,7 +184,7 @@
           StorageDead(_12);
           StorageDead(_8);
           StorageDead(_6);
--         falseEdge -> [real: bb2, imaginary: bb4];
+-         falseEdge -> [real: bb2, imaginary: bb3];
 +         goto -> bb2;
       }
   
diff --git a/tests/mir-opt/match_arm_scopes.complicated_match.panic-unwind.SimplifyCfg-initial.after-ElaborateDrops.after.diff b/tests/mir-opt/match_arm_scopes.complicated_match.panic-unwind.SimplifyCfg-initial.after-ElaborateDrops.after.diff
index 619fda339a6..307f7105dd2 100644
--- a/tests/mir-opt/match_arm_scopes.complicated_match.panic-unwind.SimplifyCfg-initial.after-ElaborateDrops.after.diff
+++ b/tests/mir-opt/match_arm_scopes.complicated_match.panic-unwind.SimplifyCfg-initial.after-ElaborateDrops.after.diff
@@ -42,11 +42,15 @@
       }
   
       bb2: {
--         switchInt((_2.0: bool)) -> [0: bb3, otherwise: bb4];
+-         switchInt((_2.0: bool)) -> [0: bb4, otherwise: bb3];
 +         switchInt((_2.0: bool)) -> [0: bb3, otherwise: bb17];
       }
   
       bb3: {
+-         falseEdge -> [real: bb20, imaginary: bb4];
+-     }
+- 
+-     bb4: {
           StorageLive(_15);
           _15 = (_2.1: bool);
           StorageLive(_16);
@@ -55,12 +59,8 @@
 +         goto -> bb16;
       }
   
-      bb4: {
--         falseEdge -> [real: bb20, imaginary: bb3];
--     }
-- 
 -     bb5: {
--         falseEdge -> [real: bb13, imaginary: bb4];
+-         falseEdge -> [real: bb13, imaginary: bb3];
 -     }
 - 
 -     bb6: {
@@ -68,6 +68,7 @@
 -     }
 - 
 -     bb7: {
++     bb4: {
           _0 = const 1_i32;
 -         drop(_7) -> [return: bb18, unwind: bb25];
 +         drop(_7) -> [return: bb15, unwind: bb22];
@@ -183,7 +184,7 @@
           StorageDead(_12);
           StorageDead(_8);
           StorageDead(_6);
--         falseEdge -> [real: bb2, imaginary: bb4];
+-         falseEdge -> [real: bb2, imaginary: bb3];
 +         goto -> bb2;
       }