about summary refs log tree commit diff
path: root/compiler/rustc_next_trait_solver/src
diff options
context:
space:
mode:
authorlcnr <rust@lcnr.de>2024-09-30 10:18:55 +0200
committerlcnr <rust@lcnr.de>2024-10-01 17:20:31 +0200
commit13881f5404037e25a88d0b79a836e232dc73b1fc (patch)
tree27ede0dfb8082dd78a8ce7cdb10af91f00dddbe2 /compiler/rustc_next_trait_solver/src
parent15ac6983930a0d49b921c0330dbbb5a4f8f1d34a (diff)
downloadrust-13881f5404037e25a88d0b79a836e232dc73b1fc.tar.gz
rust-13881f5404037e25a88d0b79a836e232dc73b1fc.zip
add caches to multiple type folders
Diffstat (limited to 'compiler/rustc_next_trait_solver/src')
-rw-r--r--compiler/rustc_next_trait_solver/src/resolve.rs11
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs51
2 files changed, 47 insertions, 15 deletions
diff --git a/compiler/rustc_next_trait_solver/src/resolve.rs b/compiler/rustc_next_trait_solver/src/resolve.rs
index 132b7400300..a37066cec66 100644
--- a/compiler/rustc_next_trait_solver/src/resolve.rs
+++ b/compiler/rustc_next_trait_solver/src/resolve.rs
@@ -1,3 +1,4 @@
+use rustc_type_ir::data_structures::DelayedMap;
 use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
 use rustc_type_ir::inherent::*;
 use rustc_type_ir::visit::TypeVisitableExt;
@@ -15,11 +16,12 @@ where
     I: Interner,
 {
     delegate: &'a D,
+    cache: DelayedMap<I::Ty, I::Ty>,
 }
 
 impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
     pub fn new(delegate: &'a D) -> Self {
-        EagerResolver { delegate }
+        EagerResolver { delegate, cache: Default::default() }
     }
 }
 
@@ -42,7 +44,12 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
             ty::Infer(ty::FloatVar(vid)) => self.delegate.opportunistic_resolve_float_var(vid),
             _ => {
                 if t.has_infer() {
-                    t.super_fold_with(self)
+                    if let Some(&ty) = self.cache.get(&t) {
+                        return ty;
+                    }
+                    let res = t.super_fold_with(self);
+                    assert!(self.cache.insert(t, res));
+                    res
                 } else {
                     t
                 }
diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
index 12ad0724b5c..12b4b3cb3a9 100644
--- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
@@ -3,7 +3,7 @@ use std::ops::ControlFlow;
 use derive_where::derive_where;
 #[cfg(feature = "nightly")]
 use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable};
-use rustc_type_ir::data_structures::ensure_sufficient_stack;
+use rustc_type_ir::data_structures::{HashMap, HashSet, ensure_sufficient_stack};
 use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
 use rustc_type_ir::inherent::*;
 use rustc_type_ir::relate::Relate;
@@ -579,18 +579,16 @@ where
 
     #[instrument(level = "trace", skip(self))]
     pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<I, ty::NormalizesTo<I>>) {
-        goal.predicate = goal
-            .predicate
-            .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
+        goal.predicate =
+            goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env));
         self.inspect.add_normalizes_to_goal(self.delegate, self.max_input_universe, goal);
         self.nested_goals.normalizes_to_goals.push(goal);
     }
 
     #[instrument(level = "debug", skip(self))]
     pub(super) fn add_goal(&mut self, source: GoalSource, mut goal: Goal<I, I::Predicate>) {
-        goal.predicate = goal
-            .predicate
-            .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
+        goal.predicate =
+            goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env));
         self.inspect.add_goal(self.delegate, self.max_input_universe, source, goal);
         self.nested_goals.goals.push((source, goal));
     }
@@ -654,6 +652,7 @@ where
             term: I::Term,
             universe_of_term: ty::UniverseIndex,
             delegate: &'a D,
+            cache: HashSet<I::Ty>,
         }
 
         impl<D: SolverDelegate<Interner = I>, I: Interner> ContainsTermOrNotNameable<'_, D, I> {
@@ -671,6 +670,10 @@ where
         {
             type Result = ControlFlow<()>;
             fn visit_ty(&mut self, t: I::Ty) -> Self::Result {
+                if self.cache.contains(&t) {
+                    return ControlFlow::Continue(());
+                }
+
                 match t.kind() {
                     ty::Infer(ty::TyVar(vid)) => {
                         if let ty::TermKind::Ty(term) = self.term.kind() {
@@ -683,17 +686,18 @@ where
                             }
                         }
 
-                        self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())
+                        self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())?;
                     }
-                    ty::Placeholder(p) => self.check_nameable(p.universe()),
+                    ty::Placeholder(p) => self.check_nameable(p.universe())?,
                     _ => {
                         if t.has_non_region_infer() || t.has_placeholders() {
-                            t.super_visit_with(self)
-                        } else {
-                            ControlFlow::Continue(())
+                            t.super_visit_with(self)?
                         }
                     }
                 }
+
+                assert!(self.cache.insert(t));
+                ControlFlow::Continue(())
             }
 
             fn visit_const(&mut self, c: I::Const) -> Self::Result {
@@ -728,6 +732,7 @@ where
             delegate: self.delegate,
             universe_of_term,
             term: goal.predicate.term,
+            cache: Default::default(),
         };
         goal.predicate.alias.visit_with(&mut visitor).is_continue()
             && goal.param_env.visit_with(&mut visitor).is_continue()
@@ -1015,6 +1020,17 @@ where
 {
     ecx: &'me mut EvalCtxt<'a, D>,
     param_env: I::ParamEnv,
+    cache: HashMap<I::Ty, I::Ty>,
+}
+
+impl<'me, 'a, D, I> ReplaceAliasWithInfer<'me, 'a, D, I>
+where
+    D: SolverDelegate<Interner = I>,
+    I: Interner,
+{
+    fn new(ecx: &'me mut EvalCtxt<'a, D>, param_env: I::ParamEnv) -> Self {
+        ReplaceAliasWithInfer { ecx, param_env, cache: Default::default() }
+    }
 }
 
 impl<D, I> TypeFolder<I> for ReplaceAliasWithInfer<'_, '_, D, I>
@@ -1041,7 +1057,16 @@ where
                 );
                 infer_ty
             }
-            _ => ty.super_fold_with(self),
+            _ if ty.has_aliases() => {
+                if let Some(&entry) = self.cache.get(&ty) {
+                    return entry;
+                }
+
+                let res = ty.super_fold_with(self);
+                assert!(self.cache.insert(ty, res).is_none());
+                res
+            }
+            _ => ty,
         }
     }