about summary refs log tree commit diff
diff options
context:
space:
mode:
authorOli Scherer <git-spam-no-reply9815368754983@oli-obk.de>2023-10-20 19:21:24 +0000
committerOli Scherer <git-spam-no-reply9815368754983@oli-obk.de>2023-10-26 07:10:25 +0000
commit14423080f1f9e9729d372f9b5da06c0ab4aef6e3 (patch)
treede9f66b453bbae0c74c5f11256d9637bd18463d5
parenta61cf673cd10ed24394c3403e03d8f7cf1f50170 (diff)
downloadrust-14423080f1f9e9729d372f9b5da06c0ab4aef6e3.tar.gz
rust-14423080f1f9e9729d372f9b5da06c0ab4aef6e3.zip
Add hir::GeneratorKind::Gen
-rw-r--r--compiler/rustc_ast_lowering/src/expr.rs7
-rw-r--r--compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs5
-rw-r--r--compiler/rustc_borrowck/src/diagnostics/region_name.rs14
-rw-r--r--compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs3
-rw-r--r--compiler/rustc_hir/src/hir.rs12
-rw-r--r--compiler/rustc_hir_typeck/src/check.rs19
-rw-r--r--compiler/rustc_middle/src/mir/terminator.rs10
-rw-r--r--compiler/rustc_middle/src/ty/util.rs2
-rw-r--r--compiler/rustc_smir/src/rustc_smir/mod.rs28
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs17
-rw-r--r--compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs3
-rw-r--r--compiler/stable_mir/src/mir/body.rs1
12 files changed, 99 insertions, 22 deletions
diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs
index b82f878ea87..ab82bbd46a6 100644
--- a/compiler/rustc_ast_lowering/src/expr.rs
+++ b/compiler/rustc_ast_lowering/src/expr.rs
@@ -712,7 +712,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
         let full_span = expr.span.to(await_kw_span);
         match self.coroutine_kind {
             Some(hir::CoroutineKind::Async(_)) => {}
-            Some(hir::CoroutineKind::Coroutine) | None => {
+            Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) | None => {
                 self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks {
                     await_kw_span,
                     item_span: self.current_item,
@@ -936,8 +936,8 @@ impl<'hir> LoweringContext<'_, 'hir> {
                 }
                 Some(movability)
             }
-            Some(hir::CoroutineKind::Async(_)) => {
-                panic!("non-`async` closure body turned `async` during lowering");
+            Some(hir::CoroutineKind::Gen(_)) | Some(hir::CoroutineKind::Async(_)) => {
+                panic!("non-`async`/`gen` closure body turned `async`/`gen` during lowering");
             }
             None => {
                 if movability == Movability::Static {
@@ -1446,6 +1446,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
     fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> {
         match self.coroutine_kind {
             Some(hir::CoroutineKind::Coroutine) => {}
+            Some(hir::CoroutineKind::Gen(_)) => {}
             Some(hir::CoroutineKind::Async(_)) => {
                 self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span });
             }
diff --git a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
index 928c25d1061..9db7dec1cf8 100644
--- a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
@@ -2505,6 +2505,11 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
         };
         let kind = match use_span.coroutine_kind() {
             Some(coroutine_kind) => match coroutine_kind {
+                CoroutineKind::Gen(kind) => match kind {
+                    CoroutineSource::Block => "gen block",
+                    CoroutineSource::Closure => "gen closure",
+                    _ => bug!("gen block/closure expected, but gen function found."),
+                },
                 CoroutineKind::Async(async_kind) => match async_kind {
                     CoroutineSource::Block => "async block",
                     CoroutineSource::Closure => "async closure",
diff --git a/compiler/rustc_borrowck/src/diagnostics/region_name.rs b/compiler/rustc_borrowck/src/diagnostics/region_name.rs
index e630ac7c4fa..d38cfbc54d7 100644
--- a/compiler/rustc_borrowck/src/diagnostics/region_name.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/region_name.rs
@@ -698,6 +698,20 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
                             " of async function"
                         }
                     },
+                    Some(hir::CoroutineKind::Gen(gen)) => match gen {
+                        hir::CoroutineSource::Block => " of gen block",
+                        hir::CoroutineSource::Closure => " of gen closure",
+                        hir::CoroutineSource::Fn => {
+                            let parent_item =
+                                hir.get_by_def_id(hir.get_parent_item(mir_hir_id).def_id);
+                            let output = &parent_item
+                                .fn_decl()
+                                .expect("coroutine lowered from gen fn should be in fn")
+                                .output;
+                            span = output.span();
+                            " of gen function"
+                        }
+                    },
                     Some(hir::CoroutineKind::Coroutine) => " of coroutine",
                     None => " of closure",
                 };
diff --git a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs
index 5900c764073..1a85eb8dd79 100644
--- a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs
+++ b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs
@@ -560,6 +560,9 @@ pub fn push_item_name(tcx: TyCtxt<'_>, def_id: DefId, qualified: bool, output: &
 
 fn coroutine_kind_label(coroutine_kind: Option<CoroutineKind>) -> &'static str {
     match coroutine_kind {
+        Some(CoroutineKind::Gen(CoroutineSource::Block)) => "gen_block",
+        Some(CoroutineKind::Gen(CoroutineSource::Closure)) => "gen_closure",
+        Some(CoroutineKind::Gen(CoroutineSource::Fn)) => "gen_fn",
         Some(CoroutineKind::Async(CoroutineSource::Block)) => "async_block",
         Some(CoroutineKind::Async(CoroutineSource::Closure)) => "async_closure",
         Some(CoroutineKind::Async(CoroutineSource::Fn)) => "async_fn",
diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs
index 259af4f565b..22ee39634a5 100644
--- a/compiler/rustc_hir/src/hir.rs
+++ b/compiler/rustc_hir/src/hir.rs
@@ -1513,6 +1513,9 @@ pub enum CoroutineKind {
     /// An explicit `async` block or the body of an async function.
     Async(CoroutineSource),
 
+    /// An explicit `gen` block or the body of a `gen` function.
+    Gen(CoroutineSource),
+
     /// A coroutine literal created via a `yield` inside a closure.
     Coroutine,
 }
@@ -1529,6 +1532,14 @@ impl fmt::Display for CoroutineKind {
                 k.fmt(f)
             }
             CoroutineKind::Coroutine => f.write_str("coroutine"),
+            CoroutineKind::Gen(k) => {
+                if f.alternate() {
+                    f.write_str("`gen` ")?;
+                } else {
+                    f.write_str("gen ")?
+                }
+                k.fmt(f)
+            }
         }
     }
 }
@@ -2242,6 +2253,7 @@ impl From<CoroutineKind> for YieldSource {
             // Guess based on the kind of the current coroutine.
             CoroutineKind::Coroutine => Self::Yield,
             CoroutineKind::Async(_) => Self::Await { expr: None },
+            CoroutineKind::Gen(_) => Self::Yield,
         }
     }
 }
