diff options
| author | lcnr <rust@lcnr.de> | 2024-09-30 10:18:55 +0200 |
|---|---|---|
| committer | lcnr <rust@lcnr.de> | 2024-10-01 17:20:31 +0200 |
| commit | 13881f5404037e25a88d0b79a836e232dc73b1fc (patch) | |
| tree | 27ede0dfb8082dd78a8ce7cdb10af91f00dddbe2 /compiler/rustc_next_trait_solver/src | |
| parent | 15ac6983930a0d49b921c0330dbbb5a4f8f1d34a (diff) | |
| download | rust-13881f5404037e25a88d0b79a836e232dc73b1fc.tar.gz rust-13881f5404037e25a88d0b79a836e232dc73b1fc.zip | |
add caches to multiple type folders
Diffstat (limited to 'compiler/rustc_next_trait_solver/src')
| -rw-r--r-- | compiler/rustc_next_trait_solver/src/resolve.rs | 11 | ||||
| -rw-r--r-- | compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs | 51 |
2 files changed, 47 insertions, 15 deletions
diff --git a/compiler/rustc_next_trait_solver/src/resolve.rs b/compiler/rustc_next_trait_solver/src/resolve.rs index 132b7400300..a37066cec66 100644 --- a/compiler/rustc_next_trait_solver/src/resolve.rs +++ b/compiler/rustc_next_trait_solver/src/resolve.rs @@ -1,3 +1,4 @@ +use rustc_type_ir::data_structures::DelayedMap; use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable}; use rustc_type_ir::inherent::*; use rustc_type_ir::visit::TypeVisitableExt; @@ -15,11 +16,12 @@ where I: Interner, { delegate: &'a D, + cache: DelayedMap<I::Ty, I::Ty>, } impl<'a, D: SolverDelegate> EagerResolver<'a, D> { pub fn new(delegate: &'a D) -> Self { - EagerResolver { delegate } + EagerResolver { delegate, cache: Default::default() } } } @@ -42,7 +44,12 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv ty::Infer(ty::FloatVar(vid)) => self.delegate.opportunistic_resolve_float_var(vid), _ => { if t.has_infer() { - t.super_fold_with(self) + if let Some(&ty) = self.cache.get(&t) { + return ty; + } + let res = t.super_fold_with(self); + assert!(self.cache.insert(t, res)); + res } else { t } diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs index 12ad0724b5c..12b4b3cb3a9 100644 --- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs +++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs @@ -3,7 +3,7 @@ use std::ops::ControlFlow; use derive_where::derive_where; #[cfg(feature = "nightly")] use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable}; -use rustc_type_ir::data_structures::ensure_sufficient_stack; +use rustc_type_ir::data_structures::{HashMap, HashSet, ensure_sufficient_stack}; use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable}; use rustc_type_ir::inherent::*; use rustc_type_ir::relate::Relate; @@ -579,18 +579,16 @@ where #[instrument(level = "trace", skip(self))] pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<I, ty::NormalizesTo<I>>) { - goal.predicate = goal - .predicate - .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env }); + goal.predicate = + goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env)); self.inspect.add_normalizes_to_goal(self.delegate, self.max_input_universe, goal); self.nested_goals.normalizes_to_goals.push(goal); } #[instrument(level = "debug", skip(self))] pub(super) fn add_goal(&mut self, source: GoalSource, mut goal: Goal<I, I::Predicate>) { - goal.predicate = goal - .predicate - .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env }); + goal.predicate = + goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env)); self.inspect.add_goal(self.delegate, self.max_input_universe, source, goal); self.nested_goals.goals.push((source, goal)); } @@ -654,6 +652,7 @@ where term: I::Term, universe_of_term: ty::UniverseIndex, delegate: &'a D, + cache: HashSet<I::Ty>, } impl<D: SolverDelegate<Interner = I>, I: Interner> ContainsTermOrNotNameable<'_, D, I> { @@ -671,6 +670,10 @@ where { type Result = ControlFlow<()>; fn visit_ty(&mut self, t: I::Ty) -> Self::Result { + if self.cache.contains(&t) { + return ControlFlow::Continue(()); + } + match t.kind() { ty::Infer(ty::TyVar(vid)) => { if let ty::TermKind::Ty(term) = self.term.kind() { @@ -683,17 +686,18 @@ where } } - self.check_nameable(self.delegate.universe_of_ty(vid).unwrap()) + self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())?; } - ty::Placeholder(p) => self.check_nameable(p.universe()), + ty::Placeholder(p) => self.check_nameable(p.universe())?, _ => { if t.has_non_region_infer() || t.has_placeholders() { - t.super_visit_with(self) - } else { - ControlFlow::Continue(()) + t.super_visit_with(self)? } } } + + assert!(self.cache.insert(t)); + ControlFlow::Continue(()) } fn visit_const(&mut self, c: I::Const) -> Self::Result { @@ -728,6 +732,7 @@ where delegate: self.delegate, universe_of_term, term: goal.predicate.term, + cache: Default::default(), }; goal.predicate.alias.visit_with(&mut visitor).is_continue() && goal.param_env.visit_with(&mut visitor).is_continue() @@ -1015,6 +1020,17 @@ where { ecx: &'me mut EvalCtxt<'a, D>, param_env: I::ParamEnv, + cache: HashMap<I::Ty, I::Ty>, +} + +impl<'me, 'a, D, I> ReplaceAliasWithInfer<'me, 'a, D, I> +where + D: SolverDelegate<Interner = I>, + I: Interner, +{ + fn new(ecx: &'me mut EvalCtxt<'a, D>, param_env: I::ParamEnv) -> Self { + ReplaceAliasWithInfer { ecx, param_env, cache: Default::default() } + } } impl<D, I> TypeFolder<I> for ReplaceAliasWithInfer<'_, '_, D, I> @@ -1041,7 +1057,16 @@ where ); infer_ty } - _ => ty.super_fold_with(self), + _ if ty.has_aliases() => { + if let Some(&entry) = self.cache.get(&ty) { + return entry; + } + + let res = ty.super_fold_with(self); + assert!(self.cache.insert(ty, res).is_none()); + res + } + _ => ty, } } |
