about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-03-15 13:35:44 +0000
committerbors <bors@rust-lang.org>2023-03-15 13:35:44 +0000
commit8330f8efc6c79455d9217ba4a9cff16091ca8da5 (patch)
treeb75daf66e83bc1bed2ab84f261e73dd8c4f794b8
parentc16f0517922c220e5b0aaac2ff50459578d6dfec (diff)
parent3bf07a5f040be32cafe89977372f271673894a61 (diff)
downloadrust-8330f8efc6c79455d9217ba4a9cff16091ca8da5.tar.gz
rust-8330f8efc6c79455d9217ba4a9cff16091ca8da5.zip
Auto merge of #12958 - zachs18:async_closure, r=Veykril
fix: Fix return type of async closures.

May fix #12957
-rw-r--r--crates/hir-def/src/body/lower.rs2
-rw-r--r--crates/hir-def/src/body/pretty.rs10
-rw-r--r--crates/hir-def/src/expr.rs1
-rw-r--r--crates/hir-ty/src/infer/expr.rs73
-rw-r--r--crates/hir-ty/src/tests/traits.rs71
5 files changed, 114 insertions, 43 deletions
diff --git a/crates/hir-def/src/body/lower.rs b/crates/hir-def/src/body/lower.rs
index 83ce9b6acbb..fedaf395598 100644
--- a/crates/hir-def/src/body/lower.rs
+++ b/crates/hir-def/src/body/lower.rs
@@ -499,6 +499,8 @@ impl ExprCollector<'_> {
                         Movability::Movable
                     };
                     ClosureKind::Generator(movability)
+                } else if e.async_token().is_some() {
+                    ClosureKind::Async
                 } else {
                     ClosureKind::Closure
                 };
