about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-07-06 12:33:03 -0400
committerMichael Goulet <michael@errs.io>2024-07-07 11:28:01 -0400
commit66eb346770a97fe96c02a79a740cb151e5232010 (patch)
treee64c3e540f40bb32d13dc3724403197a4bebe00c
parent90423a7abbddd98b6fbb22e9780991c736b51ca4 (diff)
downloadrust-66eb346770a97fe96c02a79a740cb151e5232010.tar.gz
rust-66eb346770a97fe96c02a79a740cb151e5232010.zip
Get rid of the redundant elaboration in middle
-rw-r--r--compiler/rustc_middle/src/traits/mod.rs1
-rw-r--r--compiler/rustc_middle/src/traits/util.rs62
-rw-r--r--compiler/rustc_middle/src/ty/context.rs26
-rw-r--r--compiler/rustc_middle/src/ty/print/pretty.rs15
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/trait_goals.rs5
-rw-r--r--compiler/rustc_type_ir/src/elaborate.rs28
-rw-r--r--compiler/rustc_type_ir/src/interner.rs3
7 files changed, 40 insertions, 100 deletions
diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs
index b4e3fae1b43..b74775142e4 100644
--- a/compiler/rustc_middle/src/traits/mod.rs
+++ b/compiler/rustc_middle/src/traits/mod.rs
@@ -7,7 +7,6 @@ pub mod select;
 pub mod solve;
 pub mod specialization_graph;
 mod structural_impls;
-pub mod util;
 
 use crate::mir::ConstraintCategory;
 use crate::ty::abstract_const::NotConstEvaluatable;
diff --git a/compiler/rustc_middle/src/traits/util.rs b/compiler/rustc_middle/src/traits/util.rs
deleted file mode 100644
index 7437be7a74c..00000000000
--- a/compiler/rustc_middle/src/traits/util.rs
+++ /dev/null
@@ -1,62 +0,0 @@
-use rustc_data_structures::fx::FxHashSet;
-
-use crate::ty::{Clause, PolyTraitRef, ToPolyTraitRef, TyCtxt, Upcast};
-
-/// Given a [`PolyTraitRef`], get the [`Clause`]s implied by the trait's definition.
-///
-/// This only exists in `rustc_middle` because the more powerful elaborator depends on
-/// `rustc_infer` for elaborating outlives bounds -- this should only be used for pretty
-/// printing.
-pub fn super_predicates_for_pretty_printing<'tcx>(
-    tcx: TyCtxt<'tcx>,
-    trait_ref: PolyTraitRef<'tcx>,
-) -> impl Iterator<Item = Clause<'tcx>> {
-    let clause = trait_ref.upcast(tcx);
-    Elaborator { tcx, visited: FxHashSet::from_iter([clause]), stack: vec![clause] }
-}
-
-/// Like [`super_predicates_for_pretty_printing`], except it only returns traits and filters out
-/// all other [`Clause`]s.
-pub fn supertraits_for_pretty_printing<'tcx>(
-    tcx: TyCtxt<'tcx>,
-    trait_ref: PolyTraitRef<'tcx>,
-) -> impl Iterator<Item = PolyTraitRef<'tcx>> {
-    super_predicates_for_pretty_printing(tcx, trait_ref).filter_map(|clause| {
-        clause.as_trait_clause().map(|trait_clause| trait_clause.to_poly_trait_ref())
-    })
-}
-
-struct Elaborator<'tcx> {
-    tcx: TyCtxt<'tcx>,
-    visited: FxHashSet<Clause<'tcx>>,
-    stack: Vec<Clause<'tcx>>,
-}
-
-impl<'tcx> Elaborator<'tcx> {
-    fn elaborate(&mut self, trait_ref: PolyTraitRef<'tcx>) {
-        let super_predicates =
-            self.tcx.explicit_super_predicates_of(trait_ref.def_id()).predicates.iter().filter_map(
-                |&(pred, _)| {
-                    let clause = pred.instantiate_supertrait(self.tcx, trait_ref);
-                    self.visited.insert(clause).then_some(clause)
-                },
-            );
-
-        self.stack.extend(super_predicates);
-    }
-}
-
-impl<'tcx> Iterator for Elaborator<'tcx> {
-    type Item = Clause<'tcx>;
-
-    fn next(&mut self) -> Option<Clause<'tcx>> {
-        if let Some(clause) = self.stack.pop() {
-            if let Some(trait_clause) = clause.as_trait_clause() {
-                self.elaborate(trait_clause.to_poly_trait_ref());
-            }
-            Some(clause)
-        } else {
-            None
-        }
-    }
-}
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index 59ab5d79139..c249579c5d7 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -37,7 +37,7 @@ use crate::ty::{GenericArg, GenericArgs, GenericArgsRef};
 use rustc_ast::{self as ast, attr};
 use rustc_data_structures::defer;
 use rustc_data_structures::fingerprint::Fingerprint;
