about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_ast_lowering/src/expr.rs52
-rw-r--r--compiler/rustc_hir_typeck/src/closure.rs7
-rw-r--r--tests/ui/coroutine/async-gen-yield-ty-is-unit.rs17
-rw-r--r--tests/ui/coroutine/return-types-diverge.rs20
-rw-r--r--tests/ui/coroutine/return-types.rs21
-rw-r--r--tests/ui/coroutine/return-types.stderr31
6 files changed, 125 insertions, 23 deletions
diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs
index 11b5131b8d7..704f124dbcb 100644
--- a/compiler/rustc_ast_lowering/src/expr.rs
+++ b/compiler/rustc_ast_lowering/src/expr.rs
@@ -917,12 +917,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
         let poll_expr = {
             let awaitee = self.expr_ident(span, awaitee_ident, awaitee_pat_hid);
             let ref_mut_awaitee = self.expr_mut_addr_of(span, awaitee);
-            let task_context = if let Some(task_context_hid) = self.task_context {
-                self.expr_ident_mut(span, task_context_ident, task_context_hid)
-            } else {
-                // Use of `await` outside of an async context, we cannot use `task_context` here.
-                self.expr_err(span, self.tcx.sess.span_delayed_bug(span, "no task_context hir id"))
+
+            let Some(task_context_hid) = self.task_context else {
+                unreachable!("use of `await` outside of an async context.");
             };
+
+            let task_context = self.expr_ident_mut(span, task_context_ident, task_context_hid);
+
             let new_unchecked = self.expr_call_lang_item_fn_mut(
                 span,
                 hir::LangItem::PinNewUnchecked,
@@ -991,16 +992,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
             );
             let yield_expr = self.arena.alloc(yield_expr);
 
-            if let Some(task_context_hid) = self.task_context {
-                let lhs = self.expr_ident(span, task_context_ident, task_context_hid);
-                let assign =
-                    self.expr(span, hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span)));
-                self.stmt_expr(span, assign)
-            } else {
-                // Use of `await` outside of an async context. Return `yield_expr` so that we can
-                // proceed with type checking.
-                self.stmt(span, hir::StmtKind::Semi(yield_expr))
-            }
+            let Some(task_context_hid) = self.task_context else {
+                unreachable!("use of `await` outside of an async context.");
+            };
+
+            let lhs = self.expr_ident(span, task_context_ident, task_context_hid);
+            let assign =
+                self.expr(span, hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span)));
+            self.stmt_expr(span, assign)
         };
 
         let loop_block = self.block_all(span, arena_vec![self; inner_match_stmt, yield_stmt], None);
@@ -1635,19 +1634,32 @@ impl<'hir> LoweringContext<'_, 'hir> {
             }
         };
 
-        let mut yielded =
+        let yielded =
             opt_expr.as_ref().map(|x| self.lower_expr(x)).unwrap_or_else(|| self.expr_unit(span));
 
         if is_async_gen {
-            // yield async_gen_ready($expr);
-            yielded = self.expr_call_lang_item_fn(
+            // `yield $expr` is transformed into `task_context = yield async_gen_ready($expr)`.
+            // This ensures that we store our resumed `ResumeContext` correctly, and also that
+            // the apparent value of the `yield` expression is `()`.
+            let wrapped_yielded = self.expr_call_lang_item_fn(
                 span,
                 hir::LangItem::AsyncGenReady,
                 std::slice::from_ref(yielded),
             );
-        }
+            let yield_expr = self.arena.alloc(
+                self.expr(span, hir::ExprKind::Yield(wrapped_yielded, hir::YieldSource::Yield)),
+            );
 
-        hir::ExprKind::Yield(yielded, hir::YieldSource::Yield)
+            let Some(task_context_hid) = self.task_context else {
+                unreachable!("use of `await` outside of an async context.");
+            };
+            let task_context_ident = Ident::with_dummy_span(sym::_task_context);
+            let lhs = self.expr_ident(span, task_context_ident, task_context_hid);
+
+            hir::ExprKind::Assign(lhs, yield_expr, self.lower_span(span))
+        } else {
+            hir::ExprKind::Yield(yielded, hir::YieldSource::Yield)
+        }
     }
 
     /// Desugar `ExprForLoop` from: `[opt_ident]: for <pat> in <head> <body>` into:
diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs
index 7e43d67587b..d19d304128a 100644
--- a/compiler/rustc_hir_typeck/src/closure.rs
+++ b/compiler/rustc_hir_typeck/src/closure.rs
@@ -650,9 +650,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                         },
                     )
                 }
