about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Holk <ericholk@microsoft.com>2023-11-29 12:07:43 -0800
committerEric Holk <ericholk@microsoft.com>2023-12-04 11:23:05 -0800
commitc104f3b629cfcac35802a899478756abf24ee7c1 (patch)
tree3519b9e24059a25b90c68fc1c42697e7b11b2b8c
parentbc0d10d4b0fefccda6aae0338a1935d76314736b (diff)
downloadrust-c104f3b629cfcac35802a899478756abf24ee7c1.tar.gz
rust-c104f3b629cfcac35802a899478756abf24ee7c1.zip
Lower return types for gen fn to impl Iterator
-rw-r--r--compiler/rustc_ast_lowering/src/item.rs146
-rw-r--r--compiler/rustc_ast_lowering/src/lib.rs75
-rw-r--r--compiler/rustc_ast_lowering/src/path.rs9
-rw-r--r--compiler/rustc_hir/src/hir.rs2
-rw-r--r--compiler/rustc_hir_typeck/src/closure.rs6
-rw-r--r--compiler/rustc_parse/src/parser/item.rs4
-rw-r--r--compiler/rustc_resolve/src/def_collector.rs5
7 files changed, 167 insertions, 80 deletions
diff --git a/compiler/rustc_ast_lowering/src/item.rs b/compiler/rustc_ast_lowering/src/item.rs
index f0f3e2c3c74..852555048fe 100644
--- a/compiler/rustc_ast_lowering/src/item.rs
+++ b/compiler/rustc_ast_lowering/src/item.rs
@@ -1,3 +1,5 @@
+use crate::FnReturnTransformation;
+
 use super::errors::{InvalidAbi, InvalidAbiReason, InvalidAbiSuggestion, MisplacedRelaxTraitBound};
 use super::ResolverAstLoweringExt;
 use super::{AstOwner, ImplTraitContext, ImplTraitPosition};
@@ -207,13 +209,33 @@ impl<'hir> LoweringContext<'_, 'hir> {
                     // only cares about the input argument patterns in the function
                     // declaration (decl), not the return types.
                     let asyncness = header.asyncness;
-                    let body_id =
-                        this.lower_maybe_async_body(span, hir_id, decl, asyncness, body.as_deref());
+                    let genness = header.genness;
+                    let body_id = this.lower_maybe_coroutine_body(
+                        span,
+                        hir_id,
+                        decl,
+                        asyncness,
+                        genness,
+                        body.as_deref(),
+                    );
 
                     let itctx = ImplTraitContext::Universal;
                     let (generics, decl) =
                         this.lower_generics(generics, header.constness, id, &itctx, |this| {
-                            let ret_id = asyncness.opt_return_id();
+                            let ret_id = asyncness
+                                .opt_return_id()
+                                .map(|(node_id, span)| {
+                                    crate::FnReturnTransformation::Async(node_id, span)
+                                })
+                                .or_else(|| match genness {
+                                    Gen::Yes { span, closure_id: _, return_impl_trait_id } => {
+                                        Some(crate::FnReturnTransformation::Iterator(
+                                            return_impl_trait_id,
+                                            span,
+                                        ))
+                                    }
+                                    _ => None,
+                                });
                             this.lower_fn_decl(decl, id, *fn_sig_span, FnDeclKind::Fn, ret_id)
                         });
                     let sig = hir::FnSig {
@@ -732,20 +754,31 @@ impl<'hir> LoweringContext<'_, 'hir> {
                     sig,
                     i.id,
                     FnDeclKind::Trait,
-                    asyncness.opt_return_id(),
+                    asyncness
+                        .opt_return_id()
+                        .map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
                 );
                 (generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Required(names)), false)
             }
             AssocItemKind::Fn(box Fn { sig, generics, body: Some(body), .. }) => {
                 let asyncness = sig.header.asyncness;
-                let body_id =
-                    self.lower_maybe_async_body(i.span, hir_id, &sig.decl, asyncness, Some(body));
+                let genness = sig.header.genness;
+                let body_id = self.lower_maybe_coroutine_body(
+                    i.span,
+                    hir_id,
+                    &sig.decl,
+                    asyncness,
+                    genness,
+                    Some(body),
+                );
                 let (generics, sig) = self.lower_method_sig(
                     generics,
                     sig,
                     i.id,
                     FnDeclKind::Trait,
-                    asyncness.opt_return_id(),
+                    asyncness
+                        .opt_return_id()
+                        .map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
                 );
                 (generics, hir::TraitItemKind::Fn(sig, hir::TraitFn::Provided(body_id)), true)
             }
