about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2025-05-29 02:29:01 +0000
committerbors <bors@rust-lang.org>2025-05-29 02:29:01 +0000
commit5f025f363df11c65bd31ade9fe6f48fd4f4239af (patch)
tree50463e14f9dddf557e2d34a5738ac3730c0e1f4e
parentebe9b0060240953d721508ceb4d02a745efda88f (diff)
parent0830ce036f92673fa54a06cc4eacb47426850d33 (diff)
downloadrust-5f025f363df11c65bd31ade9fe6f48fd4f4239af.tar.gz
rust-5f025f363df11c65bd31ade9fe6f48fd4f4239af.zip
Auto merge of #141581 - lcnr:fold-clauses, r=compiler-errors
add additional `TypeFlags` fast paths

Some crates, e.g. `diesel`, have items with a lot of where-clauses (more than 150). In these cases checking the `TypeFlags` of the whole `param_env` can be very beneficial.

This adds `fn fold_clauses` to mirror the existing `fn visit_clauses` and then uses this in folders which fold `ParamEnv`s.

Split out from rust-lang/rust#141451, depends on rust-lang/rust#141442.

r? `@compiler-errors`
-rw-r--r--compiler/rustc_infer/src/infer/canonical/canonicalizer.rs4
-rw-r--r--compiler/rustc_infer/src/infer/resolve.rs8
-rw-r--r--compiler/rustc_middle/src/ty/erase_regions.rs8
-rw-r--r--compiler/rustc_middle/src/ty/fold.rs4
-rw-r--r--compiler/rustc_middle/src/ty/predicate.rs2
-rw-r--r--compiler/rustc_middle/src/ty/structural_impls.rs27
-rw-r--r--compiler/rustc_next_trait_solver/src/canonicalizer.rs11
-rw-r--r--compiler/rustc_next_trait_solver/src/resolve.rs20
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs11
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs16
-rw-r--r--compiler/rustc_trait_selection/src/solve/inspect/analyse.rs9
-rw-r--r--compiler/rustc_type_ir/src/binder.rs8
-rw-r--r--compiler/rustc_type_ir/src/fold.rs8
-rw-r--r--compiler/rustc_type_ir/src/inherent.rs12
-rw-r--r--compiler/rustc_type_ir/src/interner.rs4
-rw-r--r--compiler/rustc_type_ir/src/visit.rs4
16 files changed, 138 insertions, 18 deletions
diff --git a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
index 0b543f091f7..060447ba720 100644
--- a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
+++ b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
@@ -497,6 +497,10 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
     fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
         if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p }
     }
+
+    fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
+        if c.flags().intersects(self.needs_canonical_flags) { c.super_fold_with(self) } else { c }
+    }
 }
 
 impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
diff --git a/compiler/rustc_infer/src/infer/resolve.rs b/compiler/rustc_infer/src/infer/resolve.rs
index 4b0ace8c554..a95f24b5b95 100644
--- a/compiler/rustc_infer/src/infer/resolve.rs
+++ b/compiler/rustc_infer/src/infer/resolve.rs
@@ -55,6 +55,14 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for OpportunisticVarResolver<'a, 'tcx> {
             ct.super_fold_with(self)
         }
     }
+
+    fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
+        if !p.has_non_region_infer() { p } else { p.super_fold_with(self) }
+    }
+
+    fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
+        if !c.has_non_region_infer() { c } else { c.super_fold_with(self) }
+    }
 }
 
 /// The opportunistic region resolver opportunistically resolves regions
diff --git a/compiler/rustc_middle/src/ty/erase_regions.rs b/compiler/rustc_middle/src/ty/erase_regions.rs
index 45a0b1288db..f4fead7e952 100644
--- a/compiler/rustc_middle/src/ty/erase_regions.rs
+++ b/compiler/rustc_middle/src/ty/erase_regions.rs
@@ -86,4 +86,12 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for RegionEraserVisitor<'tcx> {
             p
         }
     }
+
+    fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
+        if c.has_type_flags(TypeFlags::HAS_BINDER_VARS | TypeFlags::HAS_FREE_REGIONS) {
+            c.super_fold_with(self)
+        } else {
+            c
+        }
+    }
 }
diff --git a/compiler/rustc_middle/src/ty/fold.rs b/compiler/rustc_middle/src/ty/fold.rs
index 8d6871d2f1f..b2057fa36d7 100644
--- a/compiler/rustc_middle/src/ty/fold.rs
+++ b/compiler/rustc_middle/src/ty/fold.rs
@@ -177,6 +177,10 @@ where
     fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
         if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
     }
