about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMatthias Krüger <matthias.krueger@famsik.de>2024-04-02 21:22:03 +0200
committerGitHub <noreply@github.com>2024-04-02 21:22:03 +0200
commit93729488895537f7ddc0dccb1ec817eb1d8a28af (patch)
tree5888888ef6c13277c7948884d94575b145b0b3c3
parent5b717684ffc5f908a9a3953a21cd3d3d20020c1e (diff)
parenta333b82d04c6077d639c664d45d051a0bdcbabed (diff)
downloadrust-93729488895537f7ddc0dccb1ec817eb1d8a28af.tar.gz
rust-93729488895537f7ddc0dccb1ec817eb1d8a28af.zip
Rollup merge of #123368 - maurer:cfi-non-general-coroutines, r=compiler-errors
CFI: Support non-general coroutines

Previously, we assumed all `ty::Coroutine` were general coroutines and attempted to generalize them through the `Coroutine` trait. Select appropriate traits for each kind of coroutine.

I have this marked as a draft because it currently only fixes async coroutines, and I think it make sense to try to fix gen/async gen coroutines before this is merged.

If the issue [mentioned](https://github.com/rust-lang/rust/pull/123106#issuecomment-2030794213) in the original PR is actually affecting someone, we can land this as is to remedy it.
-rw-r--r--compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs33
-rw-r--r--tests/ui/sanitizer/cfi-coroutine.rs41
2 files changed, 62 insertions, 12 deletions
diff --git a/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs b/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs
index 5963bd7c5f1..5f5d90f359a 100644
--- a/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs
+++ b/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs
@@ -1218,22 +1218,35 @@ pub fn typeid_for_instance<'tcx>(
                     let trait_id = tcx.fn_trait_kind_to_def_id(closure_args.kind()).unwrap();
                     let tuple_args =
                         tcx.instantiate_bound_regions_with_erased(closure_args.sig()).inputs()[0];
-                    (trait_id, tuple_args)
+                    (trait_id, Some(tuple_args))
                 }
-                ty::Coroutine(..) => (
-                    tcx.require_lang_item(LangItem::Coroutine, None),
-                    instance.args.as_coroutine().resume_ty(),
-                ),
+                ty::Coroutine(..) => match tcx.coroutine_kind(instance.def_id()).unwrap() {
+                    hir::CoroutineKind::Coroutine(..) => (
+                        tcx.require_lang_item(LangItem::Coroutine, None),
+                        Some(instance.args.as_coroutine().resume_ty()),
+                    ),
+                    hir::CoroutineKind::Desugared(desugaring, _) => {
+                        let lang_item = match desugaring {
+                            hir::CoroutineDesugaring::Async => LangItem::Future,
+                            hir::CoroutineDesugaring::AsyncGen => LangItem::AsyncIterator,
+                            hir::CoroutineDesugaring::Gen => LangItem::Iterator,
+                        };
+                        (tcx.require_lang_item(lang_item, None), None)
+                    }
+                },
                 ty::CoroutineClosure(..) => (
                     tcx.require_lang_item(LangItem::FnOnce, None),
-                    tcx.instantiate_bound_regions_with_erased(
-                        instance.args.as_coroutine_closure().coroutine_closure_sig(),
-                    )
-                    .tupled_inputs_ty,
+                    Some(
+                        tcx.instantiate_bound_regions_with_erased(
+                            instance.args.as_coroutine_closure().coroutine_closure_sig(),
+                        )
+                        .tupled_inputs_ty,
+                    ),
                 ),
                 x => bug!("Unexpected type kind for closure-like: {x:?}"),
             };
-            let trait_ref = ty::TraitRef::new(tcx, trait_id, [closure_ty, inputs]);
+            let concrete_args = tcx.mk_args_trait(closure_ty, inputs.map(Into::into));
+            let trait_ref = ty::TraitRef::new(tcx, trait_id, concrete_args);
             let invoke_ty = trait_object_ty(tcx, ty::Binder::dummy(trait_ref));
             let abstract_args = tcx.mk_args_trait(invoke_ty, trait_ref.args.into_iter().skip(1));
             // There should be exactly one method on this trait, and it should be the one we're
diff --git a/tests/ui/sanitizer/cfi-coroutine.rs b/tests/ui/sanitizer/cfi-coroutine.rs
index 24e59cf5b4d..5c6a489a7e8 100644
--- a/tests/ui/sanitizer/cfi-coroutine.rs
+++ b/tests/ui/sanitizer/cfi-coroutine.rs
@@ -3,6 +3,7 @@
 //@ revisions: cfi kcfi
 // FIXME(#122848) Remove only-linux once OSX CFI binaries work
 //@ only-linux
+//@ edition: 2024
 //@ [cfi] needs-sanitizer-cfi
 //@ [kcfi] needs-sanitizer-kcfi
 //@ compile-flags: -C target-feature=-crt-static
@@ -10,16 +11,22 @@
 //@ [cfi] compile-flags: -Z sanitizer=cfi
 //@ [kcfi] compile-flags: -Z sanitizer=kcfi
 //@ [kcfi] compile-flags: -C panic=abort -Z panic-abort-tests -C prefer-dynamic=off
-//@ compile-flags: --test
+//@ compile-flags: --test -Z unstable-options
 //@ run-pass
 
 #![feature(coroutines)]
 #![feature(coroutine_trait)]
+#![feature(noop_waker)]
+#![feature(gen_blocks)]
+#![feature(async_iterator)]
 
 use std::ops::{Coroutine, CoroutineState};
 use std::pin::{pin, Pin};
+use std::task::{Context, Poll, Waker};
+use std::async_iter::AsyncIterator;
 
-fn main() {
+#[test]
+fn general_coroutine() {
     let mut coro = |x: i32| {
         yield x;
         "done"
@@ -28,3 +35,33 @@ fn main() {
     assert_eq!(abstract_coro.as_mut().resume(2), CoroutineState::Yielded(2));
     assert_eq!(abstract_coro.as_mut().resume(0), CoroutineState::Complete("done"));
 }
+
+async fn async_fn() {}
+
+#[test]
+fn async_coroutine() {
+    let f: fn() -> Pin<Box<dyn Future<Output = ()>>> = || Box::pin(async_fn());
+    let _ = async { f().await; };
+    assert_eq!(f().as_mut().poll(&mut Context::from_waker(Waker::noop())), Poll::Ready(()));
+}
+
+async gen fn async_gen_fn() -> u8 {
+    yield 5;
+}
+
+#[test]
+fn async_gen_coroutine() {
+    let f: fn() -> Pin<Box<dyn AsyncIterator<Item = u8>>> = || Box::pin(async_gen_fn());
+    assert_eq!(f().as_mut().poll_next(&mut Context::from_waker(Waker::noop())),
+               Poll::Ready(Some(5)));
+}
+
+gen fn gen_fn() -> u8 {
+    yield 6;
+}
+
+#[test]
+fn gen_coroutine() {
+    let f: fn() -> Box<dyn Iterator<Item = u8>> = || Box::new(gen_fn());
+    assert_eq!(f().next(), Some(6));
+}