about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs143
1 files changed, 105 insertions, 38 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index fa56d59dd80..50d244d2831 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -224,7 +224,7 @@ struct SuspensionPoint<'tcx> {
 
 struct TransformVisitor<'tcx> {
     tcx: TyCtxt<'tcx>,
-    is_async_kind: bool,
+    coroutine_kind: hir::CoroutineKind,
     state_adt_ref: AdtDef<'tcx>,
     state_args: GenericArgsRef<'tcx>,
 
@@ -249,6 +249,47 @@ struct TransformVisitor<'tcx> {
 }
 
 impl<'tcx> TransformVisitor<'tcx> {
+    fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
+        let block = BasicBlock::new(body.basic_blocks.len());
+
+        let source_info = SourceInfo::outermost(body.span);
+
+        let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true);
+        assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
+        let statements = vec![Statement {
+            kind: StatementKind::Assign(Box::new((
+                Place::return_place(),
+                Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
+            ))),
+            source_info,
+        }];
+
+        body.basic_blocks_mut().push(BasicBlockData {
+            statements,
+            terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }),
+            is_cleanup: false,
+        });
+
+        block
+    }
+
+    fn coroutine_state_adt_and_variant_idx(
+        &self,
+        is_return: bool,
+    ) -> (AggregateKind<'tcx>, VariantIdx) {
+        let idx = VariantIdx::new(match (is_return, self.coroutine_kind) {
+            (true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete
+            (false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded
+            (true, hir::CoroutineKind::Async(_)) => 0,  // Poll::Ready
+            (false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending
+            (true, hir::CoroutineKind::Gen(_)) => 0,    // Option::None
+            (false, hir::CoroutineKind::Gen(_)) => 1,   // Option::Some
+        });
+
+        let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
+        (kind, idx)
+    }
+
     // Make a `CoroutineState` or `Poll` variant assignment.
     //
     // `core::ops::CoroutineState` only has single element tuple variants,
@@ -261,31 +302,44 @@ impl<'tcx> TransformVisitor<'tcx> {
         is_return: bool,
         statements: &mut Vec<Statement<'tcx>>,
     ) {
-        let idx = VariantIdx::new(match (is_return, self.is_async_kind) {
-            (true, false) => 1,  // CoroutineState::Complete
-            (false, false) => 0, // CoroutineState::Yielded
-            (true, true) => 0,   // Poll::Ready
-            (false, true) => 1,  // Poll::Pending
-        });
+        let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return);
 
-        let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
+        match self.coroutine_kind {
+            // `Poll::Pending`
+            CoroutineKind::Async(_) => {
+                if !is_return {
+                    assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
 
-        // `Poll::Pending`
-        if self.is_async_kind && idx == VariantIdx::new(1) {
-            assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
+                    // FIXME(swatinem): assert that `val` is indeed unit?
+                    statements.push(Statement {
+                        kind: StatementKind::Assign(Box::new((
+                            Place::return_place(),
+                            Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
+                        ))),
+                        source_info,
+                    });
+                    return;
+                }
+            }
+            // `Option::None`
+            CoroutineKind::Gen(_) => {
+                if is_return {
+                    assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
 
-            // FIXME(swatinem): assert that `val` is indeed unit?
-            statements.push(Statement {
-                kind: StatementKind::Assign(Box::new((
-                    Place::return_place(),
-                    Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
-                ))),
-                source_info,
-            });
-            return;
+                    statements.push(Statement {
+                        kind: StatementKind::Assign(Box::new((
+                            Place::return_place(),
+                            Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
+                        ))),
+                        source_info,
+                    });
+                    return;
+                }
+            }
+            CoroutineKind::Coroutine => {}
         }
 
-        // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)` or `CoroutineState::Complete(x)`
+        // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
         assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
 
         statements.push(Statement {
@@ -1263,10 +1317,13 @@ fn create_coroutine_resume_function<'tcx>(
     }
 
     if can_return {
-        cases.insert(
-            1,
-            (RETURNED, insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))),
-        );
+        let block = match coroutine_kind {
+            CoroutineKind::Async(_) | CoroutineKind::Coroutine => {
+                insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))
+            }
+            CoroutineKind::Gen(_) => transform.insert_none_ret_block(body),
+        };
+        cases.insert(1, (RETURNED, block));
     }
 
     insert_switch(body, cases, &transform, TerminatorKind::Unreachable);
@@ -1439,18 +1496,28 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         };
 
         let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
-        let (state_adt_ref, state_args) = if is_async_kind {
-            // Compute Poll<return_ty>
-            let poll_did = tcx.require_lang_item(LangItem::Poll, None);
-            let poll_adt_ref = tcx.adt_def(poll_did);
-            let poll_args = tcx.mk_args(&[body.return_ty().into()]);
-            (poll_adt_ref, poll_args)
-        } else {
-            // Compute CoroutineState<yield_ty, return_ty>
-            let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
-            let state_adt_ref = tcx.adt_def(state_did);
-            let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
-            (state_adt_ref, state_args)
+        let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
+            CoroutineKind::Async(_) => {
+                // Compute Poll<return_ty>
+                let poll_did = tcx.require_lang_item(LangItem::Poll, None);
+                let poll_adt_ref = tcx.adt_def(poll_did);
+                let poll_args = tcx.mk_args(&[body.return_ty().into()]);
+                (poll_adt_ref, poll_args)
+            }
+            CoroutineKind::Gen(_) => {
+                // Compute Option<yield_ty>
+                let option_did = tcx.require_lang_item(LangItem::Option, None);
+                let option_adt_ref = tcx.adt_def(option_did);
+                let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]);
+                (option_adt_ref, option_args)
+            }
+            CoroutineKind::Coroutine => {
+                // Compute CoroutineState<yield_ty, return_ty>
+                let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
+                let state_adt_ref = tcx.adt_def(state_did);
+                let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
+                (state_adt_ref, state_args)
+            }
         };
         let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
 
@@ -1518,7 +1585,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
         let mut transform = TransformVisitor {
             tcx,
-            is_async_kind,
+            coroutine_kind: body.coroutine_kind().unwrap(),
             state_adt_ref,
             state_args,
             remap,