about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_mir_build/src/thir/pattern/_match.rs179
-rw-r--r--compiler/rustc_mir_build/src/thir/pattern/mod.rs8
2 files changed, 165 insertions, 22 deletions
diff --git a/compiler/rustc_mir_build/src/thir/pattern/_match.rs b/compiler/rustc_mir_build/src/thir/pattern/_match.rs
index eddd2882406..904524e13ae 100644
--- a/compiler/rustc_mir_build/src/thir/pattern/_match.rs
+++ b/compiler/rustc_mir_build/src/thir/pattern/_match.rs
@@ -139,10 +139,10 @@
 //!
 //!    It is computed as follows. We look at the pattern `p_1` on top of the stack,
 //!    and we have three cases:
-//!         1.1. `p_1 = c(r_1, .., r_a)`. We discard the current stack and return nothing.
-//!         1.2. `p_1 = _`. We return the rest of the stack:
+//!         2.1. `p_1 = c(r_1, .., r_a)`. We discard the current stack and return nothing.
+//!         2.2. `p_1 = _`. We return the rest of the stack:
 //!                 p_2, .., p_n
-//!         1.3. `p_1 = r_1 | r_2`. We expand the OR-pattern and then recurse on each resulting
+//!         2.3. `p_1 = r_1 | r_2`. We expand the OR-pattern and then recurse on each resulting
 //!           stack.
 //!                 D((r_1, p_2, .., p_n))
 //!                 D((r_2, p_2, .., p_n))
@@ -276,7 +276,7 @@ use self::Usefulness::*;
 use self::WitnessPreference::*;
 
 use rustc_data_structures::captures::Captures;
-use rustc_data_structures::fx::FxHashSet;
+use rustc_data_structures::fx::{FxHashMap, FxHashSet};
 use rustc_index::vec::Idx;
 
 use super::{compare_const_vals, PatternFoldable, PatternFolder};
@@ -416,7 +416,7 @@ impl<'tcx> Pat<'tcx> {
 
 /// A row of a matrix. Rows of len 1 are very common, which is why `SmallVec[_; 2]`
 /// works well.
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, PartialEq)]
 crate struct PatStack<'p, 'tcx>(SmallVec<[&'p Pat<'tcx>; 2]>);
 
 impl<'p, 'tcx> PatStack<'p, 'tcx> {
@@ -504,13 +504,36 @@ impl<'p, 'tcx> FromIterator<&'p Pat<'tcx>> for PatStack<'p, 'tcx> {
     }
 }
 
+/// Depending on the match patterns, the specialization process might be able to use a fast path.
+/// Tracks whether we can use the fast path and the lookup table needed in those cases.
+#[derive(Clone, Debug, PartialEq)]
+enum SpecializationCache {
+    /// Patterns consist of only enum variants.
+    /// Variant patterns does not intersect with each other (in contrast to range patterns),
+    /// so it is possible to precompute the result of `Matrix::specialize_constructor` at a
+    /// lower computational complexity.
+    /// `lookup` is responsible for holding the precomputed result of
+    /// `Matrix::specialize_constructor`, while `wilds` is used for two purposes: the first one is
+    /// the precomputed result of `Matrix::specialize_wildcard`, and the second is to be used as a
+    /// fallback for `Matrix::specialize_constructor` when it tries to apply a constructor that
+    /// has not been seen in the `Matrix`. See `update_cache` for further explanations.
+    Variants { lookup: FxHashMap<DefId, SmallVec<[usize; 1]>>, wilds: SmallVec<[usize; 1]> },
+    /// Does not belong to the cases above, use the slow path.
+    Incompatible,
+}
+
 /// A 2D matrix.
