about summary refs log tree commit diff
path: root/compiler/rustc_middle/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_middle/src')
-rw-r--r--compiler/rustc_middle/src/ty/codec.rs10
-rw-r--r--compiler/rustc_middle/src/ty/context.rs12
-rw-r--r--compiler/rustc_middle/src/ty/pattern.rs27
-rw-r--r--compiler/rustc_middle/src/ty/relate.rs9
-rw-r--r--compiler/rustc_middle/src/ty/structural_impls.rs1
5 files changed, 59 insertions, 0 deletions
diff --git a/compiler/rustc_middle/src/ty/codec.rs b/compiler/rustc_middle/src/ty/codec.rs
index 23927c112bc..5ff87959a80 100644
--- a/compiler/rustc_middle/src/ty/codec.rs
+++ b/compiler/rustc_middle/src/ty/codec.rs
@@ -442,6 +442,15 @@ impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List<ty::BoundVaria
     }
 }
 
+impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List<ty::Pattern<'tcx>> {
+    fn decode(decoder: &mut D) -> &'tcx Self {
+        let len = decoder.read_usize();
+        decoder.interner().mk_patterns_from_iter(
+            (0..len).map::<ty::Pattern<'tcx>, _>(|_| Decodable::decode(decoder)),
+        )
+    }
+}
+
 impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List<ty::Const<'tcx>> {
     fn decode(decoder: &mut D) -> &'tcx Self {
         let len = decoder.read_usize();
@@ -503,6 +512,7 @@ impl_decodable_via_ref! {
     &'tcx mir::Body<'tcx>,
     &'tcx mir::ConcreteOpaqueTypes<'tcx>,
     &'tcx ty::List<ty::BoundVariableKind>,
+    &'tcx ty::List<ty::Pattern<'tcx>>,
     &'tcx ty::ListWithCachedTypeInfo<ty::Clause<'tcx>>,
     &'tcx ty::List<FieldIdx>,
     &'tcx ty::List<(VariantIdx, FieldIdx)>,
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index e8dad1e056c..f97e13f13b0 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -136,6 +136,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
 
     type AllocId = crate::mir::interpret::AllocId;
     type Pat = Pattern<'tcx>;
+    type PatList = &'tcx List<Pattern<'tcx>>;
     type Safety = hir::Safety;
     type Abi = ExternAbi;
     type Const = ty::Const<'tcx>;
@@ -843,6 +844,7 @@ pub struct CtxtInterners<'tcx> {
     captures: InternedSet<'tcx, List<&'tcx ty::CapturedPlace<'tcx>>>,
     offset_of: InternedSet<'tcx, List<(VariantIdx, FieldIdx)>>,
     valtree: InternedSet<'tcx, ty::ValTreeKind<'tcx>>,
+    patterns: InternedSet<'tcx, List<ty::Pattern<'tcx>>>,
 }
 
 impl<'tcx> CtxtInterners<'tcx> {
@@ -879,6 +881,7 @@ impl<'tcx> CtxtInterners<'tcx> {
             captures: InternedSet::with_capacity(N),
             offset_of: InternedSet::with_capacity(N),
             valtree: InternedSet::with_capacity(N),
+            patterns: InternedSet::with_capacity(N),
         }
     }
 