@@ -835,11 +868,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
             ),
             AssocItemKind::Fn(box Fn { sig, generics, body, .. }) => {
                 let asyncness = sig.header.asyncness;
-                let body_id = self.lower_maybe_async_body(
+                let genness = sig.header.genness;
+                let body_id = self.lower_maybe_coroutine_body(
                     i.span,
                     hir_id,
                     &sig.decl,
                     asyncness,
+                    genness,
                     body.as_deref(),
                 );
                 let (generics, sig) = self.lower_method_sig(
@@ -847,7 +882,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
                     sig,
                     i.id,
                     if self.is_in_trait_impl { FnDeclKind::Impl } else { FnDeclKind::Inherent },
-                    asyncness.opt_return_id(),
+                    asyncness
+                        .opt_return_id()
+                        .map(|(node_id, span)| crate::FnReturnTransformation::Async(node_id, span)),
                 );
 
                 (generics, hir::ImplItemKind::Fn(sig, body_id))
@@ -1011,16 +1048,22 @@ impl<'hir> LoweringContext<'_, 'hir> {
         })
     }
 
-    fn lower_maybe_async_body(
+    /// Takes what may be the body of an `async fn` or a `gen fn` and wraps it in an `async {}` or
+    /// `gen {}` block as appropriate.
+    fn lower_maybe_coroutine_body(
         &mut self,
         span: Span,
         fn_id: hir::HirId,
         decl: &FnDecl,
         asyncness: Async,
+        genness: Gen,
         body: Option<&Block>,
     ) -> hir::BodyId {
-        let (closure_id, body) = match (asyncness, body) {
-            (Async::Yes { closure_id, .. }, Some(body)) => (closure_id, body),
+        let (closure_id, body) = match (asyncness, genness, body) {
+            // FIXME(eholk): do something reasonable for `async gen fn`. Probably that's an error
+            // for now since it's not supported.
+            (Async::Yes { closure_id, .. }, _, Some(body))
+            | (_, Gen::Yes { closure_id, .. }, Some(body)) => (closure_id, body),
             _ => return self.lower_fn_body_block(span, decl, body),
         };
 
@@ -1163,44 +1206,55 @@ impl<'hir> LoweringContext<'_, 'hir> {
                 parameters.push(new_parameter);
             }
 
-            let async_expr = this.make_async_expr(
-                CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
-                closure_id,
-                None,
-                body.span,
-                hir::CoroutineSource::Fn,
-                |this| {
-                    // Create a block from the user's function body:
-                    let user_body = this.lower_block_expr(body);
+            let mkbody = |this: &mut LoweringContext<'_, 'hir>| {
+                // Create a block from the user's function body:
+                let user_body = this.lower_block_expr(body);
 
-                    // Transform into `drop-temps { <user-body> }`, an expression:
-                    let desugared_span =
-                        this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None);
-                    let user_body =
-                        this.expr_drop_temps(desugared_span, this.arena.alloc(user_body));
+                // Transform into `drop-temps { <user-body> }`, an expression:
+                let desugared_span =
+                    this.mark_span_with_reason(DesugaringKind::Async, user_body.span, None);
+                let user_body = this.expr_drop_temps(desugared_span, this.arena.alloc(user_body));
 
-                    // As noted above, create the final block like
-                    //
-                    // ```
-                    // {
-                    //   let $param_pattern = $raw_param;
-                    //   ...
-                    //   drop-temps { <user-body> }
-                    // }
-                    // ```
-                    let body = this.block_all(
-                        desugared_span,
-                        this.arena.alloc_from_iter(statements),
-                        Some(user_body),
-                    );
+                // As noted above, create the final block like
+                //
+                // ```
+                // {
+                //   let $param_pattern = $raw_param;
+                //   ...
+                //   drop-temps { <user-body> }
+                // }
+                // ```
+                let body = this.block_all(
+                    desugared_span,
+                    this.arena.alloc_from_iter(statements),
+                    Some(user_body),
+                );
 
-                    this.expr_block(body)
-                },
-            );
+                this.expr_block(body)
+            };
+            let coroutine_expr = match (asyncness, genness) {
+                (Async::Yes { .. }, _) => this.make_async_expr(
+                    CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
+                    closure_id,
+                    None,
+                    body.span,
+                    hir::CoroutineSource::Fn,
+                    mkbody,
+                ),
+                (_, Gen::Yes { .. }) => this.make_gen_expr(
+                    CaptureBy::Value { move_kw: rustc_span::DUMMY_SP },
+                    closure_id,
+                    None,
+                    body.span,
+                    hir::CoroutineSource::Fn,
+                    mkbody,
+                ),
+                _ => unreachable!("we must have either an async fn or a gen fn"),
+            };
 
             let hir_id = this.lower_node_id(closure_id);
             this.maybe_forward_track_caller(body.span, fn_id, hir_id);
-            let expr = hir::Expr { hir_id, kind: async_expr, span: this.lower_span(body.span) };
+            let expr = hir::Expr { hir_id, kind: coroutine_expr, span: this.lower_span(body.span) };
 
             (this.arena.alloc_from_iter(parameters), expr)
         })
