about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_hir/src/lang_items.rs13
-rw-r--r--compiler/rustc_middle/src/middle/lang_items.rs4
-rw-r--r--compiler/rustc_middle/src/ty/context.rs109
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs117
-rw-r--r--compiler/rustc_type_ir/src/interner.rs6
-rw-r--r--compiler/rustc_type_ir/src/lang_items.rs6
6 files changed, 164 insertions, 91 deletions
diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs
index 3c44acb1657..30c0e40206a 100644
--- a/compiler/rustc_hir/src/lang_items.rs
+++ b/compiler/rustc_hir/src/lang_items.rs
@@ -11,6 +11,7 @@ use crate::def_id::DefId;
 use crate::{MethodKind, Target};
 
 use rustc_ast as ast;
+use rustc_data_structures::fx::FxIndexMap;
 use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
 use rustc_macros::{Decodable, Encodable, HashStable_Generic};
 use rustc_span::symbol::{kw, sym, Symbol};
@@ -23,6 +24,7 @@ pub struct LanguageItems {
     /// Mappings from lang items to their possibly found [`DefId`]s.
     /// The index corresponds to the order in [`LangItem`].
     items: [Option<DefId>; std::mem::variant_count::<LangItem>()],
+    reverse_items: FxIndexMap<DefId, LangItem>,
     /// Lang items that were not found during collection.
     pub missing: Vec<LangItem>,
 }
