about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbjorn3 <17426603+bjorn3@users.noreply.github.com>2023-11-23 20:02:45 +0000
committerbjorn3 <17426603+bjorn3@users.noreply.github.com>2023-11-23 20:17:19 +0000
commitb7bc8d5cb7685bd8e35d7b1c9d3011b043abf775 (patch)
tree9eec065d401fd5349616d750fc0a3514f4017280
parent237339fda17395d3e35f3028a0e0aa8278c3a4bf (diff)
downloadrust-b7bc8d5cb7685bd8e35d7b1c9d3011b043abf775.tar.gz
rust-b7bc8d5cb7685bd8e35d7b1c9d3011b043abf775.zip
Fix fn_sig_for_fn_abi and the coroutine transform for generators
There were three issues previously:
* The self argument was pinned, despite Iterator::next taking an
  unpinned mutable reference.
* A resume argument was passed, despite Iterator::next not having one.
* The return value was CoroutineState<Item, ()> rather than Option<Item>

While these things just so happened to work with the LLVM backend,
cg_clif does much stricter checks when trying to assign a value to a
place. In addition it can't handle the mismatch between the amount of
arguments specified by the FnAbi and the FnSig.
-rw-r--r--compiler/rustc_codegen_cranelift/build_system/tests.rs9
-rw-r--r--compiler/rustc_codegen_cranelift/config.txt1
-rw-r--r--compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs36
-rw-r--r--compiler/rustc_codegen_cranelift/rustfmt.toml5
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs32
-rw-r--r--compiler/rustc_ty_utils/src/abi.rs54
-rw-r--r--rustfmt.toml1
7 files changed, 129 insertions, 9 deletions
diff --git a/compiler/rustc_codegen_cranelift/build_system/tests.rs b/compiler/rustc_codegen_cranelift/build_system/tests.rs
index ff71a567ed3..aa50dbfdf35 100644
--- a/compiler/rustc_codegen_cranelift/build_system/tests.rs
+++ b/compiler/rustc_codegen_cranelift/build_system/tests.rs
@@ -100,6 +100,15 @@ const BASE_SYSROOT_SUITE: &[TestCase] = &[
     TestCase::build_bin_and_run("aot.issue-72793", "example/issue-72793.rs", &[]),
     TestCase::build_bin("aot.issue-59326", "example/issue-59326.rs"),
     TestCase::build_bin_and_run("aot.neon", "example/neon.rs", &[]),
+    TestCase::custom("aot.gen_block_iterate", &|runner| {
+        runner.run_rustc([
+            "example/gen_block_iterate.rs",
+            "--edition",
+            "2024",
+            "-Zunstable-options",
+        ]);
+        runner.run_out_command("gen_block_iterate", &[]);
+    }),
 ];
 
 pub(crate) static RAND_REPO: GitRepo = GitRepo::github(
diff --git a/compiler/rustc_codegen_cranelift/config.txt b/compiler/rustc_codegen_cranelift/config.txt
index 2ccdc7d7874..3cf295c003e 100644
--- a/compiler/rustc_codegen_cranelift/config.txt
+++ b/compiler/rustc_codegen_cranelift/config.txt
@@ -43,6 +43,7 @@ aot.mod_bench
 aot.issue-72793
 aot.issue-59326
 aot.neon
+aot.gen_block_iterate
 
 testsuite.extended_sysroot
 test.rust-random/rand
diff --git a/compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs b/compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs
new file mode 100644
index 00000000000..14bd23e77ea
--- /dev/null
+++ b/compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs
@@ -0,0 +1,36 @@
+// Copied from https://github.com/rust-lang/rust/blob/46455dc65069387f2dc46612f13fd45452ab301a/tests/ui/coroutine/gen_block_iterate.rs
+// revisions: next old
+//compile-flags: --edition 2024 -Zunstable-options
+//[next] compile-flags: -Ztrait-solver=next
+// run-pass
+#![feature(gen_blocks)]
+
+fn foo() -> impl Iterator<Item = u32> {
+    gen { yield 42; for x in 3..6 { yield x } }
+}
+
+fn moved() -> impl Iterator<Item = u32> {
+    let mut x = "foo".to_string();
+    gen move {
+        yield 42;
+        if x == "foo" { return }
+        x.clear();
+        for x in 3..6 { yield x }
+    }
+}
+
+fn main() {
+    let mut iter = foo();
+    assert_eq!(iter.next(), Some(42));
+    assert_eq!(iter.next(), Some(3));
+    assert_eq!(iter.next(), Some(4));
+    assert_eq!(iter.next(), Some(5));
+    assert_eq!(iter.next(), None);
+    // `gen` blocks are fused
+    assert_eq!(iter.next(), None);
+
+    let mut iter = moved();
+    assert_eq!(iter.next(), Some(42));
+    assert_eq!(iter.next(), None);
+
+}
diff --git a/compiler/rustc_codegen_cranelift/rustfmt.toml b/compiler/rustc_codegen_cranelift/rustfmt.toml
index ebeca8662a5..0f884187add 100644
--- a/compiler/rustc_codegen_cranelift/rustfmt.toml
+++ b/compiler/rustc_codegen_cranelift/rustfmt.toml
@@ -1,4 +1,7 @@
-ignore = ["y.rs"]
+ignore = [
+    "y.rs",
+    "example/gen_block_iterate.rs", # uses edition 2024
+]
 
 # Matches rustfmt.toml of rustc
 version = "Two"
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index aa4d8ddad56..42540911785 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -617,6 +617,22 @@ fn replace_resume_ty_local<'tcx>(
     }
 }
 
