summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-11-24 00:26:24 +0000
committerbors <bors@rust-lang.org>2023-11-24 00:26:24 +0000
commit1fd418f92ed13db88a21865ba5d909abcf16b6cc (patch)
tree4d790550dec9277300e1ef7d5a676491cfacf847
parente68f935117fef05e7663605a5a6671a6bb4ce719 (diff)
parent543e559c5300e8bd3be6d491897eaac327f9dc71 (diff)
downloadrust-1fd418f92ed13db88a21865ba5d909abcf16b6cc.tar.gz
rust-1fd418f92ed13db88a21865ba5d909abcf16b6cc.zip
Auto merge of #118219 - bjorn3:fix_generator_fn_abi, r=compiler-errors
Fix fn_sig_for_fn_abi and the coroutine transform for generators

There were three issues previously:
* The self argument was pinned, despite Iterator::next taking an unpinned mutable reference.
* A resume argument was passed, despite Iterator::next not having one.
* The return value was CoroutineState<Item, ()> rather than Option<Item>

While these things just so happened to work with the LLVM backend, cg_clif does much stricter checks when trying to assign a value to a place. In addition it can't handle the mismatch between the amount of arguments specified by the FnAbi and the FnSig.
-rw-r--r--compiler/rustc_codegen_cranelift/build_system/tests.rs9
-rw-r--r--compiler/rustc_codegen_cranelift/config.txt1
-rw-r--r--compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs36
-rw-r--r--compiler/rustc_codegen_cranelift/rustfmt.toml5
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs32
-rw-r--r--compiler/rustc_ty_utils/src/abi.rs106
-rw-r--r--rustfmt.toml1
7 files changed, 154 insertions, 36 deletions
diff --git a/compiler/rustc_codegen_cranelift/build_system/tests.rs b/compiler/rustc_codegen_cranelift/build_system/tests.rs
index ff71a567ed3..aa50dbfdf35 100644
--- a/compiler/rustc_codegen_cranelift/build_system/tests.rs
+++ b/compiler/rustc_codegen_cranelift/build_system/tests.rs
@@ -100,6 +100,15 @@ const BASE_SYSROOT_SUITE: &[TestCase] = &[
     TestCase::build_bin_and_run("aot.issue-72793", "example/issue-72793.rs", &[]),
     TestCase::build_bin("aot.issue-59326", "example/issue-59326.rs"),
     TestCase::build_bin_and_run("aot.neon", "example/neon.rs", &[]),
+    TestCase::custom("aot.gen_block_iterate", &|runner| {
+        runner.run_rustc([
+            "example/gen_block_iterate.rs",
+            "--edition",
+            "2024",
+            "-Zunstable-options",
+        ]);
+        runner.run_out_command("gen_block_iterate", &[]);
+    }),
 ];
 
 pub(crate) static RAND_REPO: GitRepo = GitRepo::github(
diff --git a/compiler/rustc_codegen_cranelift/config.txt b/compiler/rustc_codegen_cranelift/config.txt
index 2ccdc7d7874..3cf295c003e 100644
--- a/compiler/rustc_codegen_cranelift/config.txt
+++ b/compiler/rustc_codegen_cranelift/config.txt
@@ -43,6 +43,7 @@ aot.mod_bench
 aot.issue-72793
 aot.issue-59326
 aot.neon
+aot.gen_block_iterate
 
 testsuite.extended_sysroot
 test.rust-random/rand
diff --git a/compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs b/compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs
new file mode 100644
index 00000000000..14bd23e77ea
--- /dev/null
+++ b/compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs
@@ -0,0 +1,36 @@
+// Copied from https://github.com/rust-lang/rust/blob/46455dc65069387f2dc46612f13fd45452ab301a/tests/ui/coroutine/gen_block_iterate.rs
+// revisions: next old
+//compile-flags: --edition 2024 -Zunstable-options
+//[next] compile-flags: -Ztrait-solver=next
+// run-pass
+#![feature(gen_blocks)]
+
+fn foo() -> impl Iterator<Item = u32> {
+    gen { yield 42; for x in 3..6 { yield x } }
+}
+
+fn moved() -> impl Iterator<Item = u32> {
+    let mut x = "foo".to_string();
+    gen move {
+        yield 42;
+        if x == "foo" { return }
+        x.clear();
+        for x in 3..6 { yield x }
+    }
+}
+
+fn main() {
+    let mut iter = foo();
+    assert_eq!(iter.next(), Some(42));
+    assert_eq!(iter.next(), Some(3));
+    assert_eq!(iter.next(), Some(4));
+    assert_eq!(iter.next(), Some(5));
+    assert_eq!(iter.next(), None);
+    // `gen` blocks are fused
+    assert_eq!(iter.next(), None);
+
+    let mut iter = moved();
+    assert_eq!(iter.next(), Some(42));
+    assert_eq!(iter.next(), None);
+
+}
diff --git a/compiler/rustc_codegen_cranelift/rustfmt.toml b/compiler/rustc_codegen_cranelift/rustfmt.toml
index ebeca8662a5..0f884187add 100644
--- a/compiler/rustc_codegen_cranelift/rustfmt.toml
+++ b/compiler/rustc_codegen_cranelift/rustfmt.toml
@@ -1,4 +1,7 @@
-ignore = ["y.rs"]
+ignore = [
+    "y.rs",
+    "example/gen_block_iterate.rs", # uses edition 2024
+]
 
 # Matches rustfmt.toml of rustc
 version = "Two"
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index aa4d8ddad56..42540911785 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -617,6 +617,22 @@ fn replace_resume_ty_local<'tcx>(
     }
 }
 