diff --git a/crates/hir-def/src/body/pretty.rs b/crates/hir-def/src/body/pretty.rs
index f8b159797e4..5a9b825a253 100644
--- a/crates/hir-def/src/body/pretty.rs
+++ b/crates/hir-def/src/body/pretty.rs
@@ -360,8 +360,14 @@ impl<'a> Printer<'a> {
                 w!(self, "]");
             }
             Expr::Closure { args, arg_types, ret_type, body, closure_kind } => {
-                if let ClosureKind::Generator(Movability::Static) = closure_kind {
-                    w!(self, "static ");
+                match closure_kind {
+                    ClosureKind::Generator(Movability::Static) => {
+                        w!(self, "static ");
+                    }
+                    ClosureKind::Async => {
+                        w!(self, "async ");
+                    }
+                    _ => (),
                 }
                 w!(self, "|");
                 for (i, (pat, ty)) in args.iter().zip(arg_types.iter()).enumerate() {
diff --git a/crates/hir-def/src/expr.rs b/crates/hir-def/src/expr.rs
index bbea608c55e..19fa6b25419 100644
--- a/crates/hir-def/src/expr.rs
+++ b/crates/hir-def/src/expr.rs
@@ -245,6 +245,7 @@ pub enum Expr {
 pub enum ClosureKind {
     Closure,
     Generator(Movability),
+    Async,
 }
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs
index 535189ff028..ee186673ee1 100644
--- a/crates/hir-ty/src/infer/expr.rs
+++ b/crates/hir-ty/src/infer/expr.rs
@@ -275,7 +275,23 @@ impl<'a> InferenceContext<'a> {
                     Some(type_ref) => self.make_ty(type_ref),
                     None => self.table.new_type_var(),
                 };
-                sig_tys.push(ret_ty.clone());
+                if let ClosureKind::Async = closure_kind {
+                    // Use the first type parameter as the output type of future.
+                    // existential type AsyncBlockImplTrait<InnerType>: Future<Output = InnerType>
+                    let impl_trait_id =
+                        crate::ImplTraitId::AsyncBlockTypeImplTrait(self.owner, *body);
+                    let opaque_ty_id = self.db.intern_impl_trait_id(impl_trait_id).into();
+                    sig_tys.push(
+                        TyKind::OpaqueType(
+                            opaque_ty_id,
+                            Substitution::from1(Interner, ret_ty.clone()),
+                        )
+                        .intern(Interner),
+                    );
+                } else {
+                    sig_tys.push(ret_ty.clone());
+                }
+
                 let sig_ty = TyKind::Function(FnPointer {
                     num_binders: 0,
                     sig: FnSig { abi: (), safety: chalk_ir::Safety::Safe, variadic: false },
@@ -286,33 +302,38 @@ impl<'a> InferenceContext<'a> {
                 })
                 .intern(Interner);
 
-                let (ty, resume_yield_tys) = if matches!(closure_kind, ClosureKind::Generator(_)) {
-                    // FIXME: report error when there are more than 1 parameter.
-                    let resume_ty = match sig_tys.first() {
-                        // When `sig_tys.len() == 1` the first type is the return type, not the
-                        // first parameter type.
-                        Some(ty) if sig_tys.len() > 1 => ty.clone(),
-                        _ => self.result.standard_types.unit.clone(),
-                    };
-                    let yield_ty = self.table.new_type_var();
-
-                    let subst = TyBuilder::subst_for_generator(self.db, self.owner)
-                        .push(resume_ty.clone())
-                        .push(yield_ty.clone())
-                        .push(ret_ty.clone())
-                        .build();
+                let (ty, resume_yield_tys) = match closure_kind {
+                    ClosureKind::Generator(_) => {
+                        // FIXME: report error when there are more than 1 parameter.
+                        let resume_ty = match sig_tys.first() {
+                            // When `sig_tys.len() == 1` the first type is the return type, not the
+                            // first parameter type.
+                            Some(ty) if sig_tys.len() > 1 => ty.clone(),
+                            _ => self.result.standard_types.unit.clone(),
+                        };
+                        let yield_ty = self.table.new_type_var();
+
+                        let subst = TyBuilder::subst_for_generator(self.db, self.owner)
+                            .push(resume_ty.clone())
+                            .push(yield_ty.clone())
+                            .push(ret_ty.clone())
+                            .build();
 
-                    let generator_id = self.db.intern_generator((self.owner, tgt_expr)).into();
-                    let generator_ty = TyKind::Generator(generator_id, subst).intern(Interner);
+                        let generator_id = self.db.intern_generator((self.owner, tgt_expr)).into();
+                        let generator_ty = TyKind::Generator(generator_id, subst).intern(Interner);
 
-                    (generator_ty, Some((resume_ty, yield_ty)))
-                } else {
-                    let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into();
-                    let closure_ty =
-                        TyKind::Closure(closure_id, Substitution::from1(Interner, sig_ty.clone()))
-                            .intern(Interner);
+                        (generator_ty, Some((resume_ty, yield_ty)))
+                    }
+                    ClosureKind::Closure | ClosureKind::Async => {
+                        let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into();
+                        let closure_ty = TyKind::Closure(
+                            closure_id,
+                            Substitution::from1(Interner, sig_ty.clone()),
+                        )
+                        .intern(Interner);
 
-                    (closure_ty, None)
+                        (closure_ty, None)
+                    }
                 };
 
                 // Eagerly try to relate the closure type with the expected
@@ -321,7 +342,7 @@ impl<'a> InferenceContext<'a> {
                 self.deduce_closure_type_from_expectations(tgt_expr, &ty, &sig_ty, expected);
 
                 // Now go through the argument patterns
-                for (arg_pat, arg_ty) in args.iter().zip(sig_tys) {
+                for (arg_pat, arg_ty) in args.iter().zip(&sig_tys) {
                     self.infer_top_pat(*arg_pat, &arg_ty);
                 }
 
diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs
index 015085bde45..da76d7fd83f 100644
--- a/crates/hir-ty/src/tests/traits.rs
+++ b/crates/hir-ty/src/tests/traits.rs
@@ -83,6 +83,46 @@ async fn test() {
 }
 
 #[test]
+fn infer_async_closure() {
+    check_types(
+        r#"
+//- minicore: future, option
+async fn test() {
+    let f = async move |x: i32| x + 42;
+    f;
+//  ^ |i32| -> impl Future<Output = i32>
+    let a = f(4);
+    a;
+//  ^ impl Future<Output = i32>
+    let x = a.await;
+    x;
+//  ^ i32
+    let f = async move || 42;
+    f;
+//  ^ || -> impl Future<Output = i32>
+    let a = f();
+    a;
+//  ^ impl Future<Output = i32>
+    let x = a.await;
+    x;
+//  ^ i32
+    let b = ((async move || {})()).await;
+    b;
+//  ^ ()
+    let c = async move || {
+        let y = None;
+        y
+    //  ^ Option<u64>
+    };
+    let _: Option<u64> = c().await;
+    c;
+//  ^ || -> impl Future<Output = Option<u64>>
+}
+"#,
+    );
+}
+
+#[test]
 fn auto_sized_async_block() {
     check_no_mismatches(
         r#"
@@ -493,29 +533,30 @@ fn tuple_struct_with_fn() {
         r#"
 struct S(fn(u32) -> u64);
 fn test() -> u64 {
-    let a = S(|i| 2*i);
+    let a = S(|i| 2*i as u64);
     let b = a.0(4);
     a.0(2)
 }"#,
         expect![[r#"
-            43..101 '{     ...0(2) }': u64
+            43..108 '{     ...0(2) }': u64
             53..54 'a': S
             57..58 'S': S(fn(u32) -> u64) -> S
-            57..67 'S(|i| 2*i)': S
-            59..66 '|i| 2*i': |u32| -> u64
+            57..74 'S(|i| ...s u64)': S
+            59..73 '|i| 2*i as u64': |u32| -> u64
             60..61 'i': u32
-            63..64 '2': u32
-            63..66 '2*i': u32
+            63..64 '2': u64
+            63..73 '2*i as u64': u64
             65..66 'i': u32
-            77..78 'b': u64
-            81..82 'a': S
-            81..84 'a.0': fn(u32) -> u64
-            81..87 'a.0(4)': u64
-            85..86 '4': u32
-            93..94 'a': S
-            93..96 'a.0': fn(u32) -> u64
-            93..99 'a.0(2)': u64
-            97..98 '2': u32
+            65..73 'i as u64': u64
+            84..85 'b': u64
+            88..89 'a': S
+            88..91 'a.0': fn(u32) -> u64
+            88..94 'a.0(4)': u64
+            92..93 '4': u32
+            100..101 'a': S
+            100..103 'a.0': fn(u32) -> u64
+            100..106 'a.0(2)': u64
+            104..105 '2': u32
         "#]],
     );
 }