+/// Transforms the `body` of the coroutine applying the following transform:
+///
+/// - Remove the `resume` argument.
+///
+/// Ideally the async lowering would not add the `resume` argument.
+///
+/// The async lowering step and the type / lifetime inference / checking are
+/// still using the `resume` argument for the time being. After this transform,
+/// the coroutine body doesn't have the `resume` argument.
+fn transform_gen_context<'tcx>(_tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+    // This leaves the local representing the `resume` argument in place,
+    // but turns it into a regular local variable. This is cheaper than
+    // adjusting all local references in the body after removing it.
+    body.arg_count = 1;
+}
+
 struct LivenessInfo {
     /// Which locals are live across any suspension point.
     saved_locals: CoroutineSavedLocals,
@@ -1337,7 +1353,15 @@ fn create_coroutine_resume_function<'tcx>(
     insert_switch(body, cases, &transform, TerminatorKind::Unreachable);
 
     make_coroutine_state_argument_indirect(tcx, body);
-    make_coroutine_state_argument_pinned(tcx, body);
+
+    match coroutine_kind {
+        // Iterator::next doesn't accept a pinned argument,
+        // unlike for all other coroutine kinds.
+        CoroutineKind::Gen(_) => {}
+        _ => {
+            make_coroutine_state_argument_pinned(tcx, body);
+        }
+    }
 
     // Make sure we remove dead blocks to remove
     // unrelated code from the drop part of the function
@@ -1504,6 +1528,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         };
 
         let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