-#[derive(Clone)]
-crate struct Matrix<'p, 'tcx>(Vec<PatStack<'p, 'tcx>>);
+#[derive(Clone, PartialEq)]
+crate struct Matrix<'p, 'tcx> {
+    patterns: Vec<PatStack<'p, 'tcx>>,
+    cache: SpecializationCache,
+}
 
 impl<'p, 'tcx> Matrix<'p, 'tcx> {
     crate fn empty() -> Self {
-        Matrix(vec![])
+        // Use `SpecializationCache::Incompatible` as a placeholder; we will initialize it on the
+        // first call to `push`. See the first half of `update_cache`.
+        Matrix { patterns: vec![], cache: SpecializationCache::Incompatible }
     }
 
     /// Pushes a new row to the matrix. If the row starts with an or-pattern, this expands it.
@@ -522,18 +545,101 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> {
                 self.push(row)
             }
         } else {
-            self.0.push(row);
+            self.patterns.push(row);
+            self.update_cache(self.patterns.len() - 1);
+        }
+    }
+
+    fn update_cache(&mut self, idx: usize) {
+        let row = &self.patterns[idx];
+        // We don't know which kind of cache could be used until we see the first row; therefore an
+        // empty `Matrix` is initialized with `SpecializationCache::Empty`, then the cache is
+        // assigned the appropriate variant below on the first call to `push`.
+        if self.patterns.is_empty() {
+            self.cache = if row.is_empty() {
+                SpecializationCache::Incompatible
+            } else {
+                match *row.head().kind {
+                    PatKind::Variant { .. } => SpecializationCache::Variants {
+                        lookup: FxHashMap::default(),
+                        wilds: SmallVec::new(),
+                    },
+                    // Note: If the first pattern is a wildcard, then all patterns after that is not
+                    // useful. The check is simple enough so we treat it as the same as unsupported
+                    // patterns.
+                    _ => SpecializationCache::Incompatible,
+                }
+            };
+        }
+        // Update the cache.
+        match &mut self.cache {
+            SpecializationCache::Variants { ref mut lookup, ref mut wilds } => {
+                let head = row.head();
+                match *head.kind {
+                    _ if head.is_wildcard() => {
+                        // Per rule 1.3 in the top-level comments, a wildcard pattern is included in
+                        // the result of `specialize_constructor` for *any* `Constructor`.
+                        // We push the wildcard pattern to the precomputed result for constructors
+                        // that we have seen before; results for constructors we have not yet seen
+                        // defaults to `wilds`, which is updated right below.
+                        for (_, v) in lookup.iter_mut() {
+                            v.push(idx);
+                        }
+                        // Per rule 2.1 and 2.2 in the top-level comments, only wildcard patterns
+                        // are included in the result of `specialize_wildcard`.
+                        // What we do here is to track the wildcards we have seen; so in addition to
+                        // acting as the precomputed result of `specialize_wildcard`, `wilds` also
+                        // serves as the default value of `specialize_constructor` for constructors
+                        // that are not in `lookup`.
+                        wilds.push(idx);
+                    }
+                    PatKind::Variant { adt_def, variant_index, .. } => {
+                        // Handle the cases of rule 1.1 and 1.2 in the top-level comments.
+                        // A variant pattern can only be included in the results of
+                        // `specialize_constructor` for a particular constructor, therefore we are
+                        // using a HashMap to track that.
+                        lookup
+                            .entry(adt_def.variants[variant_index].def_id)
+                            // Default to `wilds` for absent keys. See above for an explanation.
+                            .or_insert_with(|| wilds.clone())
+                            .push(idx);
+                    }
+                    _ => {
+                        self.cache = SpecializationCache::Incompatible;
+                    }
+                }
+            }
+            SpecializationCache::Incompatible => {}
         }
     }
 
     /// Iterate over the first component of each row
     fn heads<'a>(&'a self) -> impl Iterator<Item = &'a Pat<'tcx>> + Captures<'p> {
-        self.0.iter().map(|r| r.head())
+        self.patterns.iter().map(|r| r.head())
     }
 
     /// This computes `D(self)`. See top of the file for explanations.
     fn specialize_wildcard(&self) -> Self {
-        self.0.iter().filter_map(|r| r.specialize_wildcard()).collect()
+        match &self.cache {
+            SpecializationCache::Variants { wilds, .. } => {
+                let result =
+                    wilds.iter().filter_map(|&i| self.patterns[i].specialize_wildcard()).collect();
+                // When debug assertions are enabled, check the results against the "slow path"
+                // result.
+                debug_assert_eq!(
+                    result,
+                    Self {
+                        patterns: self.patterns.clone(),
+                        cache: SpecializationCache::Incompatible
+                    }
+                    .specialize_wildcard()
+                );
+                result
+            }
+            SpecializationCache::Incompatible => {
+                self.patterns.iter().filter_map(|r| r.specialize_wildcard()).collect()
+            }
+        }
     }
 
     /// This computes `S(constructor, self)`. See top of the file for explanations.
