diff options
| author | lcnr <rust@lcnr.de> | 2024-09-30 10:18:55 +0200 |
|---|---|---|
| committer | lcnr <rust@lcnr.de> | 2024-10-01 17:20:31 +0200 |
| commit | 13881f5404037e25a88d0b79a836e232dc73b1fc (patch) | |
| tree | 27ede0dfb8082dd78a8ce7cdb10af91f00dddbe2 | |
| parent | 15ac6983930a0d49b921c0330dbbb5a4f8f1d34a (diff) | |
| download | rust-13881f5404037e25a88d0b79a836e232dc73b1fc.tar.gz rust-13881f5404037e25a88d0b79a836e232dc73b1fc.zip | |
add caches to multiple type folders
| -rw-r--r-- | compiler/rustc_infer/src/infer/relate/combine.rs | 7 | ||||
| -rw-r--r-- | compiler/rustc_infer/src/infer/relate/type_relating.rs | 41 | ||||
| -rw-r--r-- | compiler/rustc_infer/src/infer/resolve.rs | 14 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/ty/fold.rs | 25 | ||||
| -rw-r--r-- | compiler/rustc_next_trait_solver/src/resolve.rs | 11 | ||||
| -rw-r--r-- | compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs | 51 | ||||
| -rw-r--r-- | compiler/rustc_type_ir/src/data_structures/delayed_map.rs | 92 | ||||
| -rw-r--r-- | compiler/rustc_type_ir/src/data_structures/mod.rs (renamed from compiler/rustc_type_ir/src/data_structures.rs) | 3 |
8 files changed, 222 insertions, 22 deletions
diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs index e75d7b7db14..3b2ef3fe981 100644 --- a/compiler/rustc_infer/src/infer/relate/combine.rs +++ b/compiler/rustc_infer/src/infer/relate/combine.rs @@ -36,10 +36,15 @@ use crate::traits::{Obligation, PredicateObligation}; #[derive(Clone)] pub struct CombineFields<'infcx, 'tcx> { pub infcx: &'infcx InferCtxt<'tcx>, + // Immutable fields pub trace: TypeTrace<'tcx>, pub param_env: ty::ParamEnv<'tcx>, - pub goals: Vec<Goal<'tcx, ty::Predicate<'tcx>>>, pub define_opaque_types: DefineOpaqueTypes, + // Mutable fields + // + // Adding any additional field likely requires + // changes to the cache of `TypeRelating`. + pub goals: Vec<Goal<'tcx, ty::Predicate<'tcx>>>, } impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> { diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs index a402a8009ff..c8ced4a2cda 100644 --- a/compiler/rustc_infer/src/infer/relate/type_relating.rs +++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs @@ -4,6 +4,7 @@ use rustc_middle::ty::relate::{ }; use rustc_middle::ty::{self, Ty, TyCtxt, TyVar}; use rustc_span::Span; +use rustc_type_ir::data_structures::DelayedSet; use tracing::{debug, instrument}; use super::combine::CombineFields; @@ -13,9 +14,36 @@ use crate::infer::{DefineOpaqueTypes, InferCtxt, SubregionOrigin}; /// Enforce that `a` is equal to or a subtype of `b`. pub struct TypeRelating<'combine, 'a, 'tcx> { + // Partially mutable. fields: &'combine mut CombineFields<'a, 'tcx>, + + // Immutable fields. structurally_relate_aliases: StructurallyRelateAliases, ambient_variance: ty::Variance, + + /// The cache has only tracks the `ambient_variance` as its the + /// only field which is mutable and which meaningfully changes + /// the result when relating types. + /// + /// The cache does not track whether the state of the + /// `InferCtxt` has been changed or whether we've added any + /// obligations to `self.fields.goals`. Whether a goal is added + /// once or multiple times is not really meaningful. + /// + /// Changes in the inference state may delay some type inference to + /// the next fulfillment loop. Given that this loop is already + /// necessary, this is also not a meaningful change. Consider + /// the following three relations: + /// ```text + /// Vec<?0> sub Vec<?1> + /// ?0 eq u32 + /// Vec<?0> sub Vec<?1> + /// ``` + /// Without a cache, the second `Vec<?0> sub Vec<?1>` would eagerly + /// constrain `?1` to `u32`. When using the cache entry from the + /// first time we've related these types, this only happens when + /// later proving the `Subtype(?0, ?1)` goal from the first relation. + cache: DelayedSet<(ty::Variance, Ty<'tcx>, Ty<'tcx>)>, } impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> { @@ -24,7 +52,12 @@ impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> { structurally_relate_aliases: StructurallyRelateAliases, ambient_variance: ty::Variance, ) -> TypeRelating<'combine, 'infcx, 'tcx> { - TypeRelating { fields: f, structurally_relate_aliases, ambient_variance } + TypeRelating { + fields: f, + structurally_relate_aliases, + ambient_variance, + cache: Default::default(), + } } } @@ -78,6 +111,10 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> { let a = infcx.shallow_resolve(a); let b = infcx.shallow_resolve(b); + if self.cache.contains(&(self.ambient_variance, a, b)) { + return Ok(a); + } + match (a.kind(), b.kind()) { (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => { match self.ambient_variance { @@ -160,6 +197,8 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> { } } + assert!(self.cache.insert((self.ambient_variance, a, b))); + Ok(a) } diff --git a/compiler/rustc_infer/src/infer/resolve.rs b/compiler/rustc_infer/src/infer/resolve.rs index 34625ffb778..671a66d504f 100644 --- a/compiler/rustc_infer/src/infer/resolve.rs +++ b/compiler/rustc_infer/src/infer/resolve.rs @@ -2,6 +2,7 @@ use rustc_middle::bug; use rustc_middle::ty::fold::{FallibleTypeFolder, TypeFolder, TypeSuperFoldable}; use rustc_middle::ty::visit::TypeVisitableExt; use rustc_middle::ty::{self, Const, InferConst, Ty, TyCtxt, TypeFoldable}; +use rustc_type_ir::data_structures::DelayedMap; use super::{FixupError, FixupResult, InferCtxt}; @@ -15,12 +16,15 @@ use super::{FixupError, FixupResult, InferCtxt}; /// points for correctness. pub struct OpportunisticVarResolver<'a, 'tcx> { infcx: &'a InferCtxt<'tcx>, + /// We're able to use a cache here as the folder does + /// not have any mutable state. + cache: DelayedMap<Ty<'tcx>, Ty<'tcx>>, } impl<'a, 'tcx> OpportunisticVarResolver<'a, 'tcx> { #[inline] pub fn new(infcx: &'a InferCtxt<'tcx>) -> Self { - OpportunisticVarResolver { infcx } + OpportunisticVarResolver { infcx, cache: Default::default() } } } @@ -33,9 +37,13 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for OpportunisticVarResolver<'a, 'tcx> { fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> { if !t.has_non_region_infer() { t // micro-optimize -- if there is nothing in this type that this fold affects... + } else if let Some(&ty) = self.cache.get(&t) { + return ty; } else { - let t = self.infcx.shallow_resolve(t); - t.super_fold_with(self) + let shallow = self.infcx.shallow_resolve(t); + let res = shallow.super_fold_with(self); + assert!(self.cache.insert(t, res)); + res } } diff --git a/compiler/rustc_middle/src/ty/fold.rs b/compiler/rustc_middle/src/ty/fold.rs index 2ee7497497a..e152d3f5fbe 100644 --- a/compiler/rustc_middle/src/ty/fold.rs +++ b/compiler/rustc_middle/src/ty/fold.rs @@ -1,5 +1,6 @@ use rustc_data_structures::fx::FxIndexMap; use rustc_hir::def_id::DefId; +use rustc_type_ir::data_structures::DelayedMap; pub use rustc_type_ir::fold::{ FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable, shift_region, shift_vars, }; @@ -131,12 +132,20 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for RegionFolder<'a, 'tcx> { /////////////////////////////////////////////////////////////////////////// // Bound vars replacer +/// A delegate used when instantiating bound vars. +/// +/// Any implementation must make sure that each bound variable always +/// gets mapped to the same result. `BoundVarReplacer` caches by using +/// a `DelayedMap` which does not cache the first few types it encounters. pub trait BoundVarReplacerDelegate<'tcx> { fn replace_region(&mut self, br: ty::BoundRegion) -> ty::Region<'tcx>; fn replace_ty(&mut self, bt: ty::BoundTy) -> Ty<'tcx>; fn replace_const(&mut self, bv: ty::BoundVar) -> ty::Const<'tcx>; } +/// A simple delegate taking 3 mutable functions. The used functions must +/// always return the same result for each bound variable, no matter how +/// frequently they are called. pub struct FnMutDelegate<'a, 'tcx> { pub regions: &'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a), pub types: &'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a), @@ -164,11 +173,15 @@ struct BoundVarReplacer<'tcx, D> { current_index: ty::DebruijnIndex, delegate: D, + + /// This cache only tracks the `DebruijnIndex` and assumes that it does not matter + /// for the delegate how often its methods get used. + cache: DelayedMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>, } impl<'tcx, D: BoundVarReplacerDelegate<'tcx>> BoundVarReplacer<'tcx, D> { fn new(tcx: TyCtxt<'tcx>, delegate: D) -> Self { - BoundVarReplacer { tcx, current_index: ty::INNERMOST, delegate } + BoundVarReplacer { tcx, current_index: ty::INNERMOST, delegate, cache: Default::default() } } } @@ -197,7 +210,15 @@ where debug_assert!(!ty.has_vars_bound_above(ty::INNERMOST)); ty::fold::shift_vars(self.tcx, ty, self.current_index.as_u32()) } - _ if t.has_vars_bound_at_or_above(self.current_index) => t.super_fold_with(self), + _ if t.has_vars_bound_at_or_above(self.current_index) => { + if let Some(&ty) = self.cache.get(&(self.current_index, t)) { + return ty; + } + + let res = t.super_fold_with(self); + assert!(self.cache.insert((self.current_index, t), res)); + res + } _ => t, } } diff --git a/compiler/rustc_next_trait_solver/src/resolve.rs b/compiler/rustc_next_trait_solver/src/resolve.rs index 132b7400300..a37066cec66 100644 --- a/compiler/rustc_next_trait_solver/src/resolve.rs +++ b/compiler/rustc_next_trait_solver/src/resolve.rs @@ -1,3 +1,4 @@ +use rustc_type_ir::data_structures::DelayedMap; use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable}; use rustc_type_ir::inherent::*; use rustc_type_ir::visit::TypeVisitableExt; @@ -15,11 +16,12 @@ where I: Interner, { delegate: &'a D, + cache: DelayedMap<I::Ty, I::Ty>, } impl<'a, D: SolverDelegate> EagerResolver<'a, D> { pub fn new(delegate: &'a D) -> Self { - EagerResolver { delegate } + EagerResolver { delegate, cache: Default::default() } } } @@ -42,7 +44,12 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv ty::Infer(ty::FloatVar(vid)) => self.delegate.opportunistic_resolve_float_var(vid), _ => { if t.has_infer() { - t.super_fold_with(self) + if let Some(&ty) = self.cache.get(&t) { + return ty; + } + let res = t.super_fold_with(self); + assert!(self.cache.insert(t, res)); + res } else { t } diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs index 12ad0724b5c..12b4b3cb3a9 100644 --- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs +++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs @@ -3,7 +3,7 @@ use std::ops::ControlFlow; use derive_where::derive_where; #[cfg(feature = "nightly")] use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable}; -use rustc_type_ir::data_structures::ensure_sufficient_stack; +use rustc_type_ir::data_structures::{HashMap, HashSet, ensure_sufficient_stack}; use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable}; use rustc_type_ir::inherent::*; use rustc_type_ir::relate::Relate; @@ -579,18 +579,16 @@ where #[instrument(level = "trace", skip(self))] pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<I, ty::NormalizesTo<I>>) { - goal.predicate = goal - .predicate - .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env }); + goal.predicate = + goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env)); self.inspect.add_normalizes_to_goal(self.delegate, self.max_input_universe, goal); self.nested_goals.normalizes_to_goals.push(goal); } #[instrument(level = "debug", skip(self))] pub(super) fn add_goal(&mut self, source: GoalSource, mut goal: Goal<I, I::Predicate>) { - goal.predicate = goal - .predicate - .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env }); + goal.predicate = + goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env)); self.inspect.add_goal(self.delegate, self.max_input_universe, source, goal); self.nested_goals.goals.push((source, goal)); } @@ -654,6 +652,7 @@ where term: I::Term, universe_of_term: ty::UniverseIndex, delegate: &'a D, + cache: HashSet<I::Ty>, } impl<D: SolverDelegate<Interner = I>, I: Interner> ContainsTermOrNotNameable<'_, D, I> { @@ -671,6 +670,10 @@ where { type Result = ControlFlow<()>; fn visit_ty(&mut self, t: I::Ty) -> Self::Result { + if self.cache.contains(&t) { + return ControlFlow::Continue(()); + } + match t.kind() { ty::Infer(ty::TyVar(vid)) => { if let ty::TermKind::Ty(term) = self.term.kind() { @@ -683,17 +686,18 @@ where } } - self.check_nameable(self.delegate.universe_of_ty(vid).unwrap()) + self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())?; } - ty::Placeholder(p) => self.check_nameable(p.universe()), + ty::Placeholder(p) => self.check_nameable(p.universe())?, _ => { if t.has_non_region_infer() || t.has_placeholders() { - t.super_visit_with(self) - } else { - ControlFlow::Continue(()) + t.super_visit_with(self)? } } } + + assert!(self.cache.insert(t)); + ControlFlow::Continue(()) } fn visit_const(&mut self, c: I::Const) -> Self::Result { @@ -728,6 +732,7 @@ where delegate: self.delegate, universe_of_term, term: goal.predicate.term, + cache: Default::default(), }; goal.predicate.alias.visit_with(&mut visitor).is_continue() && goal.param_env.visit_with(&mut visitor).is_continue() @@ -1015,6 +1020,17 @@ where { ecx: &'me mut EvalCtxt<'a, D>, param_env: I::ParamEnv, + cache: HashMap<I::Ty, I::Ty>, +} + +impl<'me, 'a, D, I> ReplaceAliasWithInfer<'me, 'a, D, I> +where + D: SolverDelegate<Interner = I>, + I: Interner, +{ + fn new(ecx: &'me mut EvalCtxt<'a, D>, param_env: I::ParamEnv) -> Self { + ReplaceAliasWithInfer { ecx, param_env, cache: Default::default() } + } } impl<D, I> TypeFolder<I> for ReplaceAliasWithInfer<'_, '_, D, I> @@ -1041,7 +1057,16 @@ where ); infer_ty } - _ => ty.super_fold_with(self), + _ if ty.has_aliases() => { + if let Some(&entry) = self.cache.get(&ty) { + return entry; + } + + let res = ty.super_fold_with(self); + assert!(self.cache.insert(ty, res).is_none()); + res + } + _ => ty, } } diff --git a/compiler/rustc_type_ir/src/data_structures/delayed_map.rs b/compiler/rustc_type_ir/src/data_structures/delayed_map.rs new file mode 100644 index 00000000000..7e7406e3706 --- /dev/null +++ b/compiler/rustc_type_ir/src/data_structures/delayed_map.rs @@ -0,0 +1,92 @@ +use std::hash::Hash; + +use crate::data_structures::{HashMap, HashSet}; + +const CACHE_CUTOFF: u32 = 32; + +/// A hashmap which only starts hashing after ignoring the first few inputs. +/// +/// This is used in type folders as in nearly all cases caching is not worth it +/// as nearly all folded types are tiny. However, there are very rare incredibly +/// large types for which caching is necessary to avoid hangs. +#[derive(Debug)] +pub struct DelayedMap<K, V> { + cache: HashMap<K, V>, + count: u32, +} + +impl<K, V> Default for DelayedMap<K, V> { + fn default() -> Self { + DelayedMap { cache: Default::default(), count: 0 } + } +} + +impl<K: Hash + Eq, V> DelayedMap<K, V> { + #[inline(always)] + pub fn insert(&mut self, key: K, value: V) -> bool { + if self.count >= CACHE_CUTOFF { + self.cold_insert(key, value) + } else { + self.count += 1; + true + } + } + + #[cold] + #[inline(never)] + fn cold_insert(&mut self, key: K, value: V) -> bool { + self.cache.insert(key, value).is_none() + } + + #[inline(always)] + pub fn get(&self, key: &K) -> Option<&V> { + if self.cache.is_empty() { None } else { self.cold_get(key) } + } + + #[cold] + #[inline(never)] + fn cold_get(&self, key: &K) -> Option<&V> { + self.cache.get(key) + } +} + +#[derive(Debug)] +pub struct DelayedSet<T> { + cache: HashSet<T>, + count: u32, +} + +impl<T> Default for DelayedSet<T> { + fn default() -> Self { + DelayedSet { cache: Default::default(), count: 0 } + } +} + +impl<T: Hash + Eq> DelayedSet<T> { + #[inline(always)] + pub fn insert(&mut self, value: T) -> bool { + if self.count >= CACHE_CUTOFF { + self.cold_insert(value) + } else { + self.count += 1; + true + } + } + + #[cold] + #[inline(never)] + fn cold_insert(&mut self, value: T) -> bool { + self.cache.insert(value) + } + + #[inline(always)] + pub fn contains(&self, value: &T) -> bool { + !self.cache.is_empty() && self.cold_contains(value) + } + + #[cold] + #[inline(never)] + fn cold_contains(&self, value: &T) -> bool { + self.cache.contains(value) + } +} diff --git a/compiler/rustc_type_ir/src/data_structures.rs b/compiler/rustc_type_ir/src/data_structures/mod.rs index 96036e53b0a..d9766d91845 100644 --- a/compiler/rustc_type_ir/src/data_structures.rs +++ b/compiler/rustc_type_ir/src/data_structures/mod.rs @@ -6,6 +6,8 @@ pub use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; pub type IndexMap<K, V> = indexmap::IndexMap<K, V, BuildHasherDefault<FxHasher>>; pub type IndexSet<V> = indexmap::IndexSet<V, BuildHasherDefault<FxHasher>>; +mod delayed_map; + #[cfg(feature = "nightly")] mod impl_ { pub use rustc_data_structures::sso::{SsoHashMap, SsoHashSet}; @@ -24,4 +26,5 @@ mod impl_ { } } +pub use delayed_map::{DelayedMap, DelayedSet}; pub use impl_::*; |