+/// Transforms the `body` of the coroutine applying the following transform:
+///
+/// - Remove the `resume` argument.
+///
+/// Ideally the async lowering would not add the `resume` argument.
+///
+/// The async lowering step and the type / lifetime inference / checking are
+/// still using the `resume` argument for the time being. After this transform,
+/// the coroutine body doesn't have the `resume` argument.
+fn transform_gen_context<'tcx>(_tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+    // This leaves the local representing the `resume` argument in place,
+    // but turns it into a regular local variable. This is cheaper than
+    // adjusting all local references in the body after removing it.
+    body.arg_count = 1;
+}
+
 struct LivenessInfo {
     /// Which locals are live across any suspension point.
     saved_locals: CoroutineSavedLocals,
@@ -1337,7 +1353,15 @@ fn create_coroutine_resume_function<'tcx>(
     insert_switch(body, cases, &transform, TerminatorKind::Unreachable);
 
     make_coroutine_state_argument_indirect(tcx, body);
-    make_coroutine_state_argument_pinned(tcx, body);
+
+    match coroutine_kind {
+        // Iterator::next doesn't accept a pinned argument,
+        // unlike for all other coroutine kinds.
+        CoroutineKind::Gen(_) => {}
+        _ => {
+            make_coroutine_state_argument_pinned(tcx, body);
+        }
+    }
 
     // Make sure we remove dead blocks to remove
     // unrelated code from the drop part of the function
@@ -1504,6 +1528,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         };
 
         let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