@@ -543,10 +649,47 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> {
         constructor: &Constructor<'tcx>,
         ctor_wild_subpatterns: &Fields<'p, 'tcx>,
     ) -> Matrix<'p, 'tcx> {
-        self.0
-            .iter()
-            .filter_map(|r| r.specialize_constructor(cx, constructor, ctor_wild_subpatterns))
-            .collect()
+        match &self.cache {
+            SpecializationCache::Variants { lookup, wilds } => {
+                let result: Self = if let Constructor::Variant(id) = constructor {
+                    lookup
+                        .get(id)
+                        // Default to `wilds` for absent keys. See `update_cache` for an explanation.
+                        .unwrap_or(&wilds)
+                        .iter()
+                        .filter_map(|&i| {
+                            self.patterns[i].specialize_constructor(
+                                cx,
+                                constructor,
+                                ctor_wild_subpatterns,
+                            )
+                        })
+                        .collect()
+                } else {
+                    unreachable!()
+                };
+                // When debug assertions are enabled, check the results against the "slow path"
+                // result.
+                debug_assert_eq!(
+                    result,
+                    Matrix {
+                        patterns: self.patterns.clone(),
+                        cache: SpecializationCache::Incompatible
+                    }
+                    .specialize_constructor(
+                        cx,
+                        constructor,
+                        ctor_wild_subpatterns
+                    )
+                );
+                result
+            }
+            SpecializationCache::Incompatible => self
+                .patterns
+                .iter()
+                .filter_map(|r| r.specialize_constructor(cx, constructor, ctor_wild_subpatterns))
+                .collect(),
+        }
     }
 }
 
@@ -568,7 +711,7 @@ impl<'p, 'tcx> fmt::Debug for Matrix<'p, 'tcx> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         write!(f, "\n")?;
 
-        let &Matrix(ref m) = self;
+        let Matrix { patterns: m, .. } = self;
         let pretty_printed_matrix: Vec<Vec<String>> =
             m.iter().map(|row| row.iter().map(|pat| format!("{:?}", pat)).collect()).collect();
 
@@ -1824,7 +1967,7 @@ crate fn is_useful<'p, 'tcx>(
     is_under_guard: bool,
     is_top_level: bool,
 ) -> Usefulness<'tcx> {
-    let &Matrix(ref rows) = matrix;
+    let Matrix { patterns: rows, .. } = matrix;
     debug!("is_useful({:#?}, {:#?})", matrix, v);
 
     // The base case. We are pattern-matching on () and the return value is
@@ -2266,7 +2409,7 @@ fn split_grouped_constructors<'p, 'tcx>(
                 // `borders` is the set of borders between equivalence classes: each equivalence
                 // class lies between 2 borders.
                 let row_borders = matrix
-                    .0
+                    .patterns
                     .iter()
                     .flat_map(|row| {
                         IntRange::from_pat(tcx, param_env, row.head()).map(|r| (r, row.len()))
diff --git a/compiler/rustc_mir_build/src/thir/pattern/mod.rs b/compiler/rustc_mir_build/src/thir/pattern/mod.rs
index d617f4a6aa6..718ed78889f 100644
--- a/compiler/rustc_mir_build/src/thir/pattern/mod.rs
+++ b/compiler/rustc_mir_build/src/thir/pattern/mod.rs
@@ -39,19 +39,19 @@ crate enum PatternError {
     NonConstPath(Span),
 }
 
-#[derive(Copy, Clone, Debug)]
+#[derive(Copy, Clone, Debug, PartialEq)]
 crate enum BindingMode {
     ByValue,
     ByRef(BorrowKind),
 }
 
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, PartialEq)]
 crate struct FieldPat<'tcx> {
     crate field: Field,
     crate pattern: Pat<'tcx>,
 }
 
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, PartialEq)]
 crate struct Pat<'tcx> {
     crate ty: Ty<'tcx>,
     crate span: Span,
@@ -116,7 +116,7 @@ crate struct Ascription<'tcx> {
     crate user_ty_span: Span,
 }
 
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, PartialEq)]
 crate enum PatKind<'tcx> {
     Wild,