about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRyo Yoshida <low.ryoshida@gmail.com>2022-05-31 02:34:09 +0900
committerRyo Yoshida <low.ryoshida@gmail.com>2022-06-01 01:22:11 +0900
commit62d6b5a594c4b8fa53eec25e4b2d2fd2580a31f3 (patch)
treea478ff0ca4cf0b8b021ded62f26d0c5dbc26dd8a
parentc1c867506b69410f6c92ddbe2d7c6b81ab9974ab (diff)
downloadrust-62d6b5a594c4b8fa53eec25e4b2d2fd2580a31f3.tar.gz
rust-62d6b5a594c4b8fa53eec25e4b2d2fd2580a31f3.zip
Generalize some inference functions for patterns
-rw-r--r--crates/hir-ty/src/infer.rs25
-rw-r--r--crates/hir-ty/src/infer/pat.rs146
2 files changed, 107 insertions, 64 deletions
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs
index a19ff8bf601..4b80f06a3ae 100644
--- a/crates/hir-ty/src/infer.rs
+++ b/crates/hir-ty/src/infer.rs
@@ -125,6 +125,31 @@ impl Default for BindingMode {
     }
 }
 
+/// Used to generalize patterns and assignee expressions.
+trait PatLike: Into<ExprOrPatId> + Copy {
+    type BindingMode: Copy;
+
+    fn infer(
+        this: &mut InferenceContext,
+        id: Self,
+        expected_ty: &Ty,
+        default_bm: Self::BindingMode,
+    ) -> Ty;
+}
+
+impl PatLike for PatId {
+    type BindingMode = BindingMode;
+
+    fn infer(
+        this: &mut InferenceContext,
+        id: Self,
+        expected_ty: &Ty,
+        default_bm: Self::BindingMode,
+    ) -> Ty {
+        this.infer_pat(id, expected_ty, default_bm)
+    }
+}
+
 #[derive(Debug)]
 pub(crate) struct InferOk<T> {
     value: T,
diff --git a/crates/hir-ty/src/infer/pat.rs b/crates/hir-ty/src/infer/pat.rs
index c06d262f5e3..dc86f696d4f 100644
--- a/crates/hir-ty/src/infer/pat.rs
+++ b/crates/hir-ty/src/infer/pat.rs
@@ -4,7 +4,7 @@ use std::iter::repeat_with;
 
 use chalk_ir::Mutability;
 use hir_def::{
-    expr::{BindingAnnotation, Expr, Literal, Pat, PatId, RecordFieldPat},
+    expr::{BindingAnnotation, Expr, Literal, Pat, PatId},
     path::Path,
     type_ref::ConstScalar,
 };
@@ -17,15 +17,20 @@ use crate::{
     TyKind,
 };
 
+use super::PatLike;
+
 impl<'a> InferenceContext<'a> {
-    fn infer_tuple_struct_pat(
+    /// Infers type for tuple struct pattern or its corresponding assignee expression.
+    ///
+    /// Ellipses found in the original pattern or expression must be filtered out.
+    pub(super) fn infer_tuple_struct_pat_like<T: PatLike>(
         &mut self,
         path: Option<&Path>,
-        subpats: &[PatId],
         expected: &Ty,
-        default_bm: BindingMode,
-        id: PatId,
+        default_bm: T::BindingMode,
+        id: T,
         ellipsis: Option<usize>,
+        subs: &[T],
     ) -> Ty {
         let (ty, def) = self.resolve_variant(path, true);
         let var_data = def.map(|it| it.variant_data(self.db.upcast()));
@@ -39,8 +44,8 @@ impl<'a> InferenceContext<'a> {
 
         let field_tys = def.map(|it| self.db.field_types(it)).unwrap_or_default();
         let (pre, post) = match ellipsis {
-            Some(idx) => subpats.split_at(idx),
-            None => (subpats, &[][..]),
+            Some(idx) => subs.split_at(idx),
+            None => (subs, &[][..]),
         };
         let post_idx_offset = field_tys.iter().count().saturating_sub(post.len());
 
@@ -54,22 +59,22 @@ impl<'a> InferenceContext<'a> {
                     field_tys[field].clone().substitute(Interner, &substs)
                 });
             let expected_ty = self.normalize_associated_types_in(expected_ty);
-            self.infer_pat(subpat, &expected_ty, default_bm);
+            T::infer(self, subpat, &expected_ty, default_bm);
         }
 
         ty
     }
 
-    fn infer_record_pat(
+    /// Infers type for record pattern or its corresponding assignee expression.
+    pub(super) fn infer_record_pat_like<T: PatLike>(
         &mut self,
         path: Option<&Path>,
-        subpats: &[RecordFieldPat],
         expected: &Ty,
-        default_bm: BindingMode,
-        id: PatId,
+        default_bm: T::BindingMode,
+        id: T,
+        subs: impl Iterator<Item = (Name, T)>,
     ) -> Ty {
         let (ty, def) = self.resolve_variant(path, false);
-        let var_data = def.map(|it| it.variant_data(self.db.upcast()));
         if let Some(variant) = def {
             self.write_variant_resolution(id.into(), variant);
         }
@@ -80,18 +85,64 @@ impl<'a> InferenceContext<'a> {
             ty.as_adt().map(|(_, s)| s.clone()).unwrap_or_else(|| Substitution::empty(Interner));
 
         let field_tys = def.map(|it| self.db.field_types(it)).unwrap_or_default();
-        for subpat in subpats {
-            let matching_field = var_data.as_ref().and_then(|it| it.field(&subpat.name));
-            let expected_ty = matching_field.map_or(self.err_ty(), |field| {
-                field_tys[field].clone().substitute(Interner, &substs)
-            });
+        let var_data = def.map(|it| it.variant_data(self.db.upcast()));
+
+        for (name, inner) in subs {
+            let expected_ty = var_data
+                .as_ref()
+                .and_then(|it| it.field(&name))
+                .map_or(self.err_ty(), |f| field_tys[f].clone().substitute(Interner, &substs));
             let expected_ty = self.normalize_associated_types_in(expected_ty);
-            self.infer_pat(subpat.pat, &expected_ty, default_bm);
+
+            T::infer(self, inner, &expected_ty, default_bm);
         }
 
         ty
     }
 
+    /// Infers type for tuple pattern or its corresponding assignee expression.
+    ///
+    /// Ellipses found in the original pattern or expression must be filtered out.
+    pub(super) fn infer_tuple_pat_like<T: PatLike>(
+        &mut self,
+        expected: &Ty,
+        default_bm: T::BindingMode,
+        ellipsis: Option<usize>,
+        subs: &[T],
+    ) -> Ty {
+        let expectations = match expected.as_tuple() {
+            Some(parameters) => &*parameters.as_slice(Interner),
+            _ => &[],
+        };
+
+        let ((pre, post), n_uncovered_patterns) = match ellipsis {
+            Some(idx) => (subs.split_at(idx), expectations.len().saturating_sub(subs.len())),
+            None => ((&subs[..], &[][..]), 0),
+        };
+        let mut expectations_iter = expectations
+            .iter()
+            .cloned()
+            .map(|a| a.assert_ty_ref(Interner).clone())
+            .chain(repeat_with(|| self.table.new_type_var()));
+
+        let mut inner_tys = Vec::with_capacity(n_uncovered_patterns + subs.len());
+
+        inner_tys.extend(expectations_iter.by_ref().take(n_uncovered_patterns + subs.len()));
+
+        // Process pre
+        for (ty, pat) in inner_tys.iter_mut().zip(pre) {
+            *ty = T::infer(self, *pat, ty, default_bm);
+        }
+
+        // Process post
+        for (ty, pat) in inner_tys.iter_mut().skip(pre.len() + n_uncovered_patterns).zip(post) {
+            *ty = T::infer(self, *pat, ty, default_bm);
+        }
+
+        TyKind::Tuple(inner_tys.len(), Substitution::from_iter(Interner, inner_tys))
+            .intern(Interner)
+    }
+
     pub(super) fn infer_pat(
         &mut self,
         pat: PatId,
@@ -129,42 +180,7 @@ impl<'a> InferenceContext<'a> {
 
         let ty = match &self.body[pat] {
             Pat::Tuple { args, ellipsis } => {
-                let expectations = match expected.as_tuple() {
-                    Some(parameters) => &*parameters.as_slice(Interner),
-                    _ => &[],
-                };
-
-                let ((pre, post), n_uncovered_patterns) = match ellipsis {
-                    Some(idx) => {
-                        (args.split_at(*idx), expectations.len().saturating_sub(args.len()))
-                    }
-                    None => ((&args[..], &[][..]), 0),
-                };
-                let mut expectations_iter = expectations
-                    .iter()
-                    .cloned()
-                    .map(|a| a.assert_ty_ref(Interner).clone())
-                    .chain(repeat_with(|| self.table.new_type_var()));
-
-                let mut inner_tys = Vec::with_capacity(n_uncovered_patterns + args.len());
-
-                inner_tys
-                    .extend(expectations_iter.by_ref().take(n_uncovered_patterns + args.len()));
-
-                // Process pre
-                for (ty, pat) in inner_tys.iter_mut().zip(pre) {
-                    *ty = self.infer_pat(*pat, ty, default_bm);
-                }
-
-                // Process post
-                for (ty, pat) in
-                    inner_tys.iter_mut().skip(pre.len() + n_uncovered_patterns).zip(post)
-                {
-                    *ty = self.infer_pat(*pat, ty, default_bm);
-                }
-
-                TyKind::Tuple(inner_tys.len(), Substitution::from_iter(Interner, inner_tys))
-                    .intern(Interner)
+                self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args)
             }
             Pat::Or(pats) => {
                 if let Some((first_pat, rest)) = pats.split_first() {
@@ -191,16 +207,18 @@ impl<'a> InferenceContext<'a> {
                 let subty = self.infer_pat(*pat, &expectation, default_bm);
                 TyKind::Ref(mutability, static_lifetime(), subty).intern(Interner)
             }
-            Pat::TupleStruct { path: p, args: subpats, ellipsis } => self.infer_tuple_struct_pat(
-                p.as_deref(),
-                subpats,
-                &expected,
-                default_bm,
-                pat,
-                *ellipsis,
-            ),
+            Pat::TupleStruct { path: p, args: subpats, ellipsis } => self
+                .infer_tuple_struct_pat_like(
+                    p.as_deref(),
+                    &expected,
+                    default_bm,
+                    pat,
+                    *ellipsis,
+                    subpats,
+                ),
             Pat::Record { path: p, args: fields, ellipsis: _ } => {
-                self.infer_record_pat(p.as_deref(), fields, &expected, default_bm, pat)
+                let subs = fields.iter().map(|f| (f.name.clone(), f.pat));
+                self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat.into(), subs)
             }
             Pat::Path(path) => {
                 // FIXME use correct resolver for the surrounding expression