about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_mir_build/src/lib.rs1
-rw-r--r--compiler/rustc_mir_build/src/thir/pattern/_match.rs130
2 files changed, 63 insertions, 68 deletions
diff --git a/compiler/rustc_mir_build/src/lib.rs b/compiler/rustc_mir_build/src/lib.rs
index 714041ad4e8..0866892265b 100644
--- a/compiler/rustc_mir_build/src/lib.rs
+++ b/compiler/rustc_mir_build/src/lib.rs
@@ -9,6 +9,7 @@
 #![feature(control_flow_enum)]
 #![feature(crate_visibility_modifier)]
 #![feature(bool_to_option)]
+#![feature(once_cell)]
 #![feature(or_patterns)]
 #![recursion_limit = "256"]
 
diff --git a/compiler/rustc_mir_build/src/thir/pattern/_match.rs b/compiler/rustc_mir_build/src/thir/pattern/_match.rs
index c2b0d8f52e3..30529ef5e1b 100644
--- a/compiler/rustc_mir_build/src/thir/pattern/_match.rs
+++ b/compiler/rustc_mir_build/src/thir/pattern/_match.rs
@@ -295,6 +295,7 @@ use self::WitnessPreference::*;
 
 use rustc_data_structures::captures::Captures;
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
+use rustc_data_structures::sync::OnceCell;
 use rustc_index::vec::Idx;
 
 use super::{compare_const_vals, PatternFoldable, PatternFolder};
@@ -346,32 +347,40 @@ impl<'tcx> Pat<'tcx> {
 
 /// A row of a matrix. Rows of len 1 are very common, which is why `SmallVec[_; 2]`
 /// works well.
-#[derive(Debug, Clone, PartialEq)]
-crate struct PatStack<'p, 'tcx>(SmallVec<[&'p Pat<'tcx>; 2]>);
+#[derive(Debug, Clone)]
+crate struct PatStack<'p, 'tcx> {
+    pats: SmallVec<[&'p Pat<'tcx>; 2]>,
+    /// Cache for the constructor of the head
+    head_ctor: OnceCell<Constructor<'tcx>>,
+}
 
 impl<'p, 'tcx> PatStack<'p, 'tcx> {
     crate fn from_pattern(pat: &'p Pat<'tcx>) -> Self {
-        PatStack(smallvec![pat])
+        Self::from_vec(smallvec![pat])
     }
 
     fn from_vec(vec: SmallVec<[&'p Pat<'tcx>; 2]>) -> Self {
-        PatStack(vec)
+        PatStack { pats: vec, head_ctor: OnceCell::new() }
     }
 
     fn is_empty(&self) -> bool {
-        self.0.is_empty()
+        self.pats.is_empty()
     }
 
     fn len(&self) -> usize {
-        self.0.len()
+        self.pats.len()
     }
 
     fn head(&self) -> &'p Pat<'tcx> {
-        self.0[0]
+        self.pats[0]
+    }
+
+    fn head_ctor<'a>(&'a self, cx: &MatchCheckCtxt<'p, 'tcx>) -> &'a Constructor<'tcx> {
+        self.head_ctor.get_or_init(|| pat_constructor(cx, self.head()))
     }
 
     fn iter(&self) -> impl Iterator<Item = &Pat<'tcx>> {
-        self.0.iter().copied()
+        self.pats.iter().copied()
     }
 
     // If the first pattern is an or-pattern, expand this pattern. Otherwise, return `None`.
