about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRyo Yoshida <low.ryoshida@gmail.com>2022-09-06 17:48:06 +0900
committerRyo Yoshida <low.ryoshida@gmail.com>2022-09-13 02:42:52 +0900
commit77c40f878de2c5bfca8aba5a09fd16d805944727 (patch)
tree63dc4b1dc31ab0855f4924464473c6384713fc1d
parentaeeb9e08b252b24c7cdba0a30e8dd153231c082d (diff)
downloadrust-77c40f878de2c5bfca8aba5a09fd16d805944727.tar.gz
rust-77c40f878de2c5bfca8aba5a09fd16d805944727.zip
Implement type inference for generator and yield expressions
-rw-r--r--crates/hir-ty/src/db.rs6
-rw-r--r--crates/hir-ty/src/infer.rs5
-rw-r--r--crates/hir-ty/src/infer/closure.rs6
-rw-r--r--crates/hir-ty/src/infer/expr.rs67
-rw-r--r--crates/hir-ty/src/mapping.rs12
5 files changed, 78 insertions, 18 deletions
diff --git a/crates/hir-ty/src/db.rs b/crates/hir-ty/src/db.rs
index b385b1cafae..c4f7685cd14 100644
--- a/crates/hir-ty/src/db.rs
+++ b/crates/hir-ty/src/db.rs
@@ -116,6 +116,8 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
     fn intern_impl_trait_id(&self, id: ImplTraitId) -> InternedOpaqueTyId;
     #[salsa::interned]
     fn intern_closure(&self, id: (DefWithBodyId, ExprId)) -> InternedClosureId;
+    #[salsa::interned]
+    fn intern_generator(&self, id: (DefWithBodyId, ExprId)) -> InternedGeneratorId;
 
     #[salsa::invoke(chalk_db::associated_ty_data_query)]
     fn associated_ty_data(&self, id: chalk_db::AssocTypeId) -> Arc<chalk_db::AssociatedTyDatum>;
@@ -218,6 +220,10 @@ impl_intern_key!(InternedOpaqueTyId);
 pub struct InternedClosureId(salsa::InternId);
 impl_intern_key!(InternedClosureId);
 
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct InternedGeneratorId(salsa::InternId);
+impl_intern_key!(InternedGeneratorId);
+
 /// This exists just for Chalk, because Chalk just has a single `FnDefId` where
 /// we have different IDs for struct and enum variant constructors.
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs
index 10ffde87eef..d351eb599ce 100644
--- a/crates/hir-ty/src/infer.rs
+++ b/crates/hir-ty/src/infer.rs
@@ -332,7 +332,7 @@ pub struct InferenceResult {
     /// unresolved or missing subpatterns or subpatterns of mismatched types.
     pub type_of_pat: ArenaMap<PatId, Ty>,
     type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch>,
-    /// Interned Unknown to return references to.
+    /// Interned common types to return references to.
     standard_types: InternedStandardTypes,
     /// Stores the types which were implicitly dereferenced in pattern binding modes.
     pub pat_adjustments: FxHashMap<PatId, Vec<Ty>>,
@@ -412,6 +412,8 @@ pub(crate) struct InferenceContext<'a> {
     /// closures, but currently this is the only field that will change there,
     /// so it doesn't make sense.
     return_ty: Ty,
+    /// The resume type and the yield type, respectively, of the generator being inferred.
+    resume_yield_tys: Option<(Ty, Ty)>,
     diverges: Diverges,
     breakables: Vec<BreakableContext>,
 }