+        let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
         let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
             CoroutineKind::Async(_) => {
                 // Compute Poll<return_ty>
@@ -1609,6 +1634,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         body.arg_count = 2; // self, resume arg
         body.spread_arg = None;
 
+        // Remove the context argument within generator bodies.
+        if is_gen_kind {
+            transform_gen_context(tcx, body);
+        }
+
         // The original arguments to the function are no longer arguments, mark them as such.
         // Otherwise they'll conflict with our new arguments, which although they don't have
         // argument_index set, will get emitted as unnamed arguments.
diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs
index 737acfbc600..8ea78b9b532 100644
--- a/compiler/rustc_ty_utils/src/abi.rs
+++ b/compiler/rustc_ty_utils/src/abi.rs
@@ -112,7 +112,13 @@ fn fn_sig_for_fn_abi<'tcx>(
             let pin_did = tcx.require_lang_item(LangItem::Pin, None);
             let pin_adt_ref = tcx.adt_def(pin_did);
             let pin_args = tcx.mk_args(&[env_ty.into()]);
-            let env_ty = Ty::new_adt(tcx, pin_adt_ref, pin_args);
+            let env_ty = if tcx.coroutine_is_gen(did) {
+                // Iterator::next doesn't accept a pinned argument,
+                // unlike for all other coroutine kinds.
+                env_ty
+            } else {
+                Ty::new_adt(tcx, pin_adt_ref, pin_args)
+            };
 
             let sig = sig.skip_binder();
             // The `FnSig` and the `ret_ty` here is for a coroutines main
@@ -121,6 +127,8 @@ fn fn_sig_for_fn_abi<'tcx>(
             // function in case this is a special coroutine backing an async construct.
             let (resume_ty, ret_ty) = if tcx.coroutine_is_async(did) {
                 // The signature should be `Future::poll(_, &mut Context<'_>) -> Poll<Output>`
+                assert_eq!(sig.yield_ty, tcx.types.unit);
+
                 let poll_did = tcx.require_lang_item(LangItem::Poll, None);
                 let poll_adt_ref = tcx.adt_def(poll_did);
                 let poll_args = tcx.mk_args(&[sig.return_ty.into()]);
@@ -140,7 +148,30 @@ fn fn_sig_for_fn_abi<'tcx>(
                 }
                 let context_mut_ref = Ty::new_task_context(tcx);
 
-                (context_mut_ref, ret_ty)
+                (Some(context_mut_ref), ret_ty)
+            } else if tcx.coroutine_is_gen(did) {
+                // The signature should be `Iterator::next(_) -> Option<Yield>`
+                let option_did = tcx.require_lang_item(LangItem::Option, None);
+                let option_adt_ref = tcx.adt_def(option_did);
+                let option_args = tcx.mk_args(&[sig.yield_ty.into()]);
+                let ret_ty = Ty::new_adt(tcx, option_adt_ref, option_args);
+
+                assert_eq!(sig.return_ty, tcx.types.unit);
+
+                // We have to replace the `ResumeTy` that is used for type and borrow checking
+                // with `()` which is used in codegen.
+                #[cfg(debug_assertions)]
+                {
+                    if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() {
+                        let expected_adt =
+                            tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
+                        assert_eq!(*resume_ty_adt, expected_adt);
+                    } else {
+                        panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty);
+                    };
+                }
+
+                (None, ret_ty)
             } else {
                 // The signature should be `Coroutine::resume(_, Resume) -> CoroutineState<Yield, Return>`
                 let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
@@ -148,19 +179,28 @@ fn fn_sig_for_fn_abi<'tcx>(
                 let state_args = tcx.mk_args(&[sig.yield_ty.into(), sig.return_ty.into()]);
                 let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
 
-                (sig.resume_ty, ret_ty)
+                (Some(sig.resume_ty), ret_ty)
             };
 
-            ty::Binder::bind_with_vars(
+            let fn_sig = if let Some(resume_ty) = resume_ty {
                 tcx.mk_fn_sig(
                     [env_ty, resume_ty],
                     ret_ty,
                     false,
                     hir::Unsafety::Normal,
                     rustc_target::spec::abi::Abi::Rust,
-                ),
-                bound_vars,
-            )
+                )
+            } else {
+                // `Iterator::next` doesn't have a `resume` argument.
+                tcx.mk_fn_sig(
+                    [env_ty],
+                    ret_ty,
+                    false,
+                    hir::Unsafety::Normal,
+                    rustc_target::spec::abi::Abi::Rust,
+                )
+            };
+            ty::Binder::bind_with_vars(fn_sig, bound_vars)
         }
         _ => bug!("unexpected type {:?} in Instance::fn_sig", ty),
     }
diff --git a/rustfmt.toml b/rustfmt.toml
index 88700779e87..e292a310742 100644
--- a/rustfmt.toml
+++ b/rustfmt.toml
@@ -39,4 +39,5 @@ ignore = [
     # these are ignored by a standard cargo fmt run
     "compiler/rustc_codegen_cranelift/y.rs", # running rustfmt breaks this file
     "compiler/rustc_codegen_cranelift/scripts",
+    "compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs", # uses edition 2024
 ]