-use rustc_data_structures::fx::{FxHashMap, FxHashSet};
+use rustc_data_structures::fx::FxHashMap;
 use rustc_data_structures::intern::Interned;
 use rustc_data_structures::profiling::SelfProfilerRef;
 use rustc_data_structures::sharded::{IntoPointer, ShardedHashMap};
@@ -532,10 +532,6 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
         self.trait_def(trait_def_id).implement_via_object
     }
 
-    fn supertrait_def_ids(self, trait_def_id: DefId) -> impl IntoIterator<Item = DefId> {
-        self.supertrait_def_ids(trait_def_id)
-    }
-
     fn delay_bug(self, msg: impl ToString) -> ErrorGuaranteed {
         self.dcx().span_delayed_bug(DUMMY_SP, msg.to_string())
     }
@@ -2495,25 +2491,7 @@ impl<'tcx> TyCtxt<'tcx> {
     /// to identify which traits may define a given associated type to help avoid cycle errors,
     /// and to make size estimates for vtable layout computation.
     pub fn supertrait_def_ids(self, trait_def_id: DefId) -> impl Iterator<Item = DefId> + 'tcx {
-        let mut set = FxHashSet::default();
-        let mut stack = vec![trait_def_id];
-
-        set.insert(trait_def_id);
-
-        iter::from_fn(move || -> Option<DefId> {
-            let trait_did = stack.pop()?;
-            let generic_predicates = self.explicit_super_predicates_of(trait_did);
-
-            for (predicate, _) in generic_predicates.predicates {
-                if let ty::ClauseKind::Trait(data) = predicate.kind().skip_binder() {
-                    if set.insert(data.def_id()) {
-                        stack.push(data.def_id());
-                    }
-                }
-            }
-
-            Some(trait_did)
-        })
+        rustc_type_ir::elaborate::supertrait_def_ids(self, trait_def_id)
     }
 
     /// Given a closure signature, returns an equivalent fn signature. Detuples
diff --git a/compiler/rustc_middle/src/ty/print/pretty.rs b/compiler/rustc_middle/src/ty/print/pretty.rs
index 7e64c507406..df080b2887b 100644
--- a/compiler/rustc_middle/src/ty/print/pretty.rs
+++ b/compiler/rustc_middle/src/ty/print/pretty.rs
@@ -1,7 +1,6 @@
 use crate::mir::interpret::{AllocRange, GlobalAlloc, Pointer, Provenance, Scalar};
 use crate::query::IntoQueryParam;
 use crate::query::Providers;