@@ -30,7 +32,11 @@ pub struct LanguageItems {
 impl LanguageItems {
     /// Construct an empty collection of lang items and no missing ones.
     pub fn new() -> Self {
-        Self { items: [None; std::mem::variant_count::<LangItem>()], missing: Vec::new() }
+        Self {
+            items: [None; std::mem::variant_count::<LangItem>()],
+            reverse_items: FxIndexMap::default(),
+            missing: Vec::new(),
+        }
     }
 
     pub fn get(&self, item: LangItem) -> Option<DefId> {
@@ -39,6 +45,11 @@ impl LanguageItems {
 
     pub fn set(&mut self, item: LangItem, def_id: DefId) {
         self.items[item as usize] = Some(def_id);
+        self.reverse_items.insert(def_id, item);
+    }
+
+    pub fn from_def_id(&self, def_id: DefId) -> Option<LangItem> {
+        self.reverse_items.get(&def_id).copied()
     }
 
     pub fn iter(&self) -> impl Iterator<Item = (LangItem, DefId)> + '_ {
diff --git a/compiler/rustc_middle/src/middle/lang_items.rs b/compiler/rustc_middle/src/middle/lang_items.rs
index e76d7af6e4a..a0c9af436e2 100644
--- a/compiler/rustc_middle/src/middle/lang_items.rs
+++ b/compiler/rustc_middle/src/middle/lang_items.rs
@@ -27,6 +27,10 @@ impl<'tcx> TyCtxt<'tcx> {
         self.lang_items().get(lang_item) == Some(def_id)
     }
 
+    pub fn as_lang_item(self, def_id: DefId) -> Option<LangItem> {
+        self.lang_items().from_def_id(def_id)
+    }
+
     /// Given a [`DefId`] of one of the [`Fn`], [`FnMut`] or [`FnOnce`] traits,
     /// returns a corresponding [`ty::ClosureKind`].
     /// For any other [`DefId`] return `None`.
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index 0771b0aa725..055749ba3a3 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -366,6 +366,10 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
         self.is_lang_item(def_id, trait_lang_item_to_lang_item(lang_item))
     }
 
+    fn as_lang_item(self, def_id: DefId) -> Option<TraitSolverLangItem> {
+        lang_item_to_trait_lang_item(self.lang_items().from_def_id(def_id)?)
+    }
+
     fn associated_type_def_ids(self, def_id: DefId) -> impl IntoIterator<Item = DefId> {
         self.associated_items(def_id)
             .in_definition_order()
@@ -522,14 +526,6 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
         self.trait_def(trait_def_id).implement_via_object
     }
 
-    fn fn_trait_kind_from_def_id(self, trait_def_id: DefId) -> Option<ty::ClosureKind> {
-        self.fn_trait_kind_from_def_id(trait_def_id)
-    }
-
-    fn async_fn_trait_kind_from_def_id(self, trait_def_id: DefId) -> Option<ty::ClosureKind> {
-        self.async_fn_trait_kind_from_def_id(trait_def_id)
-    }
-
     fn supertrait_def_ids(self, trait_def_id: DefId) -> impl IntoIterator<Item = DefId> {
         self.supertrait_def_ids(trait_def_id)
     }
@@ -573,46 +569,69 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
     }
 }
 
-fn trait_lang_item_to_lang_item(lang_item: TraitSolverLangItem) -> LangItem {
-    match lang_item {
-        TraitSolverLangItem::AsyncDestruct => LangItem::AsyncDestruct,
-        TraitSolverLangItem::AsyncFnKindHelper => LangItem::AsyncFnKindHelper,
-        TraitSolverLangItem::AsyncFnKindUpvars => LangItem::AsyncFnKindUpvars,
-        TraitSolverLangItem::AsyncFnOnceOutput => LangItem::AsyncFnOnceOutput,
-        TraitSolverLangItem::AsyncIterator => LangItem::AsyncIterator,
-        TraitSolverLangItem::CallOnceFuture => LangItem::CallOnceFuture,
-        TraitSolverLangItem::CallRefFuture => LangItem::CallRefFuture,
-        TraitSolverLangItem::Clone => LangItem::Clone,
-        TraitSolverLangItem::Copy => LangItem::Copy,
-        TraitSolverLangItem::Coroutine => LangItem::Coroutine,
-        TraitSolverLangItem::CoroutineReturn => LangItem::CoroutineReturn,
-        TraitSolverLangItem::CoroutineYield => LangItem::CoroutineYield,
-        TraitSolverLangItem::Destruct => LangItem::Destruct,
-        TraitSolverLangItem::DiscriminantKind => LangItem::DiscriminantKind,
-        TraitSolverLangItem::DynMetadata => LangItem::DynMetadata,
-        TraitSolverLangItem::EffectsMaybe => LangItem::EffectsMaybe,
-        TraitSolverLangItem::EffectsIntersection => LangItem::EffectsIntersection,
-        TraitSolverLangItem::EffectsIntersectionOutput => LangItem::EffectsIntersectionOutput,
-        TraitSolverLangItem::EffectsNoRuntime => LangItem::EffectsNoRuntime,
-        TraitSolverLangItem::EffectsRuntime => LangItem::EffectsRuntime,
-        TraitSolverLangItem::FnPtrTrait => LangItem::FnPtrTrait,
-        TraitSolverLangItem::FusedIterator => LangItem::FusedIterator,
-        TraitSolverLangItem::Future => LangItem::Future,
-        TraitSolverLangItem::FutureOutput => LangItem::FutureOutput,
-        TraitSolverLangItem::Iterator => LangItem::Iterator,
-        TraitSolverLangItem::Metadata => LangItem::Metadata,
-        TraitSolverLangItem::Option => LangItem::Option,
-        TraitSolverLangItem::PointeeTrait => LangItem::PointeeTrait,
-        TraitSolverLangItem::PointerLike => LangItem::PointerLike,
-        TraitSolverLangItem::Poll => LangItem::Poll,
-        TraitSolverLangItem::Sized => LangItem::Sized,
-        TraitSolverLangItem::TransmuteTrait => LangItem::TransmuteTrait,
-        TraitSolverLangItem::Tuple => LangItem::Tuple,
-        TraitSolverLangItem::Unpin => LangItem::Unpin,
-        TraitSolverLangItem::Unsize => LangItem::Unsize,
+macro_rules! bidirectional_lang_item_map {
+    ($($name:ident),+ $(,)?) => {
+        fn trait_lang_item_to_lang_item(lang_item: TraitSolverLangItem) -> LangItem {
+            match lang_item {
+                $(TraitSolverLangItem::$name => LangItem::$name,)+
+            }
+        }
+
+        fn lang_item_to_trait_lang_item(lang_item: LangItem) -> Option<TraitSolverLangItem> {
+            Some(match lang_item {
+                $(LangItem::$name => TraitSolverLangItem::$name,)+
+                _ => return None,
+            })
+        }
     }
 }
 
+bidirectional_lang_item_map! {
+// tidy-alphabetical-start
+    AsyncDestruct,
+    AsyncFn,
+    AsyncFnKindHelper,
+    AsyncFnKindUpvars,
+    AsyncFnMut,
+    AsyncFnOnce,
+    AsyncFnOnceOutput,
+    AsyncIterator,
+    CallOnceFuture,
+    CallRefFuture,
+    Clone,
+    Copy,
+    Coroutine,
+    CoroutineReturn,
+    CoroutineYield,
+    Destruct,
+    DiscriminantKind,
+    DynMetadata,
+    EffectsIntersection,
+    EffectsIntersectionOutput,
+    EffectsMaybe,
+    EffectsNoRuntime,
+    EffectsRuntime,
+    Fn,
+    FnMut,
+    FnOnce,
+    FnPtrTrait,
+    FusedIterator,
+    Future,
+    FutureOutput,
+    Iterator,
+    Metadata,
+    Option,
+    PointeeTrait,
+    PointerLike,
+    Poll,
+    Sized,
+    TransmuteTrait,
+    Tuple,
+    Unpin,
+    Unsize,
+// tidy-alphabetical-end
+}
+
 impl<'tcx> rustc_type_ir::inherent::DefId<TyCtxt<'tcx>> for DefId {
     fn as_local(self) -> Option<LocalDefId> {
         self.as_local()
diff --git a/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs b/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs
index 6ee684605ac..38a4f7dfe25 100644
--- a/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs
@@ -387,48 +387,83 @@ where
             G::consider_auto_trait_candidate(self, goal)
         } else if cx.trait_is_alias(trait_def_id) {
             G::consider_trait_alias_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Sized) {
-            G::consider_builtin_sized_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Copy)
-            || cx.is_lang_item(trait_def_id, TraitSolverLangItem::Clone)
-        {
-            G::consider_builtin_copy_clone_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::PointerLike) {
-            G::consider_builtin_pointer_like_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::FnPtrTrait) {
-            G::consider_builtin_fn_ptr_trait_candidate(self, goal)
-        } else if let Some(kind) = self.cx().fn_trait_kind_from_def_id(trait_def_id) {
-            G::consider_builtin_fn_trait_candidates(self, goal, kind)
-        } else if let Some(kind) = self.cx().async_fn_trait_kind_from_def_id(trait_def_id) {
-            G::consider_builtin_async_fn_trait_candidates(self, goal, kind)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::AsyncFnKindHelper) {
-            G::consider_builtin_async_fn_kind_helper_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Tuple) {
-            G::consider_builtin_tuple_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::PointeeTrait) {
-            G::consider_builtin_pointee_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Future) {
-            G::consider_builtin_future_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Iterator) {
-            G::consider_builtin_iterator_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::FusedIterator) {
-            G::consider_builtin_fused_iterator_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::AsyncIterator) {
-            G::consider_builtin_async_iterator_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Coroutine) {
-            G::consider_builtin_coroutine_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::DiscriminantKind) {
-            G::consider_builtin_discriminant_kind_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::AsyncDestruct) {
-            G::consider_builtin_async_destruct_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Destruct) {
-            G::consider_builtin_destruct_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::TransmuteTrait) {
-            G::consider_builtin_transmute_candidate(self, goal)
-        } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::EffectsIntersection) {
-            G::consider_builtin_effects_intersection_candidate(self, goal)
         } else {
-            Err(NoSolution)
+            match cx.as_lang_item(trait_def_id) {
+                Some(TraitSolverLangItem::Sized) => G::consider_builtin_sized_candidate(self, goal),
+                Some(TraitSolverLangItem::Copy | TraitSolverLangItem::Clone) => {
+                    G::consider_builtin_copy_clone_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::Fn) => {
+                    G::consider_builtin_fn_trait_candidates(self, goal, ty::ClosureKind::Fn)
+                }
+                Some(TraitSolverLangItem::FnMut) => {
+                    G::consider_builtin_fn_trait_candidates(self, goal, ty::ClosureKind::FnMut)
+                }
+                Some(TraitSolverLangItem::FnOnce) => {
+                    G::consider_builtin_fn_trait_candidates(self, goal, ty::ClosureKind::FnOnce)
+                }
+                Some(TraitSolverLangItem::AsyncFn) => {
+                    G::consider_builtin_async_fn_trait_candidates(self, goal, ty::ClosureKind::Fn)
+                }
+                Some(TraitSolverLangItem::AsyncFnMut) => {
+                    G::consider_builtin_async_fn_trait_candidates(
+                        self,
+                        goal,
+                        ty::ClosureKind::FnMut,
+                    )
+                }
+                Some(TraitSolverLangItem::AsyncFnOnce) => {
+                    G::consider_builtin_async_fn_trait_candidates(
+                        self,
+                        goal,
+                        ty::ClosureKind::FnOnce,
+                    )
+                }
+                Some(TraitSolverLangItem::PointerLike) => {
+                    G::consider_builtin_pointer_like_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::FnPtrTrait) => {
+                    G::consider_builtin_fn_ptr_trait_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::AsyncFnKindHelper) => {
+                    G::consider_builtin_async_fn_kind_helper_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::Tuple) => G::consider_builtin_tuple_candidate(self, goal),
+                Some(TraitSolverLangItem::PointeeTrait) => {
+                    G::consider_builtin_pointee_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::Future) => {
+                    G::consider_builtin_future_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::Iterator) => {
+                    G::consider_builtin_iterator_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::FusedIterator) => {
+                    G::consider_builtin_fused_iterator_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::AsyncIterator) => {
+                    G::consider_builtin_async_iterator_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::Coroutine) => {
+                    G::consider_builtin_coroutine_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::DiscriminantKind) => {
+                    G::consider_builtin_discriminant_kind_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::AsyncDestruct) => {
+                    G::consider_builtin_async_destruct_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::Destruct) => {
+                    G::consider_builtin_destruct_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::TransmuteTrait) => {
+                    G::consider_builtin_transmute_candidate(self, goal)
+                }
+                Some(TraitSolverLangItem::EffectsIntersection) => {
+                    G::consider_builtin_effects_intersection_candidate(self, goal)
+                }
+                _ => Err(NoSolution),
+            }
         };
 
         candidates.extend(result);
diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs
index eaa3ab7ce43..db97bdca382 100644
--- a/compiler/rustc_type_ir/src/interner.rs
+++ b/compiler/rustc_type_ir/src/interner.rs
@@ -220,6 +220,8 @@ pub trait Interner:
 
     fn is_lang_item(self, def_id: Self::DefId, lang_item: TraitSolverLangItem) -> bool;
 
+    fn as_lang_item(self, def_id: Self::DefId) -> Option<TraitSolverLangItem>;
+
     fn associated_type_def_ids(self, def_id: Self::DefId) -> impl IntoIterator<Item = Self::DefId>;
 
     fn for_each_relevant_impl(
@@ -245,10 +247,6 @@ pub trait Interner:
 
     fn trait_may_be_implemented_via_object(self, trait_def_id: Self::DefId) -> bool;
 
-    fn fn_trait_kind_from_def_id(self, trait_def_id: Self::DefId) -> Option<ty::ClosureKind>;
-
-    fn async_fn_trait_kind_from_def_id(self, trait_def_id: Self::DefId) -> Option<ty::ClosureKind>;
-
     fn supertrait_def_ids(self, trait_def_id: Self::DefId)
     -> impl IntoIterator<Item = Self::DefId>;
 
diff --git a/compiler/rustc_type_ir/src/lang_items.rs b/compiler/rustc_type_ir/src/lang_items.rs
index cf00c37caa2..265a4118827 100644
--- a/compiler/rustc_type_ir/src/lang_items.rs
+++ b/compiler/rustc_type_ir/src/lang_items.rs
@@ -3,8 +3,11 @@
 pub enum TraitSolverLangItem {
     // tidy-alphabetical-start
     AsyncDestruct,
+    AsyncFn,
     AsyncFnKindHelper,
     AsyncFnKindUpvars,
+    AsyncFnMut,
+    AsyncFnOnce,
     AsyncFnOnceOutput,
     AsyncIterator,
     CallOnceFuture,
@@ -22,6 +25,9 @@ pub enum TraitSolverLangItem {
     EffectsMaybe,
     EffectsNoRuntime,
     EffectsRuntime,
+    Fn,
+    FnMut,
+    FnOnce,
     FnPtrTrait,
     FusedIterator,
     Future,