about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_hir_typeck/src/method/suggest.rs3
-rw-r--r--compiler/rustc_infer/src/traits/util.rs63
-rw-r--r--compiler/rustc_lint/src/deref_into_dyn_supertrait.rs3
-rw-r--r--compiler/rustc_middle/src/traits/util.rs7
-rw-r--r--compiler/rustc_middle/src/ty/print/pretty.rs5
5 files changed, 39 insertions, 42 deletions
diff --git a/compiler/rustc_hir_typeck/src/method/suggest.rs b/compiler/rustc_hir_typeck/src/method/suggest.rs
index 6b0dc73d49c..3ed22e095e8 100644
--- a/compiler/rustc_hir_typeck/src/method/suggest.rs
+++ b/compiler/rustc_hir_typeck/src/method/suggest.rs
@@ -26,7 +26,6 @@ use rustc_infer::infer::{
     RegionVariableOrigin,
 };
 use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind};
-use rustc_middle::traits::util::supertraits;
 use rustc_middle::ty::fast_reject::DeepRejectCtxt;
 use rustc_middle::ty::fast_reject::{simplify_type, TreatParams};
 use rustc_middle::ty::print::{with_crate_prefix, with_forced_trimmed_paths};
@@ -40,7 +39,7 @@ use rustc_trait_selection::traits::error_reporting::on_unimplemented::OnUnimplem
 use rustc_trait_selection::traits::error_reporting::on_unimplemented::TypeErrCtxtExt as _;
 use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _;
 use rustc_trait_selection::traits::{
-    FulfillmentError, Obligation, ObligationCause, ObligationCauseCode,
+    supertraits, FulfillmentError, Obligation, ObligationCause, ObligationCauseCode,
 };
 use std::borrow::Cow;
 
diff --git a/compiler/rustc_infer/src/traits/util.rs b/compiler/rustc_infer/src/traits/util.rs
index 93dfbe63bcd..3c566e0dd6d 100644
--- a/compiler/rustc_infer/src/traits/util.rs
+++ b/compiler/rustc_infer/src/traits/util.rs
@@ -2,7 +2,7 @@ use smallvec::smallvec;
 
 use crate::infer::outlives::components::{push_outlives_components, Component};
 use crate::traits::{self, Obligation, PredicateObligation};
-use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
+use rustc_data_structures::fx::FxHashSet;
 use rustc_middle::ty::{self, ToPredicate, Ty, TyCtxt};
 use rustc_span::symbol::Ident;
 use rustc_span::Span;
@@ -76,7 +76,13 @@ impl<'tcx> Extend<ty::Predicate<'tcx>> for PredicateSet<'tcx> {
 pub struct Elaborator<'tcx, O> {
     stack: Vec<O>,
     visited: PredicateSet<'tcx>,
-    only_self: bool,
+    mode: Filter,
+}
+
+enum Filter {
+    All,
+    OnlySelf,
+    OnlySelfThatDefines(Ident),
 }
 
 /// Describes how to elaborate an obligation into a sub-obligation.
@@ -224,7 +230,7 @@ pub fn elaborate<'tcx, O: Elaboratable<'tcx>>(
     obligations: impl IntoIterator<Item = O>,
 ) -> Elaborator<'tcx, O> {
     let mut elaborator =
-        Elaborator { stack: Vec::new(), visited: PredicateSet::new(tcx), only_self: false };
+        Elaborator { stack: Vec::new(), visited: PredicateSet::new(tcx), mode: Filter::All };
     elaborator.extend_deduped(obligations);
     elaborator
 }
@@ -242,7 +248,13 @@ impl<'tcx, O: Elaboratable<'tcx>> Elaborator<'tcx, O> {
     /// Filter to only the supertraits of trait predicates, i.e. only the predicates
     /// that have `Self` as their self type, instead of all implied predicates.
     pub fn filter_only_self(mut self) -> Self {
-        self.only_self = true;
+        self.mode = Filter::OnlySelf;
+        self
+    }
+
+    /// Filter to only the supertraits of trait predicates that define the assoc_ty.
+    pub fn filter_only_self_that_defines(mut self, assoc_ty: Ident) -> Self {
+        self.mode = Filter::OnlySelfThatDefines(assoc_ty);
         self
     }
 
@@ -257,10 +269,12 @@ impl<'tcx, O: Elaboratable<'tcx>> Elaborator<'tcx, O> {
                     return;
                 }
                 // Get predicates implied by the trait, or only super predicates if we only care about self predicates.
-                let predicates = if self.only_self {
-                    tcx.super_predicates_of(data.def_id())
-                } else {
-                    tcx.implied_predicates_of(data.def_id())
+                let predicates = match self.mode {
+                    Filter::All => tcx.implied_predicates_of(data.def_id()),
+                    Filter::OnlySelf => tcx.super_predicates_of(data.def_id()),
+                    Filter::OnlySelfThatDefines(ident) => {
+                        tcx.super_predicates_that_define_assoc_item((data.def_id(), ident))
+                    }
                 };
 
                 let obligations =
@@ -409,14 +423,14 @@ impl<'tcx, O: Elaboratable<'tcx>> Iterator for Elaborator<'tcx, O> {
 pub fn supertraits<'tcx>(
     tcx: TyCtxt<'tcx>,
     trait_ref: ty::PolyTraitRef<'tcx>,
-) -> impl Iterator<Item = ty::PolyTraitRef<'tcx>> {
+) -> FilterToTraits<Elaborator<'tcx, ty::Predicate<'tcx>>> {
     elaborate(tcx, [trait_ref.to_predicate(tcx)]).filter_only_self().filter_to_traits()
 }
 
 pub fn transitive_bounds<'tcx>(
     tcx: TyCtxt<'tcx>,
     trait_refs: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