@@ -1212,13 +1266,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
         sig: &FnSig,
         id: NodeId,
         kind: FnDeclKind,
-        is_async: Option<(NodeId, Span)>,
+        transform_return_type: Option<FnReturnTransformation>,
     ) -> (&'hir hir::Generics<'hir>, hir::FnSig<'hir>) {
         let header = self.lower_fn_header(sig.header);
         let itctx = ImplTraitContext::Universal;
         let (generics, decl) =
             self.lower_generics(generics, sig.header.constness, id, &itctx, |this| {
-                this.lower_fn_decl(&sig.decl, id, sig.span, kind, is_async)
+                this.lower_fn_decl(&sig.decl, id, sig.span, kind, transform_return_type)
             });
         (generics, hir::FnSig { header, decl, span: self.lower_span(sig.span) })
     }
diff --git a/compiler/rustc_ast_lowering/src/lib.rs b/compiler/rustc_ast_lowering/src/lib.rs
index aa8ad978451..96a413e9f73 100644
--- a/compiler/rustc_ast_lowering/src/lib.rs
+++ b/compiler/rustc_ast_lowering/src/lib.rs
@@ -493,6 +493,21 @@ enum ParenthesizedGenericArgs {
     Err,
 }
 
+/// Describes a return type transformation that can be performed by `LoweringContext::lower_fn_decl`
+#[derive(Debug)]
+enum FnReturnTransformation {
+    /// Replaces a return type `T` with `impl Future<Output = T>`.
+    ///
+    /// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the
+    /// `async` keyword.
+    Async(NodeId, Span),
+    /// Replaces a return type `T` with `impl Iterator<Item = T>`.
+    ///
+    /// The `NodeId` is the ID of the return type `impl Trait` item, and the `Span` points to the
+    /// `gen` keyword.
+    Iterator(NodeId, Span),
+}
+
 impl<'a, 'hir> LoweringContext<'a, 'hir> {
     fn create_def(
         &mut self,
@@ -1778,13 +1793,15 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         }))
     }
 
-    // Lowers a function declaration.
-    //
-    // `decl`: the unlowered (AST) function declaration.
-    // `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given `NodeId`.
-    // `make_ret_async`: if `Some`, converts `-> T` into `-> impl Future<Output = T>` in the
-    //      return type. This is used for `async fn` declarations. The `NodeId` is the ID of the
-    //      return type `impl Trait` item, and the `Span` points to the `async` keyword.
+    /// Lowers a function declaration.
+    ///
+    /// `decl`: the unlowered (AST) function declaration.
+    ///
+    /// `fn_node_id`: `impl Trait` arguments are lowered into generic parameters on the given
+    /// `NodeId`.
+    ///
+    /// `transform_return_type`: if `Some`, applies some conversion to the return type, such as is
+    /// needed for `async fn` and `gen fn`. See [`FnReturnTransformation`] for more details.
     #[instrument(level = "debug", skip(self))]
     fn lower_fn_decl(
         &mut self,
@@ -1792,7 +1809,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         fn_node_id: NodeId,
         fn_span: Span,
         kind: FnDeclKind,
-        make_ret_async: Option<(NodeId, Span)>,
+        transform_return_type: Option<FnReturnTransformation>,
     ) -> &'hir hir::FnDecl<'hir> {
         let c_variadic = decl.c_variadic();
 
@@ -1821,11 +1838,12 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             self.lower_ty_direct(&param.ty, &itctx)
         }));
 
