about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-01-19 21:28:37 +0000
committerMichael Goulet <michael@errs.io>2024-01-19 21:28:37 +0000
commitf2ef88ba064cd6799922e85ae0748d298ad436d1 (patch)
tree96f354c84666e663688dc78e563a4e7e6ab4c6a9
parent32ec40c68533f325a3c8fe787b77ef5c9e209b23 (diff)
downloadrust-f2ef88ba064cd6799922e85ae0748d298ad436d1.tar.gz
rust-f2ef88ba064cd6799922e85ae0748d298ad436d1.zip
Consolidate logic around resolving built-in coroutine trait impls
-rw-r--r--compiler/rustc_hir/src/lang_items.rs3
-rw-r--r--compiler/rustc_middle/src/ty/instance.rs50
-rw-r--r--compiler/rustc_span/src/symbol.rs1
-rw-r--r--compiler/rustc_ty_utils/src/instance.rs59
-rw-r--r--library/core/src/ops/coroutine.rs1
5 files changed, 56 insertions, 58 deletions
diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs
index 1cc1f11b3c8..85d10872b3d 100644
--- a/compiler/rustc_hir/src/lang_items.rs
+++ b/compiler/rustc_hir/src/lang_items.rs
@@ -213,8 +213,11 @@ language_item_table! {
     Iterator,                sym::iterator,            iterator_trait,             Target::Trait,          GenericRequirement::Exact(0);
     Future,                  sym::future_trait,        future_trait,               Target::Trait,          GenericRequirement::Exact(0);
     AsyncIterator,           sym::async_iterator,      async_iterator_trait,       Target::Trait,          GenericRequirement::Exact(0);
+
     CoroutineState,          sym::coroutine_state,     coroutine_state,            Target::Enum,           GenericRequirement::None;
     Coroutine,               sym::coroutine,           coroutine_trait,            Target::Trait,          GenericRequirement::Minimum(1);
+    CoroutineResume,         sym::coroutine_resume,    coroutine_resume,           Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
+
     Unpin,                   sym::unpin,               unpin_trait,                Target::Trait,          GenericRequirement::None;
     Pin,                     sym::pin,                 pin_type,                   Target::Struct,         GenericRequirement::None;
 
diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs
index dd41cb5a61f..b6c3c34078f 100644
--- a/compiler/rustc_middle/src/ty/instance.rs
+++ b/compiler/rustc_middle/src/ty/instance.rs
@@ -3,6 +3,7 @@ use crate::ty::print::{FmtPrinter, Printer};
 use crate::ty::{self, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable};
 use crate::ty::{EarlyBinder, GenericArgs, GenericArgsRef, TypeVisitableExt};
 use rustc_errors::ErrorGuaranteed;
+use rustc_hir as hir;
 use rustc_hir::def::Namespace;
 use rustc_hir::def_id::{CrateNum, DefId};
 use rustc_hir::lang_items::LangItem;
@@ -11,6 +12,7 @@ use rustc_macros::HashStable;
 use rustc_middle::ty::normalize_erasing_regions::NormalizationError;
 use rustc_span::Symbol;
 
+use std::assert_matches::assert_matches;
 use std::fmt;
 
 /// A monomorphized `InstanceDef`.
@@ -572,6 +574,54 @@ impl<'tcx> Instance<'tcx> {
         Some(Instance { def, args })
     }
 
+    pub fn try_resolve_item_for_coroutine(
+        tcx: TyCtxt<'tcx>,
+        trait_item_id: DefId,
+        trait_id: DefId,
+        rcvr_args: ty::GenericArgsRef<'tcx>,
+    ) -> Option<Instance<'tcx>> {
+        let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else {
+            return None;
+        };
+        let coroutine_kind = tcx.coroutine_kind(coroutine_def_id).unwrap();
+
+        let lang_items = tcx.lang_items();
+        let coroutine_callable_item = if Some(trait_id) == lang_items.future_trait() {
+            assert_matches!(
+                coroutine_kind,
+                hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)
+            );
+            hir::LangItem::FuturePoll
+        } else if Some(trait_id) == lang_items.iterator_trait() {
+            assert_matches!(
+                coroutine_kind,
+                hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)
+            );
+            hir::LangItem::IteratorNext
+        } else if Some(trait_id) == lang_items.async_iterator_trait() {
+            assert_matches!(
+                coroutine_kind,
+                hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)
+            );
+            hir::LangItem::AsyncIteratorPollNext
+        } else if Some(trait_id) == lang_items.coroutine_trait() {
+            assert_matches!(coroutine_kind, hir::CoroutineKind::Coroutine(_));
+            hir::LangItem::CoroutineResume
+        } else {
+            return None;
+        };
+
+        if tcx.lang_items().get(coroutine_callable_item) == Some(trait_item_id) {
+            Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args })
+        } else {
+            // All other methods should be defaulted methods of the built-in trait.
+            // This is important for `Iterator`'s combinators, but also useful for
+            // adding future default methods to `Future`, for instance.
+            debug_assert!(tcx.defaultness(trait_item_id).has_value());
+            Some(Instance::new(trait_item_id, rcvr_args))
+        }
+    }
+
     /// Depending on the kind of `InstanceDef`, the MIR body associated with an
     /// instance is expressed in terms of the generic parameters of `self.def_id()`, and in other
     /// cases the MIR body is expressed in terms of the types found in the substitution array.
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index 44795022cba..bebeecebbb6 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -600,6 +600,7 @@ symbols! {
         core_panic_macro,
         coroutine,
         coroutine_clone,
+        coroutine_resume,
         coroutine_state,
         coroutines,
         cosf32,
diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs
index 81d5304b812..e5e31f7caaa 100644
--- a/compiler/rustc_ty_utils/src/instance.rs
+++ b/compiler/rustc_ty_utils/src/instance.rs
@@ -245,63 +245,6 @@ fn resolve_associated_item<'tcx>(
                         span: tcx.def_span(trait_item_id),
                     })
                 }
