about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Holk <ericholk@microsoft.com>2023-12-04 13:43:38 -0800
committerEric Holk <ericholk@microsoft.com>2023-12-04 14:33:46 -0800
commit50ef8006eb68682471894c99b49eb4e39b48c745 (patch)
treee64b78caa9226fb5267e38afffebf0813f228a33
parent26f9954971a2895580e02578fe18bc6f9adea3c9 (diff)
downloadrust-50ef8006eb68682471894c99b49eb4e39b48c745.tar.gz
rust-50ef8006eb68682471894c99b49eb4e39b48c745.zip
Address code review feedback
-rw-r--r--compiler/rustc_ast/src/mut_visit.rs5
-rw-r--r--compiler/rustc_ast_lowering/src/lib.rs14
-rw-r--r--compiler/rustc_ast_lowering/src/path.rs6
-rw-r--r--compiler/rustc_ast_passes/src/ast_validation.rs2
-rw-r--r--compiler/rustc_builtin_macros/src/test.rs4
-rw-r--r--compiler/rustc_parse/src/parser/item.rs2
-rw-r--r--compiler/rustc_parse/src/parser/ty.rs2
-rw-r--r--compiler/rustc_resolve/src/late.rs14
-rw-r--r--src/tools/rustfmt/src/closures.rs21
-rw-r--r--tests/ui/coroutine/gen_fn_lifetime_capture.rs19
10 files changed, 53 insertions, 36 deletions
diff --git a/compiler/rustc_ast/src/mut_visit.rs b/compiler/rustc_ast/src/mut_visit.rs
index c6aa7a6ae37..c6a31fbdbc3 100644
--- a/compiler/rustc_ast/src/mut_visit.rs
+++ b/compiler/rustc_ast/src/mut_visit.rs
@@ -873,8 +873,9 @@ pub fn noop_visit_closure_binder<T: MutVisitor>(binder: &mut ClosureBinder, vis:
 
 pub fn noop_visit_coro_kind<T: MutVisitor>(coro_kind: &mut CoroutineKind, vis: &mut T) {
     match coro_kind {
-        CoroutineKind::Async { span: _, closure_id, return_impl_trait_id }
-        | CoroutineKind::Gen { span: _, closure_id, return_impl_trait_id } => {
+        CoroutineKind::Async { span, closure_id, return_impl_trait_id }
+        | CoroutineKind::Gen { span, closure_id, return_impl_trait_id } => {
+            vis.visit_span(span);
             vis.visit_id(closure_id);
             vis.visit_id(return_impl_trait_id);
         }
diff --git a/compiler/rustc_ast_lowering/src/lib.rs b/compiler/rustc_ast_lowering/src/lib.rs
index 21a33d137b8..d435082e121 100644
--- a/compiler/rustc_ast_lowering/src/lib.rs
+++ b/compiler/rustc_ast_lowering/src/lib.rs
@@ -1922,7 +1922,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             span,
             opaque_ty_span,
             |this| {
-                let future_bound = this.lower_coroutine_fn_output_type_to_future_bound(
+                let bound = this.lower_coroutine_fn_output_type_to_bound(
                     output,
                     coro,
                     span,
@@ -1931,7 +1931,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
                         fn_kind,
                     },
                 );
-                arena_vec![this; future_bound]
+                arena_vec![this; bound]
             },
         );
 
@@ -1940,7 +1940,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
     }
 
     /// Transforms `-> T` into `Future<Output = T>`.
-    fn lower_coroutine_fn_output_type_to_future_bound(
+    fn lower_coroutine_fn_output_type_to_bound(
         &mut self,
         output: &FnRetTy,
         coro: CoroutineKind,
@@ -1958,21 +1958,21 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             FnRetTy::Default(ret_ty_span) => self.arena.alloc(self.ty_tup(*ret_ty_span, &[])),
         };
 
-        // "<Output|Item = T>"
-        let (symbol, lang_item) = match coro {
+        // "<$assoc_ty_name = T>"
+        let (assoc_ty_name, trait_lang_item) = match coro {
             CoroutineKind::Async { .. } => (hir::FN_OUTPUT_NAME, hir::LangItem::Future),
             CoroutineKind::Gen { .. } => (hir::ITERATOR_ITEM_NAME, hir::LangItem::Iterator),
         };
 
         let future_args = self.arena.alloc(hir::GenericArgs {
             args: &[],
-            bindings: arena_vec![self; self.assoc_ty_binding(symbol, span, output_ty)],
+            bindings: arena_vec![self; self.assoc_ty_binding(assoc_ty_name, span, output_ty)],
             parenthesized: hir::GenericArgsParentheses::No,
             span_ext: DUMMY_SP,
         });
 
         hir::GenericBound::LangItemTrait(
-            lang_item,
+            trait_lang_item,
             self.lower_span(span),
             self.next_id(),
             future_args,
diff --git a/compiler/rustc_ast_lowering/src/path.rs b/compiler/rustc_ast_lowering/src/path.rs
index accb74d7a52..7ab0805d086 100644
--- a/compiler/rustc_ast_lowering/src/path.rs
+++ b/compiler/rustc_ast_lowering/src/path.rs
@@ -401,14 +401,14 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         )
     }
 
-    /// An associated type binding `$symbol = $ty`.
+    /// An associated type binding `$assoc_ty_name = $ty`.
     pub(crate) fn assoc_ty_binding(
         &mut self,
-        symbol: rustc_span::Symbol,
+        assoc_ty_name: rustc_span::Symbol,
         span: Span,
         ty: &'hir hir::Ty<'hir>,
     ) -> hir::TypeBinding<'hir> {
-        let ident = Ident::with_dummy_span(symbol);
+        let ident = Ident::with_dummy_span(assoc_ty_name);
         let kind = hir::TypeBindingKind::Equality { term: ty.into() };
         let args = arena_vec![self;];
         let bindings = arena_vec![self;];
diff --git a/compiler/rustc_ast_passes/src/ast_validation.rs b/compiler/rustc_ast_passes/src/ast_validation.rs
index 311ab96aba0..554ed36b814 100644
--- a/compiler/rustc_ast_passes/src/ast_validation.rs
+++ b/compiler/rustc_ast_passes/src/ast_validation.rs
@@ -1279,7 +1279,7 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
             ..
         }) = fk.header()
         {
-            // FIXME(eholk): Report a different error for `const gen`
+            // FIXME(gen_blocks): Report a different error for `const gen`
             self.err_handler().emit_err(errors::ConstAndAsync {
                 spans: vec![cspan, aspan],
                 cspan,
diff --git a/compiler/rustc_builtin_macros/src/test.rs b/compiler/rustc_builtin_macros/src/test.rs
index 38fdddf5834..81433155ecf 100644
--- a/compiler/rustc_builtin_macros/src/test.rs
+++ b/compiler/rustc_builtin_macros/src/test.rs
@@ -545,6 +545,10 @@ fn check_test_signature(
         return Err(sd.emit_err(errors::TestBadFn { span: i.span, cause: span, kind: "async" }));
     }
 
+    if let Some(ast::CoroutineKind::Gen { span, .. }) = f.sig.header.coro_kind {
+        return Err(sd.emit_err(errors::TestBadFn { span: i.span, cause: span, kind: "gen" }));
+    }
+
     // If the termination trait is active, the compiler will check that the output
     // type implements the `Termination` trait as `libtest` enforces that.
     let has_output = match &f.sig.decl.output {
diff --git a/compiler/rustc_parse/src/parser/item.rs b/compiler/rustc_parse/src/parser/item.rs
index 589fc46b722..8a987767dc4 100644
--- a/compiler/rustc_parse/src/parser/item.rs
+++ b/compiler/rustc_parse/src/parser/item.rs
@@ -2544,7 +2544,7 @@ impl<'a> Parser<'a> {
                         }
                     }
 
-                    // FIXME(eholk): add keyword recovery logic for genness
+                    // FIXME(gen_blocks): add keyword recovery logic for genness
 
                     if wrong_kw.is_some()
                         && self.may_recover()
diff --git a/compiler/rustc_parse/src/parser/ty.rs b/compiler/rustc_parse/src/parser/ty.rs
index 73487f4af0e..068a99db4ae 100644
--- a/compiler/rustc_parse/src/parser/ty.rs
+++ b/compiler/rustc_parse/src/parser/ty.rs
@@ -612,7 +612,7 @@ impl<'a> Parser<'a> {
         if let Some(ast::CoroutineKind::Async { span, .. }) = coro_kind {
             self.sess.emit_err(FnPointerCannotBeAsync { span: whole_span, qualifier: span });
         }
-        // FIXME(eholk): emit a similar error for `gen fn()`
+        // FIXME(gen_blocks): emit a similar error for `gen fn()`
         let decl_span = span_start.to(self.token.span);
         Ok(TyKind::BareFn(P(BareFnTy { ext, unsafety, generic_params: params, decl, decl_span })))
     }
diff --git a/compiler/rustc_resolve/src/late.rs b/compiler/rustc_resolve/src/late.rs
index c5d6574af60..ad14f5e5225 100644
--- a/compiler/rustc_resolve/src/late.rs
+++ b/compiler/rustc_resolve/src/late.rs
@@ -916,10 +916,10 @@ impl<'a: 'ast, 'ast, 'tcx> Visitor<'ast> for LateResolutionVisitor<'a, '_, 'ast,
                             &sig.decl.output,
                         );
 
-                        if let Some((async_node_id, _)) =
+                        if let Some((coro_node_id, _)) =
                             sig.header.coro_kind.map(|coro_kind| coro_kind.return_id())
                         {
-                            this.record_lifetime_params_for_impl_trait(async_node_id);
+                            this.record_lifetime_params_for_impl_trait(coro_node_id);
                         }
                     },
                 );
@@ -942,13 +942,13 @@ impl<'a: 'ast, 'ast, 'tcx> Visitor<'ast> for LateResolutionVisitor<'a, '_, 'ast,
                         this.visit_generics(generics);
 
                         let declaration = &sig.decl;
-                        let async_node_id =
+                        let coro_node_id =
                             sig.header.coro_kind.map(|coro_kind| coro_kind.return_id());
 
                         this.with_lifetime_rib(
                             LifetimeRibKind::AnonymousCreateParameter {
                                 binder: fn_id,
-                                report_in_path: async_node_id.is_some(),
+                                report_in_path: coro_node_id.is_some(),
                             },
                             |this| {
                                 this.resolve_fn_signature(
@@ -961,7 +961,7 @@ impl<'a: 'ast, 'ast, 'tcx> Visitor<'ast> for LateResolutionVisitor<'a, '_, 'ast,
                                     &declaration.output,
                                 );
 
-                                if let Some((async_node_id, _)) = async_node_id {
+                                if let Some((async_node_id, _)) = coro_node_id {
                                     this.record_lifetime_params_for_impl_trait(async_node_id);
                                 }
                             },
@@ -4291,8 +4291,10 @@ impl<'a: 'ast, 'b, 'ast, 'tcx> LateResolutionVisitor<'a, 'b, 'ast, 'tcx> {
             // `async |x| ...` gets desugared to `|x| async {...}`, so we need to
             // resolve the arguments within the proper scopes so that usages of them inside the
             // closure are detected as upvars rather than normal closure arg usages.
+            //
+            // Similarly, `gen |x| ...` gets desugared to `|x| gen {...}`, so we handle that too.
             ExprKind::Closure(box ast::Closure {
-                coro_kind: Some(CoroutineKind::Async { .. }),
+                coro_kind: Some(_),
                 ref fn_decl,
                 ref body,
                 ..
diff --git a/src/tools/rustfmt/src/closures.rs b/src/tools/rustfmt/src/closures.rs
index d79218e78ee..c1ce87eadcb 100644
--- a/src/tools/rustfmt/src/closures.rs
+++ b/src/tools/rustfmt/src/closures.rs
@@ -263,12 +263,10 @@ fn rewrite_closure_fn_decl(
     } else {
         ""
     };
-    let (is_async, is_gen) = if let Some(coro_kind) = coro_kind {
-        let is_async = if coro_kind.is_async() { "async " } else { "" };
-        let is_gen = if coro_kind.is_gen() { "gen " } else { "" };
-        (is_async, is_gen)
-    } else {
-        ("", "")
+    let coro = match coro_kind {
+        Some(ast::CoroutineKind::Async { .. }) => "async ",
+        Some(ast::CoroutineKind::Gen { .. }) => "gen ",
+        None => "",
     };
     let mover = if matches!(capture, ast::CaptureBy::Value { .. }) {
         "move "
@@ -278,14 +276,7 @@ fn rewrite_closure_fn_decl(
     // 4 = "|| {".len(), which is overconservative when the closure consists of
     // a single expression.
     let nested_shape = shape
-        .shrink_left(
-            binder.len()
-                + const_.len()
-                + immovable.len()
-                + is_async.len()
-                + is_gen.len()
-                + mover.len(),
-        )?
+        .shrink_left(binder.len() + const_.len() + immovable.len() + coro.len() + mover.len())?
         .sub_width(4)?;
 
     // 1 = |
@@ -323,7 +314,7 @@ fn rewrite_closure_fn_decl(
         .tactic(tactic)
         .preserve_newline(true);
     let list_str = write_list(&item_vec, &fmt)?;
-    let mut prefix = format!("{binder}{const_}{immovable}{is_async}{is_gen}{mover}|{list_str}|");
+    let mut prefix = format!("{binder}{const_}{immovable}{coro}{mover}|{list_str}|");
 
     if !ret_str.is_empty() {
         if prefix.contains('\n') {
diff --git a/tests/ui/coroutine/gen_fn_lifetime_capture.rs b/tests/ui/coroutine/gen_fn_lifetime_capture.rs
new file mode 100644
index 00000000000..b6a4d71e6cc
--- /dev/null
+++ b/tests/ui/coroutine/gen_fn_lifetime_capture.rs
@@ -0,0 +1,19 @@
+// edition: 2024
+// compile-flags: -Zunstable-options
+// check-pass
+#![feature(gen_blocks)]
+
+// make sure gen fn captures lifetimes in its signature
+
+gen fn foo<'a, 'b>(x: &'a i32, y: &'b i32, z: &'b i32) -> &'b i32 {
+    yield y;
+    yield z;
+}
+
+fn main() {
+    let z = 3;
+    let mut iter = foo(&1, &2, &z);
+    assert_eq!(iter.next(), Some(&2));
+    assert_eq!(iter.next(), Some(&3));
+    assert_eq!(iter.next(), None);
+}