-        let output = if let Some((ret_id, _span)) = make_ret_async {
-            let fn_def_id = self.local_def_id(fn_node_id);
-            self.lower_async_fn_ret_ty(&decl.output, fn_def_id, ret_id, kind, fn_span)
-        } else {
-            match &decl.output {
+        let output = match transform_return_type {
+            Some(transform) => {
+                let fn_def_id = self.local_def_id(fn_node_id);
+                self.lower_coroutine_fn_ret_ty(&decl.output, fn_def_id, transform, kind, fn_span)
+            }
+            None => match &decl.output {
                 FnRetTy::Ty(ty) => {
                     let context = if kind.return_impl_trait_allowed() {
                         let fn_def_id = self.local_def_id(fn_node_id);
@@ -1849,7 +1867,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
                     hir::FnRetTy::Return(self.lower_ty(ty, &context))
                 }
                 FnRetTy::Default(span) => hir::FnRetTy::DefaultReturn(self.lower_span(*span)),
-            }
+            },
         };
 
         self.arena.alloc(hir::FnDecl {
@@ -1888,17 +1906,22 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
     // `fn_node_id`: `NodeId` of the parent function (used to create child impl trait definition)
     // `opaque_ty_node_id`: `NodeId` of the opaque `impl Trait` type that should be created
     #[instrument(level = "debug", skip(self))]
-    fn lower_async_fn_ret_ty(
+    fn lower_coroutine_fn_ret_ty(
         &mut self,
         output: &FnRetTy,
         fn_def_id: LocalDefId,
-        opaque_ty_node_id: NodeId,
+        transform: FnReturnTransformation,
         fn_kind: FnDeclKind,
         fn_span: Span,
     ) -> hir::FnRetTy<'hir> {
         let span = self.lower_span(fn_span);
         let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::Async, span, None);
 
+        let opaque_ty_node_id = match transform {
+            FnReturnTransformation::Async(opaque_ty_node_id, _)
+            | FnReturnTransformation::Iterator(opaque_ty_node_id, _) => opaque_ty_node_id,
+        };
+
         let captured_lifetimes: Vec<_> = self
             .resolver
             .take_extra_lifetime_params(opaque_ty_node_id)
@@ -1914,8 +1937,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             span,
             opaque_ty_span,
             |this| {
-                let future_bound = this.lower_async_fn_output_type_to_future_bound(
+                let future_bound = this.lower_coroutine_fn_output_type_to_future_bound(
                     output,
+                    transform,
                     span,
                     ImplTraitContext::ReturnPositionOpaqueTy {
                         origin: hir::OpaqueTyOrigin::FnReturn(fn_def_id),
@@ -1931,9 +1955,10 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
     }
 
     /// Transforms `-> T` into `Future<Output = T>`.
-    fn lower_async_fn_output_type_to_future_bound(
+    fn lower_coroutine_fn_output_type_to_future_bound(
         &mut self,
         output: &FnRetTy,
+        transform: FnReturnTransformation,
         span: Span,
         nested_impl_trait_context: ImplTraitContext,
     ) -> hir::GenericBound<'hir> {
@@ -1948,17 +1973,23 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             FnRetTy::Default(ret_ty_span) => self.arena.alloc(self.ty_tup(*ret_ty_span, &[])),
         };
 
-        // "<Output = T>"
+        // "<Output|Item = T>"
+        let (symbol, lang_item) = match transform {
+            FnReturnTransformation::Async(..) => (hir::FN_OUTPUT_NAME, hir::LangItem::Future),
+            FnReturnTransformation::Iterator(..) => {
+                (hir::ITERATOR_ITEM_NAME, hir::LangItem::Iterator)
+            }
+        };
+
         let future_args = self.arena.alloc(hir::GenericArgs {
             args: &[],
-            bindings: arena_vec![self; self.output_ty_binding(span, output_ty)],
+            bindings: arena_vec![self; self.assoc_ty_binding(symbol, span, output_ty)],
             parenthesized: hir::GenericArgsParentheses::No,
             span_ext: DUMMY_SP,
         });
 
         hir::GenericBound::LangItemTrait(
-            // ::std::future::Future<future_params>
-            hir::LangItem::Future,
+            lang_item,
             self.lower_span(span),
             self.next_id(),
             future_args,
diff --git a/compiler/rustc_ast_lowering/src/path.rs b/compiler/rustc_ast_lowering/src/path.rs
index db8ca7c3643..accb74d7a52 100644
--- a/compiler/rustc_ast_lowering/src/path.rs
+++ b/compiler/rustc_ast_lowering/src/path.rs
@@ -389,7 +389,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             FnRetTy::Default(_) => self.arena.alloc(self.ty_tup(*span, &[])),
         };
         let args = smallvec![GenericArg::Type(self.arena.alloc(self.ty_tup(*inputs_span, inputs)))];
-        let binding = self.output_ty_binding(output_ty.span, output_ty);
+        let binding = self.assoc_ty_binding(hir::FN_OUTPUT_NAME, output_ty.span, output_ty);
         (
             GenericArgsCtor {
                 args,
@@ -401,13 +401,14 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         )
     }
 
-    /// An associated type binding `Output = $ty`.
-    pub(crate) fn output_ty_binding(
+    /// An associated type binding `$symbol = $ty`.
+    pub(crate) fn assoc_ty_binding(
         &mut self,
+        symbol: rustc_span::Symbol,
         span: Span,
         ty: &'hir hir::Ty<'hir>,
     ) -> hir::TypeBinding<'hir> {
-        let ident = Ident::with_dummy_span(hir::FN_OUTPUT_NAME);
+        let ident = Ident::with_dummy_span(symbol);
         let kind = hir::TypeBindingKind::Equality { term: ty.into() };
         let args = arena_vec![self;];
         let bindings = arena_vec![self;];
diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs
index 81733d8f64e..0d9e174e37a 100644
--- a/compiler/rustc_hir/src/hir.rs
+++ b/compiler/rustc_hir/src/hir.rs
@@ -2255,6 +2255,8 @@ pub enum ImplItemKind<'hir> {
 
 /// The name of the associated type for `Fn` return types.
 pub const FN_OUTPUT_NAME: Symbol = sym::Output;
+/// The name of the associated type for `Iterator` item types.
+pub const ITERATOR_ITEM_NAME: Symbol = sym::Item;
 
 /// Bind a type to an associated type (i.e., `A = Foo`).
 ///
diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs
index cb36d510149..f0bb18df48c 100644
--- a/compiler/rustc_hir_typeck/src/closure.rs
+++ b/compiler/rustc_hir_typeck/src/closure.rs
@@ -651,9 +651,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                         },
                     )
                 }
-                Some(hir::CoroutineKind::Gen(hir::CoroutineSource::Fn)) => {
-                    todo!("gen closures do not exist yet")
-                }
+                // For a `gen {}` block created as a `gen fn` body, we need the return type to be
+                // ().
+                Some(hir::CoroutineKind::Gen(hir::CoroutineSource::Fn)) => self.tcx.types.unit,
 
                 _ => astconv.ty_infer(None, decl.output.span()),
             },
diff --git a/compiler/rustc_parse/src/parser/item.rs b/compiler/rustc_parse/src/parser/item.rs
index 55f7310681f..4f01ab02c04 100644
--- a/compiler/rustc_parse/src/parser/item.rs
+++ b/compiler/rustc_parse/src/parser/item.rs
@@ -2410,10 +2410,6 @@ impl<'a> Parser<'a> {
             }
         }
 
-        if let Gen::Yes { span, .. } = genness {
-            self.sess.emit_err(errors::GenFn { span });
-        }
-
         if !self.eat_keyword_case(kw::Fn, case) {
             // It is possible for `expect_one_of` to recover given the contents of
             // `self.expected_tokens`, therefore, do not use `self.unexpected()` which doesn't
diff --git a/compiler/rustc_resolve/src/def_collector.rs b/compiler/rustc_resolve/src/def_collector.rs
index 647c92785e1..306492eaa96 100644
--- a/compiler/rustc_resolve/src/def_collector.rs
+++ b/compiler/rustc_resolve/src/def_collector.rs
@@ -156,7 +156,10 @@ impl<'a, 'b, 'tcx> visit::Visitor<'a> for DefCollector<'a, 'b, 'tcx> {
 
     fn visit_fn(&mut self, fn_kind: FnKind<'a>, span: Span, _: NodeId) {
         if let FnKind::Fn(_, _, sig, _, generics, body) = fn_kind {
-            if let Async::Yes { closure_id, .. } = sig.header.asyncness {
+            // FIXME(eholk): handle `async gen fn`
+            if let (Async::Yes { closure_id, .. }, _) | (_, Gen::Yes { closure_id, .. }) =
+                (sig.header.asyncness, sig.header.genness)
+            {
                 self.visit_generics(generics);
 
                 // For async functions, we need to create their inner defs inside of a