-            } else if Some(trait_ref.def_id) == lang_items.future_trait() {
-                let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else {
-                    bug!()
-                };
-                if Some(trait_item_id) == tcx.lang_items().future_poll_fn() {
-                    // `Future::poll` is generated by the compiler.
-                    Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args })
-                } else {
-                    // All other methods are default methods of the `Future` trait.
-                    // (this assumes that `ImplSource::Builtin` is only used for methods on `Future`)
-                    debug_assert!(tcx.defaultness(trait_item_id).has_value());
-                    Some(Instance::new(trait_item_id, rcvr_args))
-                }
-            } else if Some(trait_ref.def_id) == lang_items.iterator_trait() {
-                let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else {
-                    bug!()
-                };
-                if Some(trait_item_id) == tcx.lang_items().next_fn() {
-                    // `Iterator::next` is generated by the compiler.
-                    Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
-                } else {
-                    // All other methods are default methods of the `Iterator` trait.
-                    // (this assumes that `ImplSource::Builtin` is only used for methods on `Iterator`)
-                    debug_assert!(tcx.defaultness(trait_item_id).has_value());
-                    Some(Instance::new(trait_item_id, rcvr_args))
-                }
-            } else if Some(trait_ref.def_id) == lang_items.async_iterator_trait() {
-                let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else {
-                    bug!()
-                };
-
-                if cfg!(debug_assertions) && tcx.item_name(trait_item_id) != sym::poll_next {
-                    span_bug!(
-                        tcx.def_span(coroutine_def_id),
-                        "no definition for `{trait_ref}::{}` for built-in coroutine type",
-                        tcx.item_name(trait_item_id)
-                    )
-                }
-
-                // `AsyncIterator::poll_next` is generated by the compiler.
-                Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
-            } else if Some(trait_ref.def_id) == lang_items.coroutine_trait() {
-                let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else {
-                    bug!()
-                };
-                if cfg!(debug_assertions) && tcx.item_name(trait_item_id) != sym::resume {
-                    // For compiler developers who'd like to add new items to `Coroutine`,
-                    // you either need to generate a shim body, or perhaps return
-                    // `InstanceDef::Item` pointing to a trait default method body if
-                    // it is given a default implementation by the trait.
-                    span_bug!(
-                        tcx.def_span(coroutine_def_id),
-                        "no definition for `{trait_ref}::{}` for built-in coroutine type",
-                        tcx.item_name(trait_item_id)
-                    )
-                }
-                Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
             } else if tcx.fn_trait_kind_from_def_id(trait_ref.def_id).is_some() {
                 // FIXME: This doesn't check for malformed libcore that defines, e.g.,
                 // `trait Fn { fn call_once(&self) { .. } }`. This is mostly for extension
@@ -334,7 +277,7 @@ fn resolve_associated_item<'tcx>(
                     ),
                 }
             } else {
-                None
+                Instance::try_resolve_item_for_coroutine(tcx, trait_item_id, trait_id, rcvr_args)
             }
         }
         traits::ImplSource::Param(..)
diff --git a/library/core/src/ops/coroutine.rs b/library/core/src/ops/coroutine.rs
index e58c9068af8..6faded76a4a 100644
--- a/library/core/src/ops/coroutine.rs
+++ b/library/core/src/ops/coroutine.rs
@@ -111,6 +111,7 @@ pub trait Coroutine<R = ()> {
     /// been returned previously. While coroutine literals in the language are
     /// guaranteed to panic on resuming after `Complete`, this is not guaranteed
     /// for all implementations of the `Coroutine` trait.
+    #[cfg_attr(not(bootstrap), lang = "coroutine_resume")]
     fn resume(self: Pin<&mut Self>, arg: R) -> CoroutineState<Self::Yield, Self::Return>;
 }