about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNadrieril <nadrieril+git@gmail.com>2024-01-21 08:33:02 +0100
committerNadrieril <nadrieril+git@gmail.com>2024-02-19 21:28:26 +0100
commitd936ab63d44f8d0f093e3a09709881673b1a7ecd (patch)
tree5f5a0dab2cb2aabc517c7dafa1cc77fd5e30cfa9
parent308b4824aabf92e1301b67137059e35b0b388768 (diff)
downloadrust-d936ab63d44f8d0f093e3a09709881673b1a7ecd.tar.gz
rust-d936ab63d44f8d0f093e3a09709881673b1a7ecd.zip
Eagerly simplify match pairs
-rw-r--r--compiler/rustc_mir_build/src/build/matches/mod.rs35
-rw-r--r--compiler/rustc_mir_build/src/build/matches/simplify.rs20
-rw-r--r--compiler/rustc_mir_build/src/build/matches/test.rs106
-rw-r--r--compiler/rustc_mir_build/src/build/matches/util.rs2
4 files changed, 61 insertions, 102 deletions
diff --git a/compiler/rustc_mir_build/src/build/matches/mod.rs b/compiler/rustc_mir_build/src/build/matches/mod.rs
index bb307089f1c..309fd3bf09e 100644
--- a/compiler/rustc_mir_build/src/build/matches/mod.rs
+++ b/compiler/rustc_mir_build/src/build/matches/mod.rs
@@ -947,12 +947,16 @@ struct Candidate<'pat, 'tcx> {
     has_guard: bool,
 
     /// All of these must be satisfied...
+    // Invariant: all the `MatchPair`s are recursively simplified.
+    // Invariant: or-patterns must be sorted at the end.
     match_pairs: Vec<MatchPair<'pat, 'tcx>>,
 
     /// ...these bindings established...
+    // Invariant: not mutated outside `Candidate::new()`.
     bindings: Vec<Binding<'tcx>>,
 
     /// ...and these types asserted...
+    // Invariant: not mutated outside `Candidate::new()`.
     ascriptions: Vec<Ascription<'tcx>>,
 
     /// ...and if this is non-empty, one of these subcandidates also has to match...
@@ -972,9 +976,9 @@ impl<'tcx, 'pat> Candidate<'pat, 'tcx> {
         place: PlaceBuilder<'tcx>,
         pattern: &'pat Pat<'tcx>,
         has_guard: bool,
-        cx: &Builder<'_, 'tcx>,
+        cx: &mut Builder<'_, 'tcx>,
     ) -> Self {
-        Candidate {
+        let mut candidate = Candidate {
             span: pattern.span,
             has_guard,
             match_pairs: vec![MatchPair::new(place, pattern, cx)],
@@ -984,7 +988,15 @@ impl<'tcx, 'pat> Candidate<'pat, 'tcx> {
             otherwise_block: None,
             pre_binding_block: None,
             next_candidate_pre_binding_block: None,
-        }
+        };
+
+        cx.simplify_match_pairs(
+            &mut candidate.match_pairs,
+            &mut candidate.bindings,
+            &mut candidate.ascriptions,
+        );
+
+        candidate
     }
 
     /// Visit the leaf candidates (those with no subcandidates) contained in
@@ -1040,13 +1052,18 @@ struct Ascription<'tcx> {
     variance: ty::Variance,
 }
 
-#[derive(Clone, Debug)]
+#[derive(Debug)]
 pub(crate) struct MatchPair<'pat, 'tcx> {
-    // this place...
+    // This place...
     place: PlaceBuilder<'tcx>,
 
     // ... must match this pattern.
+    // Invariant: after creation and simplification in `Candidate::new()`, all match pairs must be
+    // simplified, i.e. require a test.
     pattern: &'pat Pat<'tcx>,
+
+    /// Precomputed sub-match pairs of `pattern`.
+    subpairs: Vec<Self>,
 }
 
 /// See [`Test`] for more.
@@ -1163,16 +1180,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         candidates: &mut [&mut Candidate<'pat, 'tcx>],
         fake_borrows: &mut Option<FxIndexSet<Place<'tcx>>>,
     ) {
-        // Start by simplifying candidates. Once this process is complete, all
-        // the match pairs which remain require some form of test, whether it
-        // be a switch or pattern comparison.
         let mut split_or_candidate = false;
         for candidate in &mut *candidates {
-            self.simplify_match_pairs(
-                &mut candidate.match_pairs,
-                &mut candidate.bindings,
-                &mut candidate.ascriptions,
-            );
             if let [MatchPair { pattern: Pat { kind: PatKind::Or { pats }, .. }, place, .. }] =
                 &*candidate.match_pairs
             {
diff --git a/compiler/rustc_mir_build/src/build/matches/simplify.rs b/compiler/rustc_mir_build/src/build/matches/simplify.rs
index 8ff07c590bb..441e55fda89 100644
--- a/compiler/rustc_mir_build/src/build/matches/simplify.rs
+++ b/compiler/rustc_mir_build/src/build/matches/simplify.rs
@@ -107,12 +107,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         pats.iter()
             .map(|box pat| {
                 let mut candidate = Candidate::new(place.clone(), pat, has_guard, self);
-                self.simplify_match_pairs(
-                    &mut candidate.match_pairs,
-                    &mut candidate.bindings,
-                    &mut candidate.ascriptions,
-                );
-
                 if let [MatchPair { pattern: Pat { kind: PatKind::Or { pats }, .. }, place, .. }] =
                     &*candidate.match_pairs
                 {
@@ -132,11 +126,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
     /// candidate.
     fn simplify_match_pair<'pat>(
         &mut self,
-        match_pair: MatchPair<'pat, 'tcx>,
+        mut match_pair: MatchPair<'pat, 'tcx>,
         bindings: &mut Vec<Binding<'tcx>>,
         ascriptions: &mut Vec<Ascription<'tcx>>,
         match_pairs: &mut Vec<MatchPair<'pat, 'tcx>>,
     ) -> Result<(), MatchPair<'pat, 'tcx>> {
+        assert!(match_pair.subpairs.is_empty(), "mustn't simplify a match pair twice");
         match match_pair.pattern.kind {
             PatKind::AscribeUserType {
                 ref subpattern,
@@ -249,6 +244,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                     self.prefix_slice_suffix(match_pairs, &match_pair.place, prefix, slice, suffix);
                     Ok(())
                 } else {
+                    self.prefix_slice_suffix(
+                        &mut match_pair.subpairs,
+                        &match_pair.place,
+                        prefix,
+                        slice,
+                        suffix,
+                    );
+                    self.simplify_match_pairs(&mut match_pair.subpairs, bindings, ascriptions);
                     Err(match_pair)
                 }
             }
@@ -270,6 +273,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                     match_pairs.extend(self.field_match_pairs(place_builder, subpatterns));
                     Ok(())
                 } else {
+                    let downcast_place = match_pair.place.clone().downcast(adt_def, variant_index); // `(x as Variant)`
+                    match_pair.subpairs = self.field_match_pairs(downcast_place, subpatterns);
+                    self.simplify_match_pairs(&mut match_pair.subpairs, bindings, ascriptions);
                     Err(match_pair)
                 }
             }
diff --git a/compiler/rustc_mir_build/src/build/matches/test.rs b/compiler/rustc_mir_build/src/build/matches/test.rs
index 990be30b2d6..d5ae2fcdfa0 100644
--- a/compiler/rustc_mir_build/src/build/matches/test.rs
+++ b/compiler/rustc_mir_build/src/build/matches/test.rs
@@ -589,22 +589,17 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
         // away.)
         let (match_pair_index, match_pair) =
             candidate.match_pairs.iter().enumerate().find(|&(_, mp)| mp.place == *test_place)?;
+        let mut fully_matched = false;
 
-        match (&test.kind, &match_pair.pattern.kind) {
+        let ret = match (&test.kind, &match_pair.pattern.kind) {
             // If we are performing a variant switch, then this
             // informs variant patterns, but nothing else.
             (
                 &TestKind::Switch { adt_def: tested_adt_def, .. },
-                &PatKind::Variant { adt_def, variant_index, ref subpatterns, .. },
+                &PatKind::Variant { adt_def, variant_index, .. },
             ) => {
                 assert_eq!(adt_def, tested_adt_def);
-                self.candidate_after_variant_switch(
-                    match_pair_index,
-                    adt_def,
-                    variant_index,
-                    subpatterns,
-                    candidate,
-                );
+                fully_matched = true;
                 Some(variant_index.as_usize())
             }
 
@@ -618,8 +613,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
             (TestKind::SwitchInt { switch_ty: _, options }, PatKind::Constant { value })
                 if is_switch_ty(match_pair.pattern.ty) =>
             {
+                fully_matched = true;
                 let index = options.get_index_of(value).unwrap();
-                self.candidate_without_match_pair(match_pair_index, candidate);
                 Some(index)
             }
 
@@ -645,13 +640,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                     (Ordering::Equal, &None) => {
                         // on true, min_len = len = $actual_length,
                         // on false, len != $actual_length
-                        self.candidate_after_slice_test(
-                            match_pair_index,
-                            candidate,
-                            prefix,
-                            slice,
-                            suffix,
-                        );
+                        fully_matched = true;
                         Some(0)
                     }
                     (Ordering::Less, _) => {
@@ -683,13 +672,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                     (Ordering::Equal, &Some(_)) => {
                         // $actual_len >= test_len = pat_len,
                         // so we can match.
-                        self.candidate_after_slice_test(
-                            match_pair_index,
-                            candidate,
-                            prefix,
-                            slice,
-                            suffix,
-                        );
+                        fully_matched = true;
                         Some(0)
                     }
                     (Ordering::Less, _) | (Ordering::Equal, &None) => {
@@ -713,13 +696,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
 
             (TestKind::Range(test), PatKind::Range(pat)) => {
                 if test == pat {
-                    self.candidate_without_match_pair(match_pair_index, candidate);
-                    return Some(0);
+                    fully_matched = true;
+                    Some(0)
+                } else {
+                    // If the testing range does not overlap with pattern range,
+                    // the pattern can be matched only if this test fails.
+                    if !test.overlaps(pat, self.tcx, self.param_env)? { Some(1) } else { None }
                 }
-
-                // If the testing range does not overlap with pattern range,
-                // the pattern can be matched only if this test fails.
-                if !test.overlaps(pat, self.tcx, self.param_env)? { Some(1) } else { None }
             }
 
             (TestKind::Range(range), &PatKind::Constant { value }) => {
@@ -751,64 +734,25 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
                 // FIXME(#29623) we can be more clever here
                 let pattern_test = self.test(match_pair);
                 if pattern_test.kind == test.kind {
-                    self.candidate_without_match_pair(match_pair_index, candidate);
+                    fully_matched = true;
                     Some(0)
                 } else {
                     None
                 }
             }
-        }
-    }
-
-    fn candidate_without_match_pair(
-        &mut self,
-        match_pair_index: usize,
-        candidate: &mut Candidate<'_, 'tcx>,
-    ) {
-        candidate.match_pairs.remove(match_pair_index);
-    }
+        };
 
-    fn candidate_after_slice_test<'pat>(
-        &mut self,
-        match_pair_index: usize,
-        candidate: &mut Candidate<'pat, 'tcx>,
-        prefix: &'pat [Box<Pat<'tcx>>],
-        opt_slice: &'pat Option<Box<Pat<'tcx>>>,
-        suffix: &'pat [Box<Pat<'tcx>>],
-    ) {
-        let removed_place = candidate.match_pairs.remove(match_pair_index).place;
-        self.prefix_slice_suffix(
-            &mut candidate.match_pairs,
-            &removed_place,
-            prefix,
-            opt_slice,
-            suffix,
-        );
-    }
+        if fully_matched {
+            // Replace the match pair by its sub-pairs.
+            let match_pair = candidate.match_pairs.remove(match_pair_index);
+            candidate.match_pairs.extend(match_pair.subpairs);
+            // Move or-patterns to the end.
+            candidate
+                .match_pairs
+                .sort_by_key(|pair| matches!(pair.pattern.kind, PatKind::Or { .. }));
+        }
 
-    fn candidate_after_variant_switch<'pat>(
-        &mut self,
-        match_pair_index: usize,
-        adt_def: ty::AdtDef<'tcx>,
-        variant_index: VariantIdx,
-        subpatterns: &'pat [FieldPat<'tcx>],
-        candidate: &mut Candidate<'pat, 'tcx>,
-    ) {
-        let match_pair = candidate.match_pairs.remove(match_pair_index);
-
-        // So, if we have a match-pattern like `x @ Enum::Variant(P1, P2)`,
-        // we want to create a set of derived match-patterns like
-        // `(x as Variant).0 @ P1` and `(x as Variant).1 @ P1`.
-        let downcast_place = match_pair.place.downcast(adt_def, variant_index); // `(x as Variant)`
-        let consequent_match_pairs = subpatterns.iter().map(|subpattern| {
-            // e.g., `(x as Variant).0`
-            let place = downcast_place
-                .clone_project(PlaceElem::Field(subpattern.field, subpattern.pattern.ty));
-            // e.g., `(x as Variant).0 @ P1`
-            MatchPair::new(place, &subpattern.pattern, self)
-        });
-
-        candidate.match_pairs.extend(consequent_match_pairs);
+        ret
     }
 
     fn error_simplifiable<'pat>(&mut self, match_pair: &MatchPair<'pat, 'tcx>) -> ! {
diff --git a/compiler/rustc_mir_build/src/build/matches/util.rs b/compiler/rustc_mir_build/src/build/matches/util.rs
index 5eb853989d0..a426f2593fa 100644
--- a/compiler/rustc_mir_build/src/build/matches/util.rs
+++ b/compiler/rustc_mir_build/src/build/matches/util.rs
@@ -116,6 +116,6 @@ impl<'pat, 'tcx> MatchPair<'pat, 'tcx> {
         if may_need_cast {
             place = place.project(ProjectionElem::OpaqueCast(pattern.ty));
         }
-        MatchPair { place, pattern }
+        MatchPair { place, pattern, subpairs: Vec::new() }
     }
 }