+        let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
         let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
             CoroutineKind::Async(_) => {
                 // Compute Poll<return_ty>
@@ -1609,6 +1634,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         body.arg_count = 2; // self, resume arg
         body.spread_arg = None;
 
+        // Remove the context argument within generator bodies.
+        if is_gen_kind {
+            transform_gen_context(tcx, body);
+        }
+
         // The original arguments to the function are no longer arguments, mark them as such.
         // Otherwise they'll conflict with our new arguments, which although they don't have
         // argument_index set, will get emitted as unnamed arguments.
diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs
index 737acfbc600..85e137d29ac 100644
--- a/compiler/rustc_ty_utils/src/abi.rs
+++ b/compiler/rustc_ty_utils/src/abi.rs
@@ -98,6 +98,7 @@ fn fn_sig_for_fn_abi<'tcx>(
             )
         }
         ty::Coroutine(did, args, _) => {
+            let coroutine_kind = tcx.coroutine_kind(did).unwrap();
             let sig = args.as_coroutine().poly_sig();
 
             let bound_vars = tcx.mk_bound_variable_kinds_from_iter(
@@ -112,55 +113,92 @@ fn fn_sig_for_fn_abi<'tcx>(
             let pin_did = tcx.require_lang_item(LangItem::Pin, None);
             let pin_adt_ref = tcx.adt_def(pin_did);
             let pin_args = tcx.mk_args(&[env_ty.into()]);
-            let env_ty = Ty::new_adt(tcx, pin_adt_ref, pin_args);
+            let env_ty = match coroutine_kind {
+                hir::CoroutineKind::Gen(_) => {
+                    // Iterator::next doesn't accept a pinned argument,
+                    // unlike for all other coroutine kinds.
+                    env_ty
+                }
+                hir::CoroutineKind::Async(_) | hir::CoroutineKind::Coroutine => {
+                    Ty::new_adt(tcx, pin_adt_ref, pin_args)
+                }
+            };
 
             let sig = sig.skip_binder();
             // The `FnSig` and the `ret_ty` here is for a coroutines main
             // `Coroutine::resume(...) -> CoroutineState` function in case we
-            // have an ordinary coroutine, or the `Future::poll(...) -> Poll`
-            // function in case this is a special coroutine backing an async construct.
-            let (resume_ty, ret_ty) = if tcx.coroutine_is_async(did) {
-                // The signature should be `Future::poll(_, &mut Context<'_>) -> Poll<Output>`
-                let poll_did = tcx.require_lang_item(LangItem::Poll, None);
-                let poll_adt_ref = tcx.adt_def(poll_did);
-                let poll_args = tcx.mk_args(&[sig.return_ty.into()]);
-                let ret_ty = Ty::new_adt(tcx, poll_adt_ref, poll_args);
-
-                // We have to replace the `ResumeTy` that is used for type and borrow checking
-                // with `&mut Context<'_>` which is used in codegen.
-                #[cfg(debug_assertions)]
-                {
-                    if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() {
-                        let expected_adt =
-                            tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
-                        assert_eq!(*resume_ty_adt, expected_adt);
-                    } else {
-                        panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty);
-                    };
+            // have an ordinary coroutine, the `Future::poll(...) -> Poll`
+            // function in case this is a special coroutine backing an async construct
+            // or the `Iterator::next(...) -> Option` function in case this is a
+            // special coroutine backing a gen construct.
+            let (resume_ty, ret_ty) = match coroutine_kind {
+                hir::CoroutineKind::Async(_) => {
+                    // The signature should be `Future::poll(_, &mut Context<'_>) -> Poll<Output>`
+                    assert_eq!(sig.yield_ty, tcx.types.unit);
+
+                    let poll_did = tcx.require_lang_item(LangItem::Poll, None);
+                    let poll_adt_ref = tcx.adt_def(poll_did);
+                    let poll_args = tcx.mk_args(&[sig.return_ty.into()]);
+                    let ret_ty = Ty::new_adt(tcx, poll_adt_ref, poll_args);
+
+                    // We have to replace the `ResumeTy` that is used for type and borrow checking
+                    // with `&mut Context<'_>` which is used in codegen.
+                    #[cfg(debug_assertions)]
+                    {
+                        if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() {
+                            let expected_adt =
+                                tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
+                            assert_eq!(*resume_ty_adt, expected_adt);
+                        } else {
+                            panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty);
+                        };
+                    }
+                    let context_mut_ref = Ty::new_task_context(tcx);
+
+                    (Some(context_mut_ref), ret_ty)
                 }
-                let context_mut_ref = Ty::new_task_context(tcx);
+                hir::CoroutineKind::Gen(_) => {
+                    // The signature should be `Iterator::next(_) -> Option<Yield>`
+                    let option_did = tcx.require_lang_item(LangItem::Option, None);
+                    let option_adt_ref = tcx.adt_def(option_did);
+                    let option_args = tcx.mk_args(&[sig.yield_ty.into()]);
+                    let ret_ty = Ty::new_adt(tcx, option_adt_ref, option_args);
 
-                (context_mut_ref, ret_ty)
-            } else {
-                // The signature should be `Coroutine::resume(_, Resume) -> CoroutineState<Yield, Return>`
-                let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
-                let state_adt_ref = tcx.adt_def(state_did);
-                let state_args = tcx.mk_args(&[sig.yield_ty.into(), sig.return_ty.into()]);
-                let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
+                    assert_eq!(sig.return_ty, tcx.types.unit);
+                    assert_eq!(sig.resume_ty, tcx.types.unit);
 
-                (sig.resume_ty, ret_ty)
+                    (None, ret_ty)
+                }
+                hir::CoroutineKind::Coroutine => {
+                    // The signature should be `Coroutine::resume(_, Resume) -> CoroutineState<Yield, Return>`
+                    let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
+                    let state_adt_ref = tcx.adt_def(state_did);
+                    let state_args = tcx.mk_args(&[sig.yield_ty.into(), sig.return_ty.into()]);
+                    let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
+
+                    (Some(sig.resume_ty), ret_ty)
+                }
             };
 
-            ty::Binder::bind_with_vars(
+            let fn_sig = if let Some(resume_ty) = resume_ty {
                 tcx.mk_fn_sig(
                     [env_ty, resume_ty],
                     ret_ty,
                     false,
                     hir::Unsafety::Normal,
                     rustc_target::spec::abi::Abi::Rust,
-                ),
-                bound_vars,
-            )
+                )
+            } else {
+                // `Iterator::next` doesn't have a `resume` argument.
+                tcx.mk_fn_sig(
+                    [env_ty],
+                    ret_ty,
+                    false,
+                    hir::Unsafety::Normal,
+                    rustc_target::spec::abi::Abi::Rust,
+                )
+            };
+            ty::Binder::bind_with_vars(fn_sig, bound_vars)
         }
         _ => bug!("unexpected type {:?} in Instance::fn_sig", ty),
     }
diff --git a/rustfmt.toml b/rustfmt.toml
index 88700779e87..e292a310742 100644
--- a/rustfmt.toml
+++ b/rustfmt.toml
@@ -39,4 +39,5 @@ ignore = [
     # these are ignored by a standard cargo fmt run
     "compiler/rustc_codegen_cranelift/y.rs", # running rustfmt breaks this file
     "compiler/rustc_codegen_cranelift/scripts",
+    "compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs", # uses edition 2024
 ]