-) -> impl Iterator<Item = ty::PolyTraitRef<'tcx>> {
+) -> FilterToTraits<Elaborator<'tcx, ty::Predicate<'tcx>>> {
     elaborate(tcx, trait_refs.map(|trait_ref| trait_ref.to_predicate(tcx)))
         .filter_only_self()
         .filter_to_traits()
@@ -429,31 +443,12 @@ pub fn transitive_bounds<'tcx>(
 /// `T::Item` and helps to avoid cycle errors (see e.g. #35237).
 pub fn transitive_bounds_that_define_assoc_item<'tcx>(
     tcx: TyCtxt<'tcx>,
-    bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
+    trait_refs: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
     assoc_name: Ident,
-) -> impl Iterator<Item = ty::PolyTraitRef<'tcx>> {
-    let mut stack: Vec<_> = bounds.collect();
-    let mut visited = FxIndexSet::default();
-
-    std::iter::from_fn(move || {
-        while let Some(trait_ref) = stack.pop() {
-            let anon_trait_ref = tcx.anonymize_bound_vars(trait_ref);
-            if visited.insert(anon_trait_ref) {
-                let super_predicates =
-                    tcx.super_predicates_that_define_assoc_item((trait_ref.def_id(), assoc_name));
-                for (super_predicate, _) in super_predicates.predicates {
-                    let subst_predicate = super_predicate.subst_supertrait(tcx, &trait_ref);
-                    if let Some(binder) = subst_predicate.as_trait_clause() {
-                        stack.push(binder.map_bound(|t| t.trait_ref));
-                    }
-                }
-
-                return Some(trait_ref);
-            }
-        }
-
-        return None;
-    })
+) -> FilterToTraits<Elaborator<'tcx, ty::Predicate<'tcx>>> {
+    elaborate(tcx, trait_refs.map(|trait_ref| trait_ref.to_predicate(tcx)))
+        .filter_only_self_that_defines(assoc_name)
+        .filter_to_traits()
 }
 
 ///////////////////////////////////////////////////////////////////////////
diff --git a/compiler/rustc_lint/src/deref_into_dyn_supertrait.rs b/compiler/rustc_lint/src/deref_into_dyn_supertrait.rs
index bc11082d278..d2d99bc0da8 100644
--- a/compiler/rustc_lint/src/deref_into_dyn_supertrait.rs
+++ b/compiler/rustc_lint/src/deref_into_dyn_supertrait.rs
@@ -4,9 +4,10 @@ use crate::{
 };
 
 use rustc_hir as hir;
-use rustc_middle::{traits::util::supertraits, ty};
+use rustc_middle::ty;
 use rustc_session::lint::FutureIncompatibilityReason;
 use rustc_span::sym;
+use rustc_trait_selection::traits::supertraits;
 
 declare_lint! {
     /// The `deref_into_dyn_supertrait` lint is output whenever there is a use of the
diff --git a/compiler/rustc_middle/src/traits/util.rs b/compiler/rustc_middle/src/traits/util.rs
index 05c06efaf16..b4054f8ff5e 100644
--- a/compiler/rustc_middle/src/traits/util.rs
+++ b/compiler/rustc_middle/src/traits/util.rs
@@ -3,9 +3,10 @@ use rustc_data_structures::fx::FxHashSet;
 use crate::ty::{PolyTraitRef, TyCtxt};
 
 /// Given a PolyTraitRef, get the PolyTraitRefs of the trait's (transitive) supertraits.
-///
-/// A simplified version of the same function at `rustc_infer::traits::util::supertraits`.
-pub fn supertraits<'tcx>(
+/// This only exists in `rustc_middle` because the more powerful elaborator depends on
+/// `rustc_infer` for elaborating outlives bounds -- this should only be used for pretty
+/// printing.
+pub fn supertraits_for_pretty_printing<'tcx>(
     tcx: TyCtxt<'tcx>,
     trait_ref: PolyTraitRef<'tcx>,
 ) -> impl Iterator<Item = PolyTraitRef<'tcx>> {
diff --git a/compiler/rustc_middle/src/ty/print/pretty.rs b/compiler/rustc_middle/src/ty/print/pretty.rs
index d8010c71497..117aa69596c 100644
--- a/compiler/rustc_middle/src/ty/print/pretty.rs
+++ b/compiler/rustc_middle/src/ty/print/pretty.rs
@@ -1,6 +1,7 @@
 use crate::mir::interpret::{AllocRange, GlobalAlloc, Pointer, Provenance, Scalar};
 use crate::query::IntoQueryParam;
 use crate::query::Providers;
+use crate::traits::util::supertraits_for_pretty_printing;
 use crate::ty::{
     self, ConstInt, ParamConst, ScalarInt, Term, TermKind, Ty, TyCtxt, TypeFoldable,
     TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitableExt,
@@ -1150,14 +1151,14 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
                 entry.has_fn_once = true;
                 return;
             } else if Some(trait_def_id) == self.tcx().lang_items().fn_mut_trait() {
-                let super_trait_ref = crate::traits::util::supertraits(self.tcx(), trait_ref)
+                let super_trait_ref = supertraits_for_pretty_printing(self.tcx(), trait_ref)
                     .find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
                     .unwrap();
 
                 fn_traits.entry(super_trait_ref).or_default().fn_mut_trait_ref = Some(trait_ref);
                 return;
             } else if Some(trait_def_id) == self.tcx().lang_items().fn_trait() {
-                let super_trait_ref = crate::traits::util::supertraits(self.tcx(), trait_ref)
+                let super_trait_ref = supertraits_for_pretty_printing(self.tcx(), trait_ref)
                     .find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
                     .unwrap();