about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_ast_lowering/src/expr.rs266
-rw-r--r--compiler/rustc_ast_lowering/src/item.rs39
2 files changed, 81 insertions, 224 deletions
diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs
index ccc6644923a..e568da9bbc0 100644
--- a/compiler/rustc_ast_lowering/src/expr.rs
+++ b/compiler/rustc_ast_lowering/src/expr.rs
@@ -183,14 +183,6 @@ impl<'hir> LoweringContext<'_, 'hir> {
                     self.arena.alloc_from_iter(arms.iter().map(|x| self.lower_arm(x))),
                     hir::MatchSource::Normal,
                 ),
-                ExprKind::Gen(capture_clause, block, GenBlockKind::Async) => self.make_async_expr(
-                    *capture_clause,
-                    e.id,
-                    None,
-                    e.span,
-                    hir::CoroutineSource::Block,
-                    |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)),
-                ),
                 ExprKind::Await(expr, await_kw_span) => self.lower_expr_await(*await_kw_span, expr),
                 ExprKind::Closure(box Closure {
                     binder,
@@ -226,6 +218,22 @@ impl<'hir> LoweringContext<'_, 'hir> {
                         *fn_arg_span,
                     ),
                 },
+                ExprKind::Gen(capture_clause, block, genblock_kind) => {
+                    let desugaring_kind = match genblock_kind {
+                        GenBlockKind::Async => hir::CoroutineDesugaring::Async,
+                        GenBlockKind::Gen => hir::CoroutineDesugaring::Gen,
+                        GenBlockKind::AsyncGen => hir::CoroutineDesugaring::AsyncGen,
+                    };
+                    self.make_desugared_coroutine_expr(
+                        *capture_clause,
+                        e.id,
+                        None,
+                        e.span,
+                        desugaring_kind,
+                        hir::CoroutineSource::Block,
+                        |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)),
+                    )
+                }
                 ExprKind::Block(blk, opt_label) => {
                     let opt_label = self.lower_label(*opt_label);
                     hir::ExprKind::Block(self.lower_block(blk, opt_label.is_some()), opt_label)
@@ -313,23 +321,6 @@ impl<'hir> LoweringContext<'_, 'hir> {
                         rest,
                     )
                 }
-                ExprKind::Gen(capture_clause, block, GenBlockKind::Gen) => self.make_gen_expr(
-                    *capture_clause,
-                    e.id,
-                    None,
-                    e.span,
-                    hir::CoroutineSource::Block,
-                    |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)),
-                ),
-                ExprKind::Gen(capture_clause, block, GenBlockKind::AsyncGen) => self
-                    .make_async_gen_expr(
-                        *capture_clause,
-                        e.id,
-                        None,
-                        e.span,
-                        hir::CoroutineSource::Block,
-                        |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)),
-                    ),
                 ExprKind::Yield(opt_expr) => self.lower_expr_yield(e.span, opt_expr.as_deref()),
                 ExprKind::Err => {
                     hir::ExprKind::Err(self.dcx().span_delayed_bug(e.span, "lowered ExprKind::Err"))
@@ -612,213 +603,91 @@ impl<'hir> LoweringContext<'_, 'hir> {
         hir::Arm { hir_id, pat, guard, body, span }
     }
 
-    /// Lower an `async` construct to a coroutine that implements `Future`.
+    /// Lower/desugar a coroutine construct.
     ///
-    /// This results in:
-    ///
-    /// ```text
-    /// static move? |_task_context| -> <ret_ty> {
-    ///     <body>
-    /// }
-    /// ```
-    pub(super) fn make_async_expr(
-        &mut self,
-        capture_clause: CaptureBy,
-        closure_node_id: NodeId,
-        ret_ty: Option<hir::FnRetTy<'hir>>,
-        span: Span,
-        async_coroutine_source: hir::CoroutineSource,
-        body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
-    ) -> hir::ExprKind<'hir> {
-        let output = ret_ty.unwrap_or_else(|| hir::FnRetTy::DefaultReturn(self.lower_span(span)));
-
-        // Resume argument type: `ResumeTy`
-        let unstable_span = self.mark_span_with_reason(
-            DesugaringKind::Async,
-            self.lower_span(span),
-            Some(self.allow_gen_future.clone()),
-        );
-        let resume_ty = self.make_lang_item_qpath(hir::LangItem::ResumeTy, unstable_span);
-        let input_ty = hir::Ty {
-            hir_id: self.next_id(),
-            kind: hir::TyKind::Path(resume_ty),
-            span: unstable_span,
-        };
-
-        // The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`.
-        let fn_decl = self.arena.alloc(hir::FnDecl {
-            inputs: arena_vec![self; input_ty],
-            output,
-            c_variadic: false,
-            implicit_self: hir::ImplicitSelfKind::None,
-            lifetime_elision_allowed: false,
-        });
-
-        // Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
-        let (pat, task_context_hid) = self.pat_ident_binding_mode(
-            span,
-            Ident::with_dummy_span(sym::_task_context),
-            hir::BindingAnnotation::MUT,
-        );
-        let param = hir::Param {
-            hir_id: self.next_id(),
-            pat,
-            ty_span: self.lower_span(span),
-            span: self.lower_span(span),
-        };
-        let params = arena_vec![self; param];
-
-        let coroutine_kind =
-            hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, async_coroutine_source);
-        let body = self.lower_body(move |this| {
-            this.coroutine_kind = Some(coroutine_kind);
-
-            let old_ctx = this.task_context;
-            this.task_context = Some(task_context_hid);
-            let res = body(this);
-            this.task_context = old_ctx;
-            (params, res)
-        });
-
-        // `static |_task_context| -> <ret_ty> { body }`:
-        hir::ExprKind::Closure(self.arena.alloc(hir::Closure {
-            def_id: self.local_def_id(closure_node_id),
-            binder: hir::ClosureBinder::Default,
-            capture_clause,
-            bound_generic_params: &[],
-            fn_decl,
-            body,
-            fn_decl_span: self.lower_span(span),
-            fn_arg_span: None,
-            kind: hir::ClosureKind::Coroutine(coroutine_kind),
-            constness: hir::Constness::NotConst,
-        }))
-    }
-
-    /// Lower a `gen` construct to a generator that implements `Iterator`.
+    /// In particular, this creates the correct async resume argument and `_task_context`.
     ///
     /// This results in:
     ///
     /// ```text