-                // For a `gen {}` block created as a `gen fn` body, we need the return type to be
-                // ().
-                Some(hir::CoroutineKind::Gen(hir::CoroutineSource::Fn)) => self.tcx.types.unit,
+                // All `gen {}` and `async gen {}` must return unit.
+                Some(hir::CoroutineKind::Gen(_) | hir::CoroutineKind::AsyncGen(_)) => {
+                    self.tcx.types.unit
+                }
 
                 _ => astconv.ty_infer(None, decl.output.span()),
             },
diff --git a/tests/ui/coroutine/async-gen-yield-ty-is-unit.rs b/tests/ui/coroutine/async-gen-yield-ty-is-unit.rs
new file mode 100644
index 00000000000..aac74d3eacb
--- /dev/null
+++ b/tests/ui/coroutine/async-gen-yield-ty-is-unit.rs
@@ -0,0 +1,17 @@
+// compile-flags: --edition 2024 -Zunstable-options
+// check-pass
+
+#![feature(async_iterator, gen_blocks, noop_waker)]
+
+use std::{async_iter::AsyncIterator, pin::pin, task::{Context, Waker}};
+
+async gen fn gen_fn() -> &'static str {
+    yield "hello"
+}
+
+pub fn main() {
+    let async_iterator = pin!(gen_fn());
+    let waker = Waker::noop();
+    let ctx = &mut Context::from_waker(&waker);
+    async_iterator.poll_next(ctx);
+}
diff --git a/tests/ui/coroutine/return-types-diverge.rs b/tests/ui/coroutine/return-types-diverge.rs
new file mode 100644
index 00000000000..5f21c8cbf34
--- /dev/null
+++ b/tests/ui/coroutine/return-types-diverge.rs
@@ -0,0 +1,20 @@
+// compile-flags: --edition 2024 -Zunstable-options
+// check-pass
+
+#![feature(gen_blocks)]
+
+fn diverge() -> ! { loop {} }
+
+async gen fn async_gen_fn() -> i32 { diverge() }
+
+gen fn gen_fn() -> i32 { diverge() }
+
+fn async_gen_block() {
+    async gen { yield (); diverge() };
+}
+
+fn gen_block() {
+    gen { yield (); diverge() };
+}
+
+fn main() {}
diff --git a/tests/ui/coroutine/return-types.rs b/tests/ui/coroutine/return-types.rs
new file mode 100644
index 00000000000..3543d6293f7
--- /dev/null
+++ b/tests/ui/coroutine/return-types.rs
@@ -0,0 +1,21 @@
+// compile-flags: --edition 2024 -Zunstable-options
+
+#![feature(gen_blocks)]
+
+async gen fn async_gen_fn() -> i32 { 0 }
+//~^ ERROR mismatched types
+
+gen fn gen_fn() -> i32 { 0 }
+//~^ ERROR mismatched types
+
+fn async_gen_block() {
+    async gen { yield (); 1 };
+    //~^ ERROR mismatched types
+}
+
+fn gen_block() {
+    gen { yield (); 1 };
+    //~^ ERROR mismatched types
+}
+
+fn main() {}
diff --git a/tests/ui/coroutine/return-types.stderr b/tests/ui/coroutine/return-types.stderr
new file mode 100644
index 00000000000..7be96e538d9
--- /dev/null
+++ b/tests/ui/coroutine/return-types.stderr
@@ -0,0 +1,31 @@
+error[E0308]: mismatched types
+  --> $DIR/return-types.rs:5:38
+   |
+LL | async gen fn async_gen_fn() -> i32 { 0 }
+   |                                ---   ^ expected `()`, found integer
+   |                                |
+   |                                expected `()` because of return type
+
+error[E0308]: mismatched types
+  --> $DIR/return-types.rs:8:26
+   |
+LL | gen fn gen_fn() -> i32 { 0 }
+   |                    ---   ^ expected `()`, found integer
+   |                    |
+   |                    expected `()` because of return type
+
+error[E0308]: mismatched types
+  --> $DIR/return-types.rs:12:27
+   |
+LL |     async gen { yield (); 1 };
+   |                           ^ expected `()`, found integer
+
+error[E0308]: mismatched types
+  --> $DIR/return-types.rs:17:21
+   |
+LL |     gen { yield (); 1 };
+   |                     ^ expected `()`, found integer
+
+error: aborting due to 4 previous errors
+
+For more information about this error, try `rustc --explain E0308`.