+
+    fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
+        if c.has_vars_bound_at_or_above(self.current_index) { c.super_fold_with(self) } else { c }
+    }
 }
 
 impl<'tcx> TyCtxt<'tcx> {
diff --git a/compiler/rustc_middle/src/ty/predicate.rs b/compiler/rustc_middle/src/ty/predicate.rs
index 551d816941b..bc2ac42b6b1 100644
--- a/compiler/rustc_middle/src/ty/predicate.rs
+++ b/compiler/rustc_middle/src/ty/predicate.rs
@@ -238,6 +238,8 @@ impl<'tcx> Clause<'tcx> {
     }
 }
 
+impl<'tcx> rustc_type_ir::inherent::Clauses<TyCtxt<'tcx>> for ty::Clauses<'tcx> {}
+
 #[extension(pub trait ExistentialPredicateStableCmpExt<'tcx>)]
 impl<'tcx> ExistentialPredicate<'tcx> {
     /// Compares via an ordering that will not change if modules are reordered or other changes are
diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs
index 7a7c33fb34d..000ba7b6fa7 100644
--- a/compiler/rustc_middle/src/ty/structural_impls.rs
+++ b/compiler/rustc_middle/src/ty/structural_impls.rs
@@ -570,6 +570,19 @@ impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clause<'tcx> {
     }
 }
 
+impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
+    fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
+        self,
+        folder: &mut F,
+    ) -> Result<Self, F::Error> {
+        folder.try_fold_clauses(self)
+    }
+
+    fn fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
+        folder.fold_clauses(self)
+    }
+}
+
 impl<'tcx> TypeVisitable<TyCtxt<'tcx>> for ty::Predicate<'tcx> {
     fn visit_with<V: TypeVisitor<TyCtxt<'tcx>>>(&self, visitor: &mut V) -> V::Result {
         visitor.visit_predicate(*self)
@@ -615,6 +628,19 @@ impl<'tcx> TypeSuperVisitable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
     }
 }
 
+impl<'tcx> TypeSuperFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
+    fn try_super_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
+        self,
+        folder: &mut F,
+    ) -> Result<Self, F::Error> {
+        ty::util::try_fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
+    }
+
+    fn super_fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
+        ty::util::fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
+    }
+}
+
 impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Const<'tcx> {
     fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
         self,
@@ -775,7 +801,6 @@ macro_rules! list_fold {
 }
 
 list_fold! {
-    ty::Clauses<'tcx> : mk_clauses,
     &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> : mk_poly_existential_predicates,
     &'tcx ty::List<PlaceElem<'tcx>> : mk_place_elems,
     &'tcx ty::List<ty::Pattern<'tcx>> : mk_patterns,
diff --git a/compiler/rustc_next_trait_solver/src/canonicalizer.rs b/compiler/rustc_next_trait_solver/src/canonicalizer.rs
index addeb3e2b78..e5ca2bda459 100644
--- a/compiler/rustc_next_trait_solver/src/canonicalizer.rs
+++ b/compiler/rustc_next_trait_solver/src/canonicalizer.rs
@@ -572,4 +572,15 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicaliz
     fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
         if p.flags().intersects(NEEDS_CANONICAL) { p.super_fold_with(self) } else { p }
     }
+
+    fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
+        match self.canonicalize_mode {
+            CanonicalizeMode::Input { keep_static: true }
+            | CanonicalizeMode::Response { max_input_universe: _ } => {}
+            CanonicalizeMode::Input { keep_static: false } => {
+                panic!("erasing 'static in env")
+            }
+        }
+        if c.flags().intersects(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c }
+    }
 }
diff --git a/compiler/rustc_next_trait_solver/src/resolve.rs b/compiler/rustc_next_trait_solver/src/resolve.rs
index 39abec2d7d8..c3c57eccd6e 100644
--- a/compiler/rustc_next_trait_solver/src/resolve.rs
+++ b/compiler/rustc_next_trait_solver/src/resolve.rs
@@ -11,7 +11,7 @@ use crate::delegate::SolverDelegate;
 // EAGER RESOLUTION
 
 /// Resolves ty, region, and const vars to their inferred values or their root vars.
-pub struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
+struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
 where
     D: SolverDelegate<Interner = I>,
     I: Interner,
@@ -22,8 +22,20 @@ where
     cache: DelayedMap<I::Ty, I::Ty>,
 }
 
+pub fn eager_resolve_vars<D: SolverDelegate, T: TypeFoldable<D::Interner>>(
+    delegate: &D,
+    value: T,
+) -> T {
+    if value.has_infer() {
+        let mut folder = EagerResolver::new(delegate);
+        value.fold_with(&mut folder)
+    } else {
+        value
+    }
+}
+
 impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
-    pub fn new(delegate: &'a D) -> Self {
+    fn new(delegate: &'a D) -> Self {
         EagerResolver { delegate, cache: Default::default() }
     }
 }