@@ -476,6 +478,7 @@ impl<'a> InferenceContext<'a> {
             table: unify::InferenceTable::new(db, trait_env.clone()),
             trait_env,
             return_ty: TyKind::Error.intern(Interner), // set in collect_fn_signature
+            resume_yield_tys: None,
             db,
             owner,
             body,
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs
index 3ead929098b..094e460dbf7 100644
--- a/crates/hir-ty/src/infer/closure.rs
+++ b/crates/hir-ty/src/infer/closure.rs
@@ -12,6 +12,7 @@ use crate::{
 use super::{Expectation, InferenceContext};
 
 impl InferenceContext<'_> {
+    // This function handles both closures and generators.
     pub(super) fn deduce_closure_type_from_expectations(
         &mut self,
         closure_expr: ExprId,
@@ -27,6 +28,11 @@ impl InferenceContext<'_> {
         // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
         let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);
 
+        // Generators are not Fn* so return early.
+        if matches!(closure_ty.kind(Interner), TyKind::Generator(..)) {
+            return;
+        }
+
         // Deduction based on the expected `dyn Fn` is done separately.
         if let TyKind::Dyn(dyn_ty) = expected_ty.kind(Interner) {
             if let Some(sig) = self.deduce_sig_from_dyn_ty(dyn_ty) {
diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs
index f0382846a6a..e3d6be23e65 100644
--- a/crates/hir-ty/src/infer/expr.rs
+++ b/crates/hir-ty/src/infer/expr.rs
@@ -10,7 +10,10 @@ use chalk_ir::{
     cast::Cast, fold::Shift, DebruijnIndex, GenericArgData, Mutability, TyVariableKind,
 };
 use hir_def::{
-    expr::{ArithOp, Array, BinaryOp, CmpOp, Expr, ExprId, LabelId, Literal, Statement, UnaryOp},
+    expr::{
+        ArithOp, Array, BinaryOp, ClosureKind, CmpOp, Expr, ExprId, LabelId, Literal, Statement,
+        UnaryOp,
+    },
     generics::TypeOrConstParamData,
     path::{GenericArg, GenericArgs},
     resolver::resolver_for_expr,
@@ -216,7 +219,7 @@ impl<'a> InferenceContext<'a> {
                 self.diverges = Diverges::Maybe;
                 TyBuilder::unit()
             }
-            Expr::Closure { body, args, ret_type, arg_types, closure_kind: _ } => {
+            Expr::Closure { body, args, ret_type, arg_types, closure_kind } => {
                 assert_eq!(args.len(), arg_types.len());
 
                 let mut sig_tys = Vec::new();
@@ -244,20 +247,40 @@ impl<'a> InferenceContext<'a> {
                     ),
                 })
                 .intern(Interner);
-                let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into();
-                let closure_ty =
-                    TyKind::Closure(closure_id, Substitution::from1(Interner, sig_ty.clone()))
-                        .intern(Interner);
+
+                let (ty, resume_yield_tys) = if matches!(closure_kind, ClosureKind::Generator(_)) {
+                    // FIXME: report error when there are more than 1 parameter.
+                    let resume_ty = match sig_tys.first() {
+                        // When `sig_tys.len() == 1` the first type is the return type, not the
+                        // first parameter type.
+                        Some(ty) if sig_tys.len() > 1 => ty.clone(),
+                        _ => self.result.standard_types.unit.clone(),
+                    };
+                    let yield_ty = self.table.new_type_var();
+
+                    let subst = TyBuilder::subst_for_generator(self.db, self.owner)
+                        .push(resume_ty.clone())
+                        .push(yield_ty.clone())
+                        .push(ret_ty.clone())
+                        .build();
+
+                    let generator_id = self.db.intern_generator((self.owner, tgt_expr)).into();
+                    let generator_ty = TyKind::Generator(generator_id, subst).intern(Interner);
+
+                    (generator_ty, Some((resume_ty, yield_ty)))
+                } else {
+                    let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into();
+                    let closure_ty =
+                        TyKind::Closure(closure_id, Substitution::from1(Interner, sig_ty.clone()))
+                            .intern(Interner);
+
+                    (closure_ty, None)
+                };
 
                 // Eagerly try to relate the closure type with the expected
                 // type, otherwise we often won't have enough information to
                 // infer the body.
-                self.deduce_closure_type_from_expectations(
-                    tgt_expr,
-                    &closure_ty,
-                    &sig_ty,
-                    expected,
-                );
+                self.deduce_closure_type_from_expectations(tgt_expr, &ty, &sig_ty, expected);
 
                 // Now go through the argument patterns
                 for (arg_pat, arg_ty) in args.iter().zip(sig_tys) {
@@ -266,6 +289,8 @@ impl<'a> InferenceContext<'a> {
 
                 let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe);
                 let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone());
+                let prev_resume_yield_tys =
+                    mem::replace(&mut self.resume_yield_tys, resume_yield_tys);
 
                 self.with_breakable_ctx(BreakableKind::Border, self.err_ty(), None, |this| {
                     this.infer_expr_coerce(*body, &Expectation::has_type(ret_ty));
@@ -273,8 +298,9 @@ impl<'a> InferenceContext<'a> {
 
                 self.diverges = prev_diverges;
                 self.return_ty = prev_ret_ty;
+                self.resume_yield_tys = prev_resume_yield_tys;
 
-                closure_ty
+                ty
             }
             Expr::Call { callee, args, .. } => {
                 let callee_ty = self.infer_expr(*callee, &Expectation::none());
@@ -423,11 +449,18 @@ impl<'a> InferenceContext<'a> {
                 TyKind::Never.intern(Interner)
             }
             Expr::Yield { expr } => {
-                // FIXME: track yield type for coercion
-                if let Some(expr) = expr {
-                    self.infer_expr(*expr, &Expectation::none());
+                if let Some((resume_ty, yield_ty)) = self.resume_yield_tys.clone() {
+                    if let Some(expr) = expr {
+                        self.infer_expr_coerce(*expr, &Expectation::has_type(yield_ty));
+                    } else {
+                        let unit = self.result.standard_types.unit.clone();
+                        let _ = self.coerce(Some(tgt_expr), &unit, &yield_ty);
+                    }
+                    resume_ty
+                } else {
+                    // FIXME: report error (yield expr in non-generator)
+                    TyKind::Error.intern(Interner)
                 }
-                TyKind::Never.intern(Interner)
             }
             Expr::RecordLit { path, fields, spread, .. } => {
                 let (ty, def_id) = self.resolve_variant(path.as_deref(), false);
diff --git a/crates/hir-ty/src/mapping.rs b/crates/hir-ty/src/mapping.rs
index d765fee0e1f..f80fb39c1f8 100644
--- a/crates/hir-ty/src/mapping.rs
+++ b/crates/hir-ty/src/mapping.rs
@@ -103,6 +103,18 @@ impl From<crate::db::InternedClosureId> for chalk_ir::ClosureId<Interner> {
     }
 }
 
+impl From<chalk_ir::GeneratorId<Interner>> for crate::db::InternedGeneratorId {
+    fn from(id: chalk_ir::GeneratorId<Interner>) -> Self {
+        Self::from_intern_id(id.0)
+    }
+}
+
+impl From<crate::db::InternedGeneratorId> for chalk_ir::GeneratorId<Interner> {
+    fn from(id: crate::db::InternedGeneratorId) -> Self {
+        chalk_ir::GeneratorId(id.as_intern_id())
+    }
+}
+
 pub fn to_foreign_def_id(id: TypeAliasId) -> ForeignDefId {
     chalk_ir::ForeignDefId(salsa::InternKey::as_intern_id(&id))
 }