-    /// static move? |()| -> () {
+    /// static move? |<_task_context?>| -> <return_ty> {
     ///     <body>
     /// }
     /// ```
-    pub(super) fn make_gen_expr(
+    pub(super) fn make_desugared_coroutine_expr(
         &mut self,
         capture_clause: CaptureBy,
         closure_node_id: NodeId,
-        _yield_ty: Option<hir::FnRetTy<'hir>>,
+        return_ty: Option<hir::FnRetTy<'hir>>,
         span: Span,
+        desugaring_kind: hir::CoroutineDesugaring,
         coroutine_source: hir::CoroutineSource,
         body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
     ) -> hir::ExprKind<'hir> {
-        let output = hir::FnRetTy::DefaultReturn(self.lower_span(span));
-
-        // The closure/generator `FnDecl` takes a single (resume) argument of type `input_ty`.
-        let fn_decl = self.arena.alloc(hir::FnDecl {
-            inputs: &[],
-            output,
-            c_variadic: false,
-            implicit_self: hir::ImplicitSelfKind::None,
-            lifetime_elision_allowed: false,
-        });
-
-        let coroutine_kind =
-            hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, coroutine_source);
-        let body = self.lower_body(move |this| {
-            this.coroutine_kind = Some(coroutine_kind);
-
-            let res = body(this);
-            (&[], res)
-        });
-
-        // `static |()| -> () { body }`:
-        hir::ExprKind::Closure(self.arena.alloc(hir::Closure {
-            def_id: self.local_def_id(closure_node_id),
-            binder: hir::ClosureBinder::Default,
-            capture_clause,
-            bound_generic_params: &[],
-            fn_decl,
-            body,
-            fn_decl_span: self.lower_span(span),
-            fn_arg_span: None,
-            kind: hir::ClosureKind::Coroutine(coroutine_kind),
-            constness: hir::Constness::NotConst,
-        }))
-    }
+        let coroutine_kind = hir::CoroutineKind::Desugared(desugaring_kind, coroutine_source);
+
+        // The `async` desugaring takes a resume argument and maintains a `task_context`,
+        // whereas a generator does not.
+        let (inputs, params, task_context): (&[_], &[_], _) = match desugaring_kind {
+            hir::CoroutineDesugaring::Async | hir::CoroutineDesugaring::AsyncGen => {
+                // Resume argument type: `ResumeTy`
+                let unstable_span = self.mark_span_with_reason(
+                    DesugaringKind::Async,
+                    self.lower_span(span),
+                    Some(self.allow_gen_future.clone()),
+                );
+                let resume_ty = self.make_lang_item_qpath(hir::LangItem::ResumeTy, unstable_span);
+                let input_ty = hir::Ty {
+                    hir_id: self.next_id(),
+                    kind: hir::TyKind::Path(resume_ty),
+                    span: unstable_span,
+                };
+                let inputs = arena_vec![self; input_ty];
 
-    /// Lower a `async gen` construct to a generator that implements `AsyncIterator`.
-    ///
-    /// This results in:
-    ///
-    /// ```text
-    /// static move? |_task_context| -> () {
-    ///     <body>
-    /// }
-    /// ```
-    pub(super) fn make_async_gen_expr(
-        &mut self,
-        capture_clause: CaptureBy,
-        closure_node_id: NodeId,
-        _yield_ty: Option<hir::FnRetTy<'hir>>,
-        span: Span,
-        async_coroutine_source: hir::CoroutineSource,
-        body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
-    ) -> hir::ExprKind<'hir> {
-        let output = hir::FnRetTy::DefaultReturn(self.lower_span(span));
+                // Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
+                let (pat, task_context_hid) = self.pat_ident_binding_mode(
+                    span,
+                    Ident::with_dummy_span(sym::_task_context),
+                    hir::BindingAnnotation::MUT,
+                );
+                let param = hir::Param {
+                    hir_id: self.next_id(),
+                    pat,
+                    ty_span: self.lower_span(span),
+                    span: self.lower_span(span),
+                };
+                let params = arena_vec![self; param];
 
-        // Resume argument type: `ResumeTy`
-        let unstable_span = self.mark_span_with_reason(
-            DesugaringKind::Async,
-            self.lower_span(span),
-            Some(self.allow_gen_future.clone()),
-        );
-        let resume_ty = self.make_lang_item_qpath(hir::LangItem::ResumeTy, unstable_span);
-        let input_ty = hir::Ty {
-            hir_id: self.next_id(),
-            kind: hir::TyKind::Path(resume_ty),
-            span: unstable_span,
+                (inputs, params, Some(task_context_hid))
+            }
+            hir::CoroutineDesugaring::Gen => (&[], &[], None),
         };
 
-        // The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`.
+        let output =
+            return_ty.unwrap_or_else(|| hir::FnRetTy::DefaultReturn(self.lower_span(span)));
+
         let fn_decl = self.arena.alloc(hir::FnDecl {
-            inputs: arena_vec![self; input_ty],
+            inputs,
             output,
             c_variadic: false,
             implicit_self: hir::ImplicitSelfKind::None,
             lifetime_elision_allowed: false,
         });
 
-        // Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
-        let (pat, task_context_hid) = self.pat_ident_binding_mode(
-            span,
-            Ident::with_dummy_span(sym::_task_context),
-            hir::BindingAnnotation::MUT,
-        );
-        let param = hir::Param {
-            hir_id: self.next_id(),
-            pat,
-            ty_span: self.lower_span(span),
-            span: self.lower_span(span),
-        };
-        let params = arena_vec![self; param];
-
-        let coroutine_kind = hir::CoroutineKind::Desugared(
-            hir::CoroutineDesugaring::AsyncGen,
-            async_coroutine_source,
-        );
         let body = self.lower_body(move |this| {
             this.coroutine_kind = Some(coroutine_kind);
 
             let old_ctx = this.task_context;
-            this.task_context = Some(task_context_hid);
+            if task_context.is_some() {
+                this.task_context = task_context;
+            }
             let res = body(this);
             this.task_context = old_ctx;
+
             (params, res)
         });
 
-        // `static |_task_context| -> <ret_ty> { body }`:
+        // `static |<_task_context?>| -> <return_ty> { <body> }`:
         hir::ExprKind::Closure(self.arena.alloc(hir::Closure {
             def_id: self.local_def_id(closure_node_id),
             binder: hir::ClosureBinder::Default,
@@ -1203,11 +1072,12 @@ impl<'hir> LoweringContext<'_, 'hir> {
                     None
                 };
 
-                let async_body = this.make_async_expr(
+                let async_body = this.make_desugared_coroutine_expr(
                     capture_clause,
                     inner_closure_id,
                     async_ret_ty,
                     body.span,
+                    hir::CoroutineDesugaring::Async,
                     hir::CoroutineSource::Closure,
                     |this| this.with_new_scopes(fn_decl_span, |this| this.lower_expr_mut(body)),
                 );
diff --git a/compiler/rustc_ast_lowering/src/item.rs b/compiler/rustc_ast_lowering/src/item.rs
index 3848f3b7782..45357aca533 100644
--- a/compiler/rustc_ast_lowering/src/item.rs
+++ b/compiler/rustc_ast_lowering/src/item.rs
@@ -1209,33 +1209,20 @@ impl<'hir> LoweringContext<'_, 'hir> {
 
                 this.expr_block(body)
             };
-            // FIXME(gen_blocks): Consider unifying the `make_*_expr` functions.
-            let coroutine_expr = match coroutine_kind {
-                CoroutineKind::Async { .. } => this.make_async_expr(
-                    CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
-                    closure_id,
-                    None,
-                    body.span,
-                    hir::CoroutineSource::Fn,
-                    mkbody,
-                ),
-                CoroutineKind::Gen { .. } => this.make_gen_expr(
-                    CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
-                    closure_id,
-                    None,
-                    body.span,
-                    hir::CoroutineSource::Fn,
-                    mkbody,
-                ),
-                CoroutineKind::AsyncGen { .. } => this.make_async_gen_expr(
-                    CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
-                    closure_id,
-                    None,
-                    body.span,
-                    hir::CoroutineSource::Fn,
-                    mkbody,
-                ),
+            let desugaring_kind = match coroutine_kind {
+                CoroutineKind::Async { .. } => hir::CoroutineDesugaring::Async,
+                CoroutineKind::Gen { .. } => hir::CoroutineDesugaring::Gen,
+                CoroutineKind::AsyncGen { .. } => hir::CoroutineDesugaring::AsyncGen,
             };
+            let coroutine_expr = this.make_desugared_coroutine_expr(
+                CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
+                closure_id,
+                None,
+                body.span,
+                desugaring_kind,
+                hir::CoroutineSource::Fn,
+                mkbody,
+            );
 
             let hir_id = this.lower_node_id(closure_id);
             this.maybe_forward_track_caller(body.span, fn_id, hir_id);