@@ -90,4 +102,8 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
     fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
         if p.has_infer() { p.super_fold_with(self) } else { p }
     }
+
+    fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
+        if c.has_infer() { c.super_fold_with(self) } else { c }
+    }
 }
diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs
index 66d4cd23112..cb7c498b94e 100644
--- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs
@@ -22,7 +22,7 @@ use tracing::{debug, instrument, trace};
 
 use crate::canonicalizer::Canonicalizer;
 use crate::delegate::SolverDelegate;
-use crate::resolve::EagerResolver;
+use crate::resolve::eager_resolve_vars;
 use crate::solve::eval_ctxt::CurrentGoalKind;
 use crate::solve::{
     CanonicalInput, CanonicalResponse, Certainty, EvalCtxt, ExternalConstraintsData, Goal,
@@ -61,8 +61,7 @@ where
         // so we only canonicalize the lookup table and ignore
         // duplicate entries.
         let opaque_types = self.delegate.clone_opaque_types_lookup_table();
-        let (goal, opaque_types) =
-            (goal, opaque_types).fold_with(&mut EagerResolver::new(self.delegate));
+        let (goal, opaque_types) = eager_resolve_vars(self.delegate, (goal, opaque_types));
 
         let mut orig_values = Default::default();
         let canonical = Canonicalizer::canonicalize_input(
@@ -162,8 +161,8 @@ where
 
         let external_constraints =
             self.compute_external_query_constraints(certainty, normalization_nested_goals);
-        let (var_values, mut external_constraints) = (self.var_values, external_constraints)
-            .fold_with(&mut EagerResolver::new(self.delegate));
+        let (var_values, mut external_constraints) =
+            eager_resolve_vars(self.delegate, (self.var_values, external_constraints));
 
         // Remove any trivial or duplicated region constraints once we've resolved regions
         let mut unique = HashSet::default();
@@ -474,7 +473,7 @@ where
 {
     let var_values = CanonicalVarValues { var_values: delegate.cx().mk_args(var_values) };
     let state = inspect::State { var_values, data };
-    let state = state.fold_with(&mut EagerResolver::new(delegate));
+    let state = eager_resolve_vars(delegate, state);
     Canonicalizer::canonicalize_response(delegate, max_input_universe, &mut vec![], state)
 }
 
diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
index 2924208173d..b1a842e04af 100644
--- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
@@ -925,6 +925,22 @@ where
                     }
                 }
             }
+
+            fn visit_predicate(&mut self, p: I::Predicate) -> Self::Result {
+                if p.has_non_region_infer() || p.has_placeholders() {
+                    p.super_visit_with(self)
+                } else {
+                    ControlFlow::Continue(())
+                }
+            }
+
+            fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result {
+                if c.has_non_region_infer() || c.has_placeholders() {
+                    c.super_visit_with(self)
+                } else {
+                    ControlFlow::Continue(())
+                }
+            }
         }
 
         let mut visitor = ContainsTermOrNotNameable {
diff --git a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
index fda2c97ed56..1193a9059ca 100644
--- a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
+++ b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
@@ -15,9 +15,9 @@ use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk};
 use rustc_macros::extension;
 use rustc_middle::traits::ObligationCause;
 use rustc_middle::traits::solve::{Certainty, Goal, GoalSource, NoSolution, QueryResult};
-use rustc_middle::ty::{TyCtxt, TypeFoldable, VisitorResult, try_visit};
+use rustc_middle::ty::{TyCtxt, VisitorResult, try_visit};
 use rustc_middle::{bug, ty};
-use rustc_next_trait_solver::resolve::EagerResolver;
+use rustc_next_trait_solver::resolve::eager_resolve_vars;
 use rustc_next_trait_solver::solve::inspect::{self, instantiate_canonical_state};
 use rustc_next_trait_solver::solve::{GenerateProofTree, MaybeCause, SolverDelegateEvalExt as _};
 use rustc_span::{DUMMY_SP, Span};
@@ -187,8 +187,7 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
             let _ = term_hack.constrain(infcx, span, param_env);
         }
 
-        let opt_impl_args =
-            opt_impl_args.map(|impl_args| impl_args.fold_with(&mut EagerResolver::new(infcx)));
+        let opt_impl_args = opt_impl_args.map(|impl_args| eager_resolve_vars(infcx, impl_args));
 
         let goals = instantiated_goals
             .into_iter()
@@ -392,7 +391,7 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
             infcx,
             depth,
             orig_values,
-            goal: uncanonicalized_goal.fold_with(&mut EagerResolver::new(infcx)),
+            goal: eager_resolve_vars(infcx, uncanonicalized_goal),
             result,
             evaluation_kind: evaluation.kind,
             normalizes_to_term_hack,