@@ -383,7 +392,7 @@ impl<'p, 'tcx> PatStack<'p, 'tcx> {
                 pats.iter()
                     .map(|pat| {
                         let mut new_patstack = PatStack::from_pattern(pat);
-                        new_patstack.0.extend_from_slice(&self.0[1..]);
+                        new_patstack.pats.extend_from_slice(&self.pats[1..]);
                         new_patstack
                     })
                     .collect(),
@@ -414,16 +423,13 @@ impl<'p, 'tcx> PatStack<'p, 'tcx> {
         is_my_head_ctor: bool,
     ) -> Option<PatStack<'p, 'tcx>> {
         // We return `None` if `ctor` is not covered by `self.head()`. If `ctor` is known to be
-        // derived from `self.head()`, then we don't need to check; otherwise, we compute the
-        // constructor of `self.head()` and check for constructor inclusion.
+        // derived from `self.head()`, then we don't need to check; otherwise, we check for
+        // constructor inclusion.
         // Note that this shortcut is also necessary for correctness: a pattern should always be
         // specializable with its own constructor, even in cases where we refuse to inspect values like
         // opaque constants.
-        if !is_my_head_ctor {
-            let head_ctor = pat_constructor(cx.tcx, cx.param_env, self.head());
-            if !ctor.is_covered_by(cx, &head_ctor, self.head().ty) {
-                return None;
-            }
+        if !is_my_head_ctor && !ctor.is_covered_by(cx, self.head_ctor(cx), self.head().ty) {
+            return None;
         }
         let new_fields = ctor_wild_subpatterns.replace_with_pattern_arguments(self.head());
 
@@ -437,13 +443,19 @@ impl<'p, 'tcx> PatStack<'p, 'tcx> {
 
         // We pop the head pattern and push the new fields extracted from the arguments of
         // `self.head()`.
-        Some(new_fields.push_on_patstack(&self.0[1..]))
+        Some(new_fields.push_on_patstack(&self.pats[1..]))
     }
 }
 
 impl<'p, 'tcx> Default for PatStack<'p, 'tcx> {
     fn default() -> Self {
-        PatStack(smallvec![])
+        Self::from_vec(smallvec![])
+    }
+}
+
+impl<'p, 'tcx> PartialEq for PatStack<'p, 'tcx> {
+    fn eq(&self, other: &Self) -> bool {
+        self.pats == other.pats
     }
 }
 
@@ -452,7 +464,7 @@ impl<'p, 'tcx> FromIterator<&'p Pat<'tcx>> for PatStack<'p, 'tcx> {
     where
         T: IntoIterator<Item = &'p Pat<'tcx>>,
     {
-        PatStack(iter.into_iter().collect())
+        Self::from_vec(iter.into_iter().collect())
     }
 }
 
@@ -570,6 +582,14 @@ impl<'p, 'tcx> Matrix<'p, 'tcx> {
         self.patterns.iter().map(|r| r.head())
     }
 
+    /// Iterate over the first constructor of each row
+    fn head_ctors<'a>(
+        &'a self,
+        cx: &'a MatchCheckCtxt<'p, 'tcx>,
+    ) -> impl Iterator<Item = &'a Constructor<'tcx>> + Captures<'a> + Captures<'p> {
+        self.patterns.iter().map(move |r| r.head_ctor(cx))
+    }
+
     /// This computes `S(constructor, self)`. See top of the file for explanations.
     fn specialize_constructor(
         &self,
@@ -906,10 +926,7 @@ impl Slice {
             _ => return smallvec![Slice(self)],
         };
 
-        let head_ctors = matrix
-            .heads()
-            .map(|p| pat_constructor(cx.tcx, cx.param_env, p))
-            .filter(|c| !c.is_wildcard());
+        let head_ctors = matrix.head_ctors(cx).filter(|c| !c.is_wildcard());
 
         let mut max_prefix_len = self_prefix;
         let mut max_suffix_len = self_suffix;
@@ -1120,7 +1137,7 @@ impl<'tcx> Constructor<'tcx> {
     /// `hir_id` is `None` when we're evaluating the wildcard pattern. In that case we do not want
     /// to lint for overlapping ranges.
     fn split<'p>(
-        self,
+        &self,
         cx: &MatchCheckCtxt<'p, 'tcx>,
         pcx: PatCtxt<'tcx>,
         matrix: &Matrix<'p, 'tcx>,
@@ -1138,7 +1155,7 @@ impl<'tcx> Constructor<'tcx> {
             }
             Slice(slice @ Slice { kind: VarLen(..), .. }) => slice.split(cx, matrix),
             // Any other constructor can be used unchanged.
-            _ => smallvec![self],
+            _ => smallvec![self.clone()],
         }
     }
 
@@ -1991,25 +2008,9 @@ impl<'tcx> IntRange<'tcx> {
         }
     }
 
-    fn from_pat(
-        tcx: TyCtxt<'tcx>,
-        param_env: ty::ParamEnv<'tcx>,
-        pat: &Pat<'tcx>,
-    ) -> Option<IntRange<'tcx>> {
-        // This MUST be kept in sync with `pat_constructor`.
-        match *pat.kind {
-            PatKind::Constant { value } => Self::from_const(tcx, param_env, value, pat.span),
-            PatKind::Range(PatRange { lo, hi, end }) => {
-                let ty = lo.ty;
-                Self::from_range(
-                    tcx,
-                    lo.eval_bits(tcx, param_env, lo.ty),
-                    hi.eval_bits(tcx, param_env, hi.ty),
-                    ty,
-                    &end,
-                    pat.span,
-                )
-            }
+    fn from_ctor<'a>(ctor: &'a Constructor<'tcx>) -> Option<&'a IntRange<'tcx>> {
+        match ctor {
+            IntRange(range) => Some(range),
             _ => None,
         }
     }
