about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs30
1 files changed, 28 insertions, 2 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index e0cfd7d0fcb..2b591abb05d 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -355,6 +355,26 @@ impl<'tcx> TransformVisitor<'tcx> {
                     )
                 }
             }
+            CoroutineKind::AsyncGen(_) => {
+                if is_return {
+                    let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
+                    let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
+                    let yield_ty = args.type_at(0);
+                    Rvalue::Use(Operand::Constant(Box::new(ConstOperand {
+                        span: source_info.span,
+                        const_: Const::Unevaluated(
+                            UnevaluatedConst::new(
+                                self.tcx.require_lang_item(LangItem::AsyncGenFinished, None),
+                                self.tcx.mk_args(&[yield_ty.into()]),
+                            ),
+                            self.old_yield_ty,
+                        ),
+                        user_ty: None,
+                    })))
+                } else {
+                    Rvalue::Use(val)
+                }
+            }
             CoroutineKind::Coroutine => {
                 let coroutine_state_def_id =
                     self.tcx.require_lang_item(LangItem::CoroutineState, None);
@@ -1373,7 +1393,8 @@ fn create_coroutine_resume_function<'tcx>(
 
     if can_return {
         let block = match coroutine_kind {
-            CoroutineKind::Async(_) | CoroutineKind::Coroutine => {
+            // FIXME(gen_blocks): Should `async gen` yield `None` when resumed once again?
+            CoroutineKind::Async(_) | CoroutineKind::AsyncGen(_) | CoroutineKind::Coroutine => {
                 insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))
             }
             CoroutineKind::Gen(_) => transform.insert_none_ret_block(body),
@@ -1562,6 +1583,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         };
 
         let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
+        let is_async_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::AsyncGen(_)));
         let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
         let new_ret_ty = match body.coroutine_kind().unwrap() {
             CoroutineKind::Async(_) => {
@@ -1578,6 +1600,10 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
                 let option_args = tcx.mk_args(&[old_yield_ty.into()]);
                 Ty::new_adt(tcx, option_adt_ref, option_args)
             }
+            CoroutineKind::AsyncGen(_) => {
+                // The yield ty is already `Poll<Option<yield_ty>>`
+                old_yield_ty
+            }
             CoroutineKind::Coroutine => {
                 // Compute CoroutineState<yield_ty, return_ty>
                 let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
@@ -1592,7 +1618,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
 
         // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
-        if is_async_kind {
+        if is_async_kind || is_async_gen_kind {
             transform_async_context(tcx, body);
         }