diff --git a/compiler/rustc_hir_typeck/src/check.rs b/compiler/rustc_hir_typeck/src/check.rs
index e7060dac844..14cd3881a06 100644
--- a/compiler/rustc_hir_typeck/src/check.rs
+++ b/compiler/rustc_hir_typeck/src/check.rs
@@ -58,15 +58,16 @@ pub(super) fn check_fn<'a, 'tcx>(
     if let Some(kind) = body.coroutine_kind
         && can_be_coroutine.is_some()
     {
-        let yield_ty = if kind == hir::CoroutineKind::Coroutine {
-            let yield_ty = fcx.next_ty_var(TypeVariableOrigin {
-                kind: TypeVariableOriginKind::TypeInference,
-                span,
-            });
-            fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType);
-            yield_ty
-        } else {
-            Ty::new_unit(tcx)
+        let yield_ty = match kind {
+            hir::CoroutineKind::Gen(..) | hir::CoroutineKind::Coroutine => {
+                let yield_ty = fcx.next_ty_var(TypeVariableOrigin {
+                    kind: TypeVariableOriginKind::TypeInference,
+                    span,
+                });
+                fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType);
+                yield_ty
+            }
+            hir::CoroutineKind::Async(..) => Ty::new_unit(tcx),
         };
 
         // Resume type defaults to `()` if the coroutine has no argument.
diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs
index affa83fa348..ef1d0c867ea 100644
--- a/compiler/rustc_middle/src/mir/terminator.rs
+++ b/compiler/rustc_middle/src/mir/terminator.rs
@@ -148,8 +148,14 @@ impl<O> AssertKind<O> {
             RemainderByZero(_) => "attempt to calculate the remainder with a divisor of zero",
             ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion",
             ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion",
+            ResumedAfterReturn(CoroutineKind::Gen(_)) => {
+                bug!("`gen fn` should just keep returning `None` after the first time")
+            }
             ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking",
             ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking",
+            ResumedAfterPanic(CoroutineKind::Gen(_)) => {
+                bug!("`gen fn` should just keep returning `None` after panicking")
+            }
             BoundsCheck { .. } | MisalignedPointerDereference { .. } => {
                 bug!("Unexpected AssertKind")
             }
@@ -236,10 +242,14 @@ impl<O> AssertKind<O> {
             DivisionByZero(_) => middle_assert_divide_by_zero,
             RemainderByZero(_) => middle_assert_remainder_by_zero,
             ResumedAfterReturn(CoroutineKind::Async(_)) => middle_assert_async_resume_after_return,
+            // FIXME(gen_blocks): custom error message for `gen` blocks
+            ResumedAfterReturn(CoroutineKind::Gen(_)) => middle_assert_async_resume_after_return,
             ResumedAfterReturn(CoroutineKind::Coroutine) => {
                 middle_assert_coroutine_resume_after_return
             }
             ResumedAfterPanic(CoroutineKind::Async(_)) => middle_assert_async_resume_after_panic,
+            // FIXME(gen_blocks): custom error message for `gen` blocks
+            ResumedAfterPanic(CoroutineKind::Gen(_)) => middle_assert_async_resume_after_panic,
             ResumedAfterPanic(CoroutineKind::Coroutine) => {
                 middle_assert_coroutine_resume_after_panic
             }
diff --git a/compiler/rustc_middle/src/ty/util.rs b/compiler/rustc_middle/src/ty/util.rs
index be48c0e8926..d29d1902404 100644
--- a/compiler/rustc_middle/src/ty/util.rs
+++ b/compiler/rustc_middle/src/ty/util.rs
@@ -749,6 +749,7 @@ impl<'tcx> TyCtxt<'tcx> {
             DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() {
                 rustc_hir::CoroutineKind::Async(..) => "async closure",
                 rustc_hir::CoroutineKind::Coroutine => "coroutine",
+                rustc_hir::CoroutineKind::Gen(..) => "gen closure",
             },
             _ => def_kind.descr(def_id),
         }
@@ -766,6 +767,7 @@ impl<'tcx> TyCtxt<'tcx> {
             DefKind::Coroutine => match self.coroutine_kind(def_id).unwrap() {
                 rustc_hir::CoroutineKind::Async(..) => "an",
                 rustc_hir::CoroutineKind::Coroutine => "a",
+                rustc_hir::CoroutineKind::Gen(..) => "a",
             },
             _ => def_kind.article(),
         }
diff --git a/compiler/rustc_smir/src/rustc_smir/mod.rs b/compiler/rustc_smir/src/rustc_smir/mod.rs
index 928536721c9..219f6d284dd 100644
--- a/compiler/rustc_smir/src/rustc_smir/mod.rs
+++ b/compiler/rustc_smir/src/rustc_smir/mod.rs
@@ -880,18 +880,28 @@ impl<'tcx> Stable<'tcx> for mir::AggregateKind<'tcx> {
     }
 }
 
+impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineSource {
+    type T = stable_mir::mir::CoroutineSource;
+    fn stable(&self, _: &mut Tables<'tcx>) -> Self::T {
+        use rustc_hir::CoroutineSource;
+        match self {
+            CoroutineSource::Block => stable_mir::mir::CoroutineSource::Block,
+            CoroutineSource::Closure => stable_mir::mir::CoroutineSource::Closure,
+            CoroutineSource::Fn => stable_mir::mir::CoroutineSource::Fn,
+        }
+    }
+}
+
 impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineKind {
     type T = stable_mir::mir::CoroutineKind;
-    fn stable(&self, _: &mut Tables<'tcx>) -> Self::T {
-        use rustc_hir::{CoroutineKind, CoroutineSource};
+    fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T {
+        use rustc_hir::CoroutineKind;
         match self {
-            CoroutineKind::Async(async_gen) => {
-                let async_gen = match async_gen {
-                    CoroutineSource::Block => stable_mir::mir::CoroutineSource::Block,
-                    CoroutineSource::Closure => stable_mir::mir::CoroutineSource::Closure,
-                    CoroutineSource::Fn => stable_mir::mir::CoroutineSource::Fn,
-                };
-                stable_mir::mir::CoroutineKind::Async(async_gen)
+            CoroutineKind::Async(source) => {
+                stable_mir::mir::CoroutineKind::Async(source.stable(tables))
+            }
+            CoroutineKind::Gen(source) => {
+                stable_mir::mir::CoroutineKind::Gen(source.stable(tables))
             }
             CoroutineKind::Coroutine => stable_mir::mir::CoroutineKind::Coroutine,
         }
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
index b9a08056ad1..3087b734cf2 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
@@ -2425,6 +2425,21 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
                         CoroutineKind::Async(CoroutineSource::Closure) => {
                             format!("future created by async closure is not {trait_name}")
                         }
+                        CoroutineKind::Gen(CoroutineSource::Fn) => self
+                            .tcx
+                            .parent(coroutine_did)
+                            .as_local()
+                            .map(|parent_did| hir.local_def_id_to_hir_id(parent_did))
+                            .and_then(|parent_hir_id| hir.opt_name(parent_hir_id))
+                            .map(|name| {
+                                format!("iterator returned by `{name}` is not {trait_name}")
+                            })?,
+                        CoroutineKind::Gen(CoroutineSource::Block) => {
+                            format!("iterator created by gen block is not {trait_name}")
+                        }
+                        CoroutineKind::Gen(CoroutineSource::Closure) => {
+                            format!("iterator created by gen closure is not {trait_name}")
+                        }
                     })
                 })
                 .unwrap_or_else(|| format!("{future_or_coroutine} is not {trait_name}"));
@@ -2905,7 +2920,7 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
             }
             ObligationCauseCode::SizedCoroutineInterior(coroutine_def_id) => {
                 let what = match self.tcx.coroutine_kind(coroutine_def_id) {
-                    None | Some(hir::CoroutineKind::Coroutine) => "yield",
+                    None | Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) => "yield",
                     Some(hir::CoroutineKind::Async(..)) => "await",
                 };
                 err.note(format!(
diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs
index 353a69ef6d3..030813a4ac3 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs
@@ -1614,6 +1614,9 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
             hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "an async block",
             hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "an async function",
             hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "an async closure",
+            hir::CoroutineKind::Gen(hir::CoroutineSource::Block) => "a gen block",
+            hir::CoroutineKind::Gen(hir::CoroutineSource::Fn) => "a gen function",
+            hir::CoroutineKind::Gen(hir::CoroutineSource::Closure) => "a gen closure",
         })
     }
 
diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs
index c2e79eaaf3f..b686004b2fe 100644
--- a/compiler/stable_mir/src/mir/body.rs
+++ b/compiler/stable_mir/src/mir/body.rs
@@ -137,6 +137,7 @@ pub enum UnOp {
 pub enum CoroutineKind {
     Async(CoroutineSource),
     Coroutine,
+    Gen(CoroutineSource),
 }
 
 #[derive(Clone, Debug)]