@@ -2145,7 +2146,7 @@ impl<'tcx> IntRange<'tcx> {
     /// between every pair of boundary points. (This essentially sums up to performing the intuitive
     /// merging operation depicted above.)
     fn split<'p>(
-        self,
+        &self,
         cx: &MatchCheckCtxt<'p, 'tcx>,
         pcx: PatCtxt<'tcx>,
         matrix: &Matrix<'p, 'tcx>,
@@ -2176,15 +2177,13 @@ impl<'tcx> IntRange<'tcx> {
         // Collect the span and range of all the intersecting ranges to lint on likely
         // incorrect range patterns. (#63987)
         let mut overlaps = vec![];
+        let row_len = matrix.patterns.get(0).map(|r| r.len()).unwrap_or(0);
         // `borders` is the set of borders between equivalence classes: each equivalence
         // class lies between 2 borders.
         let row_borders = matrix
-            .patterns
-            .iter()
-            .flat_map(|row| {
-                IntRange::from_pat(cx.tcx, cx.param_env, row.head()).map(|r| (r, row.len()))
-            })
-            .flat_map(|(range, row_len)| {
+            .head_ctors(cx)
+            .filter_map(|ctor| IntRange::from_ctor(ctor))
+            .filter_map(|range| {
                 let intersection = self.intersection(cx.tcx, &range);
                 let should_lint = self.suspicious_intersection(&range);
                 if let (Some(range), 1, true) = (&intersection, row_len, should_lint) {
@@ -2229,7 +2228,7 @@ impl<'tcx> IntRange<'tcx> {
     }
 
     fn lint_overlapping_patterns(
-        self,
+        &self,
         tcx: TyCtxt<'tcx>,
         hir_id: Option<HirId>,
         ty: Ty<'tcx>,
@@ -2412,7 +2411,7 @@ crate fn is_useful<'p, 'tcx>(
 
     debug!("is_useful_expand_first_col: pcx={:#?}, expanding {:#?}", pcx, v.head());
 
-    let constructor = pat_constructor(cx.tcx, cx.param_env, v.head());
+    let constructor = v.head_ctor(cx);
     let ret = if !constructor.is_wildcard() {
         debug!("is_useful - expanding constructor: {:#?}", constructor);
         constructor
@@ -2435,11 +2434,8 @@ crate fn is_useful<'p, 'tcx>(
     } else {
         debug!("is_useful - expanding wildcard");
 
-        let used_ctors: Vec<Constructor<'_>> = matrix
-            .heads()
-            .map(|p| pat_constructor(cx.tcx, cx.param_env, p))
-            .filter(|c| !c.is_wildcard())
-            .collect();
+        let used_ctors: Vec<Constructor<'_>> =
+            matrix.head_ctors(cx).cloned().filter(|c| !c.is_wildcard()).collect();
         debug!("is_useful_used_ctors = {:#?}", used_ctors);
         // `all_ctors` are all the constructors for the given type, which
         // should all be represented (or caught with the wild pattern `_`).
@@ -2563,12 +2559,10 @@ fn is_useful_specialized<'p, 'tcx>(
 
 /// Determines the constructor that the given pattern can be specialized to.
 /// Returns `None` in case of a catch-all, which can't be specialized.
-fn pat_constructor<'tcx>(
-    tcx: TyCtxt<'tcx>,
-    param_env: ty::ParamEnv<'tcx>,
-    pat: &Pat<'tcx>,
+fn pat_constructor<'p, 'tcx>(
+    cx: &MatchCheckCtxt<'p, 'tcx>,
+    pat: &'p Pat<'tcx>,
 ) -> Constructor<'tcx> {
-    // This MUST be kept in sync with `IntRange::from_pat`.
     match *pat.kind {
         PatKind::AscribeUserType { .. } => bug!(), // Handled by `expand_pattern`
         PatKind::Binding { .. } | PatKind::Wild => Wildcard,
@@ -2577,7 +2571,7 @@ fn pat_constructor<'tcx>(
             Variant(adt_def.variants[variant_index].def_id)
         }
         PatKind::Constant { value } => {
-            if let Some(int_range) = IntRange::from_const(tcx, param_env, value, pat.span) {
+            if let Some(int_range) = IntRange::from_const(cx.tcx, cx.param_env, value, pat.span) {
                 IntRange(int_range)
             } else {
                 match value.ty.kind() {
@@ -2593,9 +2587,9 @@ fn pat_constructor<'tcx>(
         PatKind::Range(PatRange { lo, hi, end }) => {
             let ty = lo.ty;
             if let Some(int_range) = IntRange::from_range(
-                tcx,
-                lo.eval_bits(tcx, param_env, lo.ty),
-                hi.eval_bits(tcx, param_env, hi.ty),
+                cx.tcx,
+                lo.eval_bits(cx.tcx, cx.param_env, lo.ty),
+                hi.eval_bits(cx.tcx, cx.param_env, hi.ty),
                 ty,
                 &end,
                 pat.span,
@@ -2608,7 +2602,7 @@ fn pat_constructor<'tcx>(
         PatKind::Array { ref prefix, ref slice, ref suffix }
         | PatKind::Slice { ref prefix, ref slice, ref suffix } => {
             let array_len = match pat.ty.kind() {
-                ty::Array(_, length) => Some(length.eval_usize(tcx, param_env)),
+                ty::Array(_, length) => Some(length.eval_usize(cx.tcx, cx.param_env)),
                 ty::Slice(_) => None,
                 _ => span_bug!(pat.span, "bad ty {:?} for slice pattern", pat.ty),
             };