-use crate::traits::util::{super_predicates_for_pretty_printing, supertraits_for_pretty_printing};
 use crate::ty::GenericArgKind;
 use crate::ty::{
     ConstInt, Expr, ParamConst, ScalarInt, Term, TermKind, TypeFoldable, TypeSuperFoldable,
@@ -23,6 +22,7 @@ use rustc_span::symbol::{kw, Ident, Symbol};
 use rustc_span::FileNameDisplayPreference;
 use rustc_target::abi::Size;
 use rustc_target::spec::abi::Abi;
+use rustc_type_ir::{elaborate, Upcast as _};
 use smallvec::SmallVec;
 
 use std::cell::Cell;
@@ -1255,14 +1255,14 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
                 entry.has_fn_once = true;
                 return;
             } else if self.tcx().is_lang_item(trait_def_id, LangItem::FnMut) {
-                let super_trait_ref = supertraits_for_pretty_printing(self.tcx(), trait_ref)
+                let super_trait_ref = elaborate::supertraits(self.tcx(), trait_ref)
                     .find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
                     .unwrap();
 
                 fn_traits.entry(super_trait_ref).or_default().fn_mut_trait_ref = Some(trait_ref);
                 return;
             } else if self.tcx().is_lang_item(trait_def_id, LangItem::Fn) {
-                let super_trait_ref = supertraits_for_pretty_printing(self.tcx(), trait_ref)
+                let super_trait_ref = elaborate::supertraits(self.tcx(), trait_ref)
                     .find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
                     .unwrap();
 
@@ -1343,10 +1343,11 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
                     let bound_principal_with_self = bound_principal
                         .with_self_ty(cx.tcx(), cx.tcx().types.trait_object_dummy_self);
 
-                    let super_projections: Vec<_> =
-                        super_predicates_for_pretty_printing(cx.tcx(), bound_principal_with_self)
-                            .filter_map(|clause| clause.as_projection_clause())
-                            .collect();
+                    let clause: ty::Clause<'tcx> = bound_principal_with_self.upcast(cx.tcx());
+                    let super_projections: Vec<_> = elaborate::elaborate(cx.tcx(), [clause])
+                        .filter_only_self()
+                        .filter_map(|clause| clause.as_projection_clause())
+                        .collect();
 
                     let mut projections: Vec<_> = predicates
                         .projection_bounds()
diff --git a/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs b/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
index b2a59eece0d..f1128f70aaa 100644
--- a/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
@@ -6,7 +6,7 @@ use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
 use rustc_type_ir::inherent::*;
 use rustc_type_ir::lang_items::TraitSolverLangItem;
 use rustc_type_ir::visit::TypeVisitableExt as _;
-use rustc_type_ir::{self as ty, Interner, TraitPredicate, Upcast as _};
+use rustc_type_ir::{self as ty, elaborate, Interner, TraitPredicate, Upcast as _};
 use tracing::{instrument, trace};
 
 use crate::delegate::SolverDelegate;
@@ -862,8 +862,7 @@ where
             .auto_traits()
             .into_iter()
             .chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
-                self.cx()
-                    .supertrait_def_ids(principal_def_id)
+                elaborate::supertrait_def_ids(self.cx(), principal_def_id)
                     .into_iter()
                     .filter(|def_id| self.cx().trait_is_auto(*def_id))
             }))
diff --git a/compiler/rustc_type_ir/src/elaborate.rs b/compiler/rustc_type_ir/src/elaborate.rs
index 04cdf29a777..0def7d12f74 100644
--- a/compiler/rustc_type_ir/src/elaborate.rs
+++ b/compiler/rustc_type_ir/src/elaborate.rs
@@ -229,6 +229,34 @@ impl<I: Interner, O: Elaboratable<I>> Iterator for Elaborator<I, O> {
 // Supertrait iterator
 ///////////////////////////////////////////////////////////////////////////
 
+/// Computes the def-ids of the transitive supertraits of `trait_def_id`. This (intentionally)
+/// does not compute the full elaborated super-predicates but just the set of def-ids. It is used
+/// to identify which traits may define a given associated type to help avoid cycle errors,
+/// and to make size estimates for vtable layout computation.
+pub fn supertrait_def_ids<I: Interner>(
+    cx: I,
+    trait_def_id: I::DefId,
+) -> impl Iterator<Item = I::DefId> {
+    let mut set = HashSet::default();
+    let mut stack = vec![trait_def_id];
+
+    set.insert(trait_def_id);
+
+    std::iter::from_fn(move || {
+        let trait_def_id = stack.pop()?;
+
+        for (predicate, _) in cx.explicit_super_predicates_of(trait_def_id).iter_identity() {
+            if let ty::ClauseKind::Trait(data) = predicate.kind().skip_binder() {
+                if set.insert(data.def_id()) {
+                    stack.push(data.def_id());
+                }
+            }
+        }
+
+        Some(trait_def_id)
+    })
+}
+
 pub fn supertraits<I: Interner>(
     tcx: I,
     trait_ref: ty::Binder<I, ty::TraitRef<I>>,
diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs
index 446a4060e62..6a291a02d61 100644
--- a/compiler/rustc_type_ir/src/interner.rs
+++ b/compiler/rustc_type_ir/src/interner.rs
@@ -253,9 +253,6 @@ pub trait Interner:
 
     fn trait_may_be_implemented_via_object(self, trait_def_id: Self::DefId) -> bool;
 
-    fn supertrait_def_ids(self, trait_def_id: Self::DefId)
-    -> impl IntoIterator<Item = Self::DefId>;
-
     fn delay_bug(self, msg: impl ToString) -> Self::ErrorGuaranteed;
 
     fn is_general_coroutine(self, coroutine_def_id: Self::DefId) -> bool;