diff --git a/compiler/rustc_type_ir/src/binder.rs b/compiler/rustc_type_ir/src/binder.rs
index 1e6ca0dcf5d..55c0a3bba9f 100644
--- a/compiler/rustc_type_ir/src/binder.rs
+++ b/compiler/rustc_type_ir/src/binder.rs
@@ -711,6 +711,14 @@ impl<'a, I: Interner> TypeFolder<I> for ArgFolder<'a, I> {
             c.super_fold_with(self)
         }
     }
+
+    fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
+        if p.has_param() { p.super_fold_with(self) } else { p }
+    }
+
+    fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
+        if c.has_param() { c.super_fold_with(self) } else { c }
+    }
 }
 
 impl<'a, I: Interner> ArgFolder<'a, I> {
diff --git a/compiler/rustc_type_ir/src/fold.rs b/compiler/rustc_type_ir/src/fold.rs
index ce1188070ca..a5eb8699e5f 100644
--- a/compiler/rustc_type_ir/src/fold.rs
+++ b/compiler/rustc_type_ir/src/fold.rs
@@ -152,6 +152,10 @@ pub trait TypeFolder<I: Interner>: Sized {
     fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
         p.super_fold_with(self)
     }
+
+    fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
+        c.super_fold_with(self)
+    }
 }
 
 /// This trait is implemented for every folding traversal. There is a fold
@@ -190,6 +194,10 @@ pub trait FallibleTypeFolder<I: Interner>: Sized {
     fn try_fold_predicate(&mut self, p: I::Predicate) -> Result<I::Predicate, Self::Error> {
         p.try_super_fold_with(self)
     }
+
+    fn try_fold_clauses(&mut self, c: I::Clauses) -> Result<I::Clauses, Self::Error> {
+        c.try_super_fold_with(self)
+    }
 }
 
 ///////////////////////////////////////////////////////////////////////////
diff --git a/compiler/rustc_type_ir/src/inherent.rs b/compiler/rustc_type_ir/src/inherent.rs
index dde55effc3d..fa88bcb891a 100644
--- a/compiler/rustc_type_ir/src/inherent.rs
+++ b/compiler/rustc_type_ir/src/inherent.rs
@@ -511,6 +511,18 @@ pub trait Clause<I: Interner<Clause = Self>>:
     fn instantiate_supertrait(self, cx: I, trait_ref: ty::Binder<I, ty::TraitRef<I>>) -> Self;
 }
 
+pub trait Clauses<I: Interner<Clauses = Self>>:
+    Copy
+    + Debug
+    + Hash
+    + Eq
+    + TypeSuperVisitable<I>
+    + TypeSuperFoldable<I>
+    + Flags
+    + SliceLike<Item = I::Clause>
+{
+}
+
 /// Common capabilities of placeholder kinds
 pub trait PlaceholderLike: Copy + Debug + Hash + Eq {
     fn universe(self) -> ty::UniverseIndex;
diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs
index 7e88114df46..a9917192144 100644
--- a/compiler/rustc_type_ir/src/interner.rs
+++ b/compiler/rustc_type_ir/src/interner.rs
@@ -12,7 +12,7 @@ use crate::ir_print::IrPrint;
 use crate::lang_items::TraitSolverLangItem;
 use crate::relate::Relate;
 use crate::solve::{CanonicalInput, ExternalConstraintsData, PredefinedOpaquesData, QueryResult};
-use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable};
+use crate::visit::{Flags, TypeVisitable};
 use crate::{self as ty, search_graph};
 
 #[cfg_attr(feature = "nightly", rustc_diagnostic_item = "type_ir_interner")]
@@ -146,7 +146,7 @@ pub trait Interner:
     type ParamEnv: ParamEnv<Self>;
     type Predicate: Predicate<Self>;
     type Clause: Clause<Self>;
-    type Clauses: Copy + Debug + Hash + Eq + TypeSuperVisitable<Self> + Flags;
+    type Clauses: Clauses<Self>;
 
     fn with_global_cache<R>(self, f: impl FnOnce(&mut search_graph::GlobalCache<Self>) -> R) -> R;
 
diff --git a/compiler/rustc_type_ir/src/visit.rs b/compiler/rustc_type_ir/src/visit.rs
index ccb84e25911..fc3864dd5ae 100644
--- a/compiler/rustc_type_ir/src/visit.rs
+++ b/compiler/rustc_type_ir/src/visit.rs
@@ -120,8 +120,8 @@ pub trait TypeVisitor<I: Interner>: Sized {
         p.super_visit_with(self)
     }
 
-    fn visit_clauses(&mut self, p: I::Clauses) -> Self::Result {
-        p.super_visit_with(self)
+    fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result {
+        c.super_visit_with(self)
     }
 
     fn visit_error(&mut self, _guar: I::ErrorGuaranteed) -> Self::Result {