about summary refs log tree commit diff
diff options
context:
space:
mode:
authorNadrieril <nadrieril+git@gmail.com>2020-10-25 22:51:50 +0000
committerNadrieril <nadrieril+git@gmail.com>2020-10-27 00:46:32 +0000
commitc511955a9ff89089bff313cd8a87a6e62e2783f5 (patch)
treeccf736918abcb6547af4b5dcbcd875dd1afc01a3
parent6ad9f44a50fdfdb2388850fbd9b07db66f452ec7 (diff)
downloadrust-c511955a9ff89089bff313cd8a87a6e62e2783f5.tar.gz
rust-c511955a9ff89089bff313cd8a87a6e62e2783f5.zip
Factor out the two specialization steps
-rw-r--r--compiler/rustc_mir_build/src/thir/pattern/_match.rs187
1 files changed, 107 insertions, 80 deletions
diff --git a/compiler/rustc_mir_build/src/thir/pattern/_match.rs b/compiler/rustc_mir_build/src/thir/pattern/_match.rs
index aee1320ec89..7dbde71b2a6 100644
--- a/compiler/rustc_mir_build/src/thir/pattern/_match.rs
+++ b/compiler/rustc_mir_build/src/thir/pattern/_match.rs
@@ -1132,6 +1132,66 @@ impl<'tcx> Constructor<'tcx> {
         }
     }
 
+    /// Returns whether `self` is covered by `other`, ie whether `self` is a subset of `other`. For
+    /// the simple cases, this is simply checking for equality. For the "grouped" constructors,
+    /// this checks for inclusion.
+    fn is_covered_by<'p>(
+        &self,
+        cx: &MatchCheckCtxt<'p, 'tcx>,
+        other: &Constructor<'tcx>,
+        ty: Ty<'tcx>,
+    ) -> bool {
+        match (self, other) {
+            (Single, Single) => true,
+            (Variant(self_id), Variant(other_id)) => self_id == other_id,
+
+            (IntRange(self_range), IntRange(other_range)) => {
+                if self_range.intersection(cx.tcx, other_range).is_some() {
+                    // Constructor splitting should ensure that all intersections we encounter
+                    // are actually inclusions.
+                    assert!(self_range.is_subrange(other_range));
+                    true
+                } else {
+                    false
+                }
+            }
+            (
+                FloatRange(self_from, self_to, self_end),
+                FloatRange(other_from, other_to, other_end),
+            ) => {
+                match (
+                    compare_const_vals(cx.tcx, self_to, other_to, cx.param_env, ty),
+                    compare_const_vals(cx.tcx, self_from, other_from, cx.param_env, ty),
+                ) {
+                    (Some(to), Some(from)) => {
+                        (from == Ordering::Greater || from == Ordering::Equal)
+                            && (to == Ordering::Less
+                                || (other_end == self_end && to == Ordering::Equal))
+                    }
+                    _ => false,
+                }
+            }
+            (Str(self_val), Str(other_val)) => {
+                // FIXME: there's probably a more direct way of comparing for equality
+                match compare_const_vals(cx.tcx, self_val, other_val, cx.param_env, ty) {
+                    Some(comparison) => comparison == Ordering::Equal,
+                    None => false,
+                }
+            }
+
+            (Slice(self_slice), Slice(other_slice)) => {
+                other_slice.pattern_kind().covers_length(self_slice.arity())
+            }
+
+            // We are trying to inspect an opaque constant. Thus we skip the row.
+            (Opaque, _) | (_, Opaque) => false,
+            // Only a wildcard pattern can match the special extra constructor.
+            (NonExhaustive, _) => false,
+
+            _ => bug!("trying to compare incompatible constructors {:?} and {:?}", self, other),
+        }
+    }
+
     /// Apply a constructor to a list of patterns, yielding a new pattern. `pats`
     /// must have as many elements as this constructor's arity.
     ///
@@ -1461,6 +1521,41 @@ impl<'p, 'tcx> Fields<'p, 'tcx> {
         }
     }
 
+    /// Replaces contained fields with the arguments of the given pattern. Only use on a pattern
+    /// that is compatible with the constructor used to build `self`.
+    /// This is meant to be used on the result of `Fields::wildcards()`. The idea is that
+    /// `wildcards` constructs a list of fields where all entries are wildcards, and the pattern
+    /// provided to this function fills some of the fields with non-wildcards.
+    /// In the following example `Fields::wildcards` would return `[_, _, _, _]`. If we call
+    /// `replace_with_pattern_arguments` on it with the pattern, the result will be `[Some(0), _,
+    /// _, _]`.
+    /// ```rust
+    /// let x: [Option<u8>; 4] = foo();
+    /// match x {
+    ///     [Some(0), ..] => {}
+    /// }
+    /// ```
+    fn replace_with_pattern_arguments(&self, pat: &'p Pat<'tcx>) -> Self {
+        match pat.kind.as_ref() {
+            PatKind::Deref { subpattern } => Self::from_single_pattern(subpattern),
+            PatKind::Leaf { subpatterns } | PatKind::Variant { subpatterns, .. } => {
+                self.replace_with_fieldpats(subpatterns)
+            }
+            PatKind::Array { prefix, suffix, .. } | PatKind::Slice { prefix, suffix, .. } => {
+                // Number of subpatterns for the constructor
+                let ctor_arity = self.len();
+
+                // Replace the prefix and the suffix with the given patterns, leaving wildcards in
+                // the middle if there was a subslice pattern `..`.
+                let prefix = prefix.iter().enumerate();
+                let suffix =
+                    suffix.iter().enumerate().map(|(i, p)| (ctor_arity - suffix.len() + i, p));
+                self.replace_fields_indexed(prefix.chain(suffix))
+            }
+            _ => self.clone(),
+        }
+    }
+
     fn push_on_patstack(self, stack: &[&'p Pat<'tcx>]) -> PatStack<'p, 'tcx> {
         let pats: SmallVec<_> = match self {
             Fields::Slice(pats) => pats.iter().chain(stack.iter().copied()).collect(),
@@ -2535,89 +2630,21 @@ fn specialize_one_pattern<'p, 'tcx>(
         return Some(ctor_wild_subpatterns.clone());
     }
 
-    let ty = pat.ty;
-    // `unwrap` is safe because `pat` is not a wildcard.
-    let pat_ctor = pat_constructor(cx.tcx, cx.param_env, pat).unwrap();
-
-    let ctor_covered_by_pat = match (ctor, &pat_ctor) {
-        (Single, Single) => true,
-        (Variant(ctor_id), Variant(pat_id)) => ctor_id == pat_id,
-
-        (IntRange(ctor_range), IntRange(pat_range)) => {
-            if ctor_range.intersection(cx.tcx, pat_range).is_some() {
-                // Constructor splitting should ensure that all intersections we encounter
-                // are actually inclusions.
-                assert!(ctor_range.is_subrange(pat_range));
-                true
-            } else {
-                false
-            }
+    // We return `None` if `ctor` is not covered by `pat`. If `ctor` is known to be derived from
+    // `pat` then we don't need to check; otherwise, we compute the constructor of `pat` and check
+    // for constructor inclusion.
+    // Note that this shortcut is also necessary for correctness: a pattern should always be
+    // specializable with its own constructor, even in cases where we refuse to inspect values like
+    // opaque constants.
+    if !is_its_own_ctor {
+        // `unwrap` is safe because `pat` is not a wildcard.
+        let pat_ctor = pat_constructor(cx.tcx, cx.param_env, pat).unwrap();
+        if !ctor.is_covered_by(cx, &pat_ctor, pat.ty) {
+            return None;
         }
-        (FloatRange(ctor_from, ctor_to, ctor_end), FloatRange(pat_from, pat_to, pat_end)) => {
-            let to = compare_const_vals(cx.tcx, ctor_to, pat_to, cx.param_env, ty)?;
-            let from = compare_const_vals(cx.tcx, ctor_from, pat_from, cx.param_env, ty)?;
-            (from == Ordering::Greater || from == Ordering::Equal)
-                && (to == Ordering::Less || (pat_end == ctor_end && to == Ordering::Equal))
-        }
-        (Str(ctor_val), Str(pat_val)) => {
-            // FIXME: there's probably a more direct way of comparing for equality
-            let comparison = compare_const_vals(cx.tcx, ctor_val, pat_val, cx.param_env, ty)?;
-            comparison == Ordering::Equal
-        }
-
-        (Slice(ctor_slice), Slice(pat_slice)) => {
-            pat_slice.pattern_kind().covers_length(ctor_slice.arity())
-        }
-
-        // Only a wildcard pattern can match an opaque constant, unless we're specializing the
-        // value against its own constructor. That happens when we call
-        // `v.specialize_constructor(ctor)` with `ctor` obtained from `pat_constructor(v.head())`.
-        // For example, in the following match, when we are dealing with the third branch, we will
-        // specialize with an `Opaque` ctor. We want to ignore the second branch because opaque
-        // constants should not be inspected, but we don't want to ignore the current (third)
-        // branch, as that would cause us to always conclude that such a branch is unreachable.
-        // ```rust
-        // #[derive(PartialEq)]
-        // struct Foo(i32);
-        // impl Eq for Foo {}
-        // const FOO: Foo = Foo(42);
-        //
-        // match (Foo(0), true) {
-        //     (_, true) => {}
-        //     (FOO, true) => {}
-        //     (FOO, false) => {}
-        // }
-        // ```
-        (Opaque, Opaque) if is_its_own_ctor => true,
-        // We are trying to inspect an opaque constant. Thus we skip the row.
-        (Opaque, _) | (_, Opaque) => false,
-        // Only a wildcard pattern can match the special extra constructor.
-        (NonExhaustive, _) => false,
-
-        _ => bug!("trying to specialize pattern {:?} with constructor {:?}", pat, ctor),
-    };
-
-    if !ctor_covered_by_pat {
-        return None;
     }
 
-    let fields = match pat.kind.as_ref() {
-        PatKind::Deref { subpattern } => Fields::from_single_pattern(subpattern),
-        PatKind::Leaf { subpatterns } | PatKind::Variant { subpatterns, .. } => {
-            ctor_wild_subpatterns.replace_with_fieldpats(subpatterns)
-        }
-        PatKind::Array { prefix, suffix, .. } | PatKind::Slice { prefix, suffix, .. } => {
-            // Number of subpatterns for the constructor
-            let ctor_arity = ctor_wild_subpatterns.len();
-
-            // Replace the prefix and the suffix with the given patterns, leaving wildcards in
-            // the middle if there was a subslice pattern `..`.
-            let prefix = prefix.iter().enumerate();
-            let suffix = suffix.iter().enumerate().map(|(i, p)| (ctor_arity - suffix.len() + i, p));
-            ctor_wild_subpatterns.replace_fields_indexed(prefix.chain(suffix))
-        }
-        _ => ctor_wild_subpatterns.clone(),
-    };
+    let fields = ctor_wild_subpatterns.replace_with_pattern_arguments(pat);
 
     debug!("specialize({:#?}, {:#?}, {:#?}) = {:#?}", pat, ctor, ctor_wild_subpatterns, fields);