@@ -2659,6 +2662,7 @@ slice_interners!(
     local_def_ids: intern_local_def_ids(LocalDefId),
     captures: intern_captures(&'tcx ty::CapturedPlace<'tcx>),
     offset_of: pub mk_offset_of((VariantIdx, FieldIdx)),
+    patterns: pub mk_patterns(Pattern<'tcx>),
 );
 
 impl<'tcx> TyCtxt<'tcx> {
@@ -2932,6 +2936,14 @@ impl<'tcx> TyCtxt<'tcx> {
         self.intern_local_def_ids(def_ids)
     }
 
+    pub fn mk_patterns_from_iter<I, T>(self, iter: I) -> T::Output
+    where
+        I: Iterator<Item = T>,
+        T: CollectAndApply<ty::Pattern<'tcx>, &'tcx List<ty::Pattern<'tcx>>>,
+    {
+        T::collect_and_apply(iter, |xs| self.mk_patterns(xs))
+    }
+
     pub fn mk_local_def_ids_from_iter<I, T>(self, iter: I) -> T::Output
     where
         I: Iterator<Item = T>,
diff --git a/compiler/rustc_middle/src/ty/pattern.rs b/compiler/rustc_middle/src/ty/pattern.rs
index 758adc42e3e..5af9b17dd77 100644
--- a/compiler/rustc_middle/src/ty/pattern.rs
+++ b/compiler/rustc_middle/src/ty/pattern.rs
@@ -23,6 +23,13 @@ impl<'tcx> Flags for Pattern<'tcx> {
                 FlagComputation::for_const_kind(&start.kind()).flags
                     | FlagComputation::for_const_kind(&end.kind()).flags
             }
+            ty::PatternKind::Or(pats) => {
+                let mut flags = pats[0].flags();
+                for pat in pats[1..].iter() {
+                    flags |= pat.flags();
+                }
+                flags
+            }
         }
     }
 
@@ -31,6 +38,13 @@ impl<'tcx> Flags for Pattern<'tcx> {
             ty::PatternKind::Range { start, end } => {
                 start.outer_exclusive_binder().max(end.outer_exclusive_binder())
             }
+            ty::PatternKind::Or(pats) => {
+                let mut idx = pats[0].outer_exclusive_binder();
+                for pat in pats[1..].iter() {
+                    idx = idx.max(pat.outer_exclusive_binder());
+                }
+                idx
+            }
         }
     }
 }
@@ -77,6 +91,19 @@ impl<'tcx> IrPrint<PatternKind<'tcx>> for TyCtxt<'tcx> {
 
                 write!(f, "..={end}")
             }
+            PatternKind::Or(patterns) => {
+                write!(f, "(")?;
+                let mut first = true;
+                for pat in patterns {
+                    if first {
+                        first = false
+                    } else {
+                        write!(f, " | ")?;
+                    }
+                    write!(f, "{pat:?}")?;
+                }
+                write!(f, ")")
+            }
         }
     }
 
diff --git a/compiler/rustc_middle/src/ty/relate.rs b/compiler/rustc_middle/src/ty/relate.rs
index c3ee72bcaed..6ad4e5276b2 100644
--- a/compiler/rustc_middle/src/ty/relate.rs
+++ b/compiler/rustc_middle/src/ty/relate.rs
@@ -59,6 +59,15 @@ impl<'tcx> Relate<TyCtxt<'tcx>> for ty::Pattern<'tcx> {
                 let end = relation.relate(end_a, end_b)?;
                 Ok(tcx.mk_pat(ty::PatternKind::Range { start, end }))
             }
+            (&ty::PatternKind::Or(a), &ty::PatternKind::Or(b)) => {
+                if a.len() != b.len() {
+                    return Err(TypeError::Mismatch);
+                }
+                let v = iter::zip(a, b).map(|(a, b)| relation.relate(a, b));
+                let patterns = tcx.mk_patterns_from_iter(v)?;
+                Ok(tcx.mk_pat(ty::PatternKind::Or(patterns)))
+            }
+            (ty::PatternKind::Range { .. } | ty::PatternKind::Or(_), _) => Err(TypeError::Mismatch),
         }
     }
 }
diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs
index 26861666c1d..2fcb2a1572a 100644
--- a/compiler/rustc_middle/src/ty/structural_impls.rs
+++ b/compiler/rustc_middle/src/ty/structural_impls.rs
@@ -779,5 +779,6 @@ list_fold! {
     ty::Clauses<'tcx> : mk_clauses,
     &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> : mk_poly_existential_predicates,
     &'tcx ty::List<PlaceElem<'tcx>> : mk_place_elems,
+    &'tcx ty::List<ty::Pattern<'tcx>> : mk_patterns,
     CanonicalVarInfos<'tcx> : mk_canonical_var_infos,
 }