about summary refs log tree commit diff
diff options
context:
space:
mode:
authorlcnr <rust@lcnr.de>2024-09-30 10:18:55 +0200
committerlcnr <rust@lcnr.de>2024-10-01 17:20:31 +0200
commit13881f5404037e25a88d0b79a836e232dc73b1fc (patch)
tree27ede0dfb8082dd78a8ce7cdb10af91f00dddbe2
parent15ac6983930a0d49b921c0330dbbb5a4f8f1d34a (diff)
downloadrust-13881f5404037e25a88d0b79a836e232dc73b1fc.tar.gz
rust-13881f5404037e25a88d0b79a836e232dc73b1fc.zip
add caches to multiple type folders
-rw-r--r--compiler/rustc_infer/src/infer/relate/combine.rs7
-rw-r--r--compiler/rustc_infer/src/infer/relate/type_relating.rs41
-rw-r--r--compiler/rustc_infer/src/infer/resolve.rs14
-rw-r--r--compiler/rustc_middle/src/ty/fold.rs25
-rw-r--r--compiler/rustc_next_trait_solver/src/resolve.rs11
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs51
-rw-r--r--compiler/rustc_type_ir/src/data_structures/delayed_map.rs92
-rw-r--r--compiler/rustc_type_ir/src/data_structures/mod.rs (renamed from compiler/rustc_type_ir/src/data_structures.rs)3
8 files changed, 222 insertions, 22 deletions
diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs
index e75d7b7db14..3b2ef3fe981 100644
--- a/compiler/rustc_infer/src/infer/relate/combine.rs
+++ b/compiler/rustc_infer/src/infer/relate/combine.rs
@@ -36,10 +36,15 @@ use crate::traits::{Obligation, PredicateObligation};
 #[derive(Clone)]
 pub struct CombineFields<'infcx, 'tcx> {
     pub infcx: &'infcx InferCtxt<'tcx>,
+    // Immutable fields
     pub trace: TypeTrace<'tcx>,
     pub param_env: ty::ParamEnv<'tcx>,
-    pub goals: Vec<Goal<'tcx, ty::Predicate<'tcx>>>,
     pub define_opaque_types: DefineOpaqueTypes,
+    // Mutable fields
+    //
+    // Adding any additional field likely requires
+    // changes to the cache of `TypeRelating`.
+    pub goals: Vec<Goal<'tcx, ty::Predicate<'tcx>>>,
 }
 
 impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs
index a402a8009ff..c8ced4a2cda 100644
--- a/compiler/rustc_infer/src/infer/relate/type_relating.rs
+++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs
@@ -4,6 +4,7 @@ use rustc_middle::ty::relate::{
 };
 use rustc_middle::ty::{self, Ty, TyCtxt, TyVar};
 use rustc_span::Span;
+use rustc_type_ir::data_structures::DelayedSet;
 use tracing::{debug, instrument};
 
 use super::combine::CombineFields;
@@ -13,9 +14,36 @@ use crate::infer::{DefineOpaqueTypes, InferCtxt, SubregionOrigin};
 
 /// Enforce that `a` is equal to or a subtype of `b`.
 pub struct TypeRelating<'combine, 'a, 'tcx> {
+    // Partially mutable.
     fields: &'combine mut CombineFields<'a, 'tcx>,
+
+    // Immutable fields.
     structurally_relate_aliases: StructurallyRelateAliases,
     ambient_variance: ty::Variance,
+
+    /// The cache has only tracks the `ambient_variance` as its the
+    /// only field which is mutable and which meaningfully changes
+    /// the result when relating types.
+    ///
+    /// The cache does not track whether the state of the
+    /// `InferCtxt` has been changed or whether we've added any
+    /// obligations to `self.fields.goals`. Whether a goal is added
+    /// once or multiple times is not really meaningful.
+    ///
+    /// Changes in the inference state may delay some type inference to
+    /// the next fulfillment loop. Given that this loop is already
+    /// necessary, this is also not a meaningful change. Consider
+    /// the following three relations:
+    /// ```text
+    /// Vec<?0> sub Vec<?1>
+    /// ?0 eq u32
+    /// Vec<?0> sub Vec<?1>
+    /// ```
+    /// Without a cache, the second `Vec<?0> sub Vec<?1>` would eagerly
+    /// constrain `?1` to `u32`. When using the cache entry from the
+    /// first time we've related these types, this only happens when
+    /// later proving the `Subtype(?0, ?1)` goal from the first relation.
+    cache: DelayedSet<(ty::Variance, Ty<'tcx>, Ty<'tcx>)>,
 }
 
 impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> {
@@ -24,7 +52,12 @@ impl<'combine, 'infcx, 'tcx> TypeRelating<'combine, 'infcx, 'tcx> {
         structurally_relate_aliases: StructurallyRelateAliases,
         ambient_variance: ty::Variance,
     ) -> TypeRelating<'combine, 'infcx, 'tcx> {
-        TypeRelating { fields: f, structurally_relate_aliases, ambient_variance }
+        TypeRelating {
+            fields: f,
+            structurally_relate_aliases,
+            ambient_variance,
+            cache: Default::default(),
+        }
     }
 }
 
@@ -78,6 +111,10 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
         let a = infcx.shallow_resolve(a);
         let b = infcx.shallow_resolve(b);
 
+        if self.cache.contains(&(self.ambient_variance, a, b)) {
+            return Ok(a);
+        }
+
         match (a.kind(), b.kind()) {
             (&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => {
                 match self.ambient_variance {
@@ -160,6 +197,8 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
             }
         }
 
+        assert!(self.cache.insert((self.ambient_variance, a, b)));
+
         Ok(a)
     }
 
diff --git a/compiler/rustc_infer/src/infer/resolve.rs b/compiler/rustc_infer/src/infer/resolve.rs
index 34625ffb778..671a66d504f 100644
--- a/compiler/rustc_infer/src/infer/resolve.rs
+++ b/compiler/rustc_infer/src/infer/resolve.rs
@@ -2,6 +2,7 @@ use rustc_middle::bug;
 use rustc_middle::ty::fold::{FallibleTypeFolder, TypeFolder, TypeSuperFoldable};
 use rustc_middle::ty::visit::TypeVisitableExt;
 use rustc_middle::ty::{self, Const, InferConst, Ty, TyCtxt, TypeFoldable};
+use rustc_type_ir::data_structures::DelayedMap;
 
 use super::{FixupError, FixupResult, InferCtxt};
 
@@ -15,12 +16,15 @@ use super::{FixupError, FixupResult, InferCtxt};
 /// points for correctness.
 pub struct OpportunisticVarResolver<'a, 'tcx> {
     infcx: &'a InferCtxt<'tcx>,
+    /// We're able to use a cache here as the folder does
+    /// not have any mutable state.
+    cache: DelayedMap<Ty<'tcx>, Ty<'tcx>>,
 }
 
 impl<'a, 'tcx> OpportunisticVarResolver<'a, 'tcx> {
     #[inline]
     pub fn new(infcx: &'a InferCtxt<'tcx>) -> Self {
-        OpportunisticVarResolver { infcx }
+        OpportunisticVarResolver { infcx, cache: Default::default() }
     }
 }
 
@@ -33,9 +37,13 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for OpportunisticVarResolver<'a, 'tcx> {
     fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
         if !t.has_non_region_infer() {
             t // micro-optimize -- if there is nothing in this type that this fold affects...
+        } else if let Some(&ty) = self.cache.get(&t) {
+            return ty;
         } else {
-            let t = self.infcx.shallow_resolve(t);
-            t.super_fold_with(self)
+            let shallow = self.infcx.shallow_resolve(t);
+            let res = shallow.super_fold_with(self);
+            assert!(self.cache.insert(t, res));
+            res
         }
     }
 
diff --git a/compiler/rustc_middle/src/ty/fold.rs b/compiler/rustc_middle/src/ty/fold.rs
index 2ee7497497a..e152d3f5fbe 100644
--- a/compiler/rustc_middle/src/ty/fold.rs
+++ b/compiler/rustc_middle/src/ty/fold.rs
@@ -1,5 +1,6 @@
 use rustc_data_structures::fx::FxIndexMap;
 use rustc_hir::def_id::DefId;
+use rustc_type_ir::data_structures::DelayedMap;
 pub use rustc_type_ir::fold::{
     FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable, shift_region, shift_vars,
 };
@@ -131,12 +132,20 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for RegionFolder<'a, 'tcx> {
 ///////////////////////////////////////////////////////////////////////////
 // Bound vars replacer
 
+/// A delegate used when instantiating bound vars.
+///
+/// Any implementation must make sure that each bound variable always
+/// gets mapped to the same result. `BoundVarReplacer` caches by using
+/// a `DelayedMap` which does not cache the first few types it encounters.
 pub trait BoundVarReplacerDelegate<'tcx> {
     fn replace_region(&mut self, br: ty::BoundRegion) -> ty::Region<'tcx>;
     fn replace_ty(&mut self, bt: ty::BoundTy) -> Ty<'tcx>;
     fn replace_const(&mut self, bv: ty::BoundVar) -> ty::Const<'tcx>;
 }
 
+/// A simple delegate taking 3 mutable functions. The used functions must
+/// always return the same result for each bound variable, no matter how
+/// frequently they are called.
 pub struct FnMutDelegate<'a, 'tcx> {
     pub regions: &'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a),
     pub types: &'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a),
@@ -164,11 +173,15 @@ struct BoundVarReplacer<'tcx, D> {
     current_index: ty::DebruijnIndex,
 
     delegate: D,
+
+    /// This cache only tracks the `DebruijnIndex` and assumes that it does not matter
+    /// for the delegate how often its methods get used.
+    cache: DelayedMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>,
 }
 
 impl<'tcx, D: BoundVarReplacerDelegate<'tcx>> BoundVarReplacer<'tcx, D> {
     fn new(tcx: TyCtxt<'tcx>, delegate: D) -> Self {
-        BoundVarReplacer { tcx, current_index: ty::INNERMOST, delegate }
+        BoundVarReplacer { tcx, current_index: ty::INNERMOST, delegate, cache: Default::default() }
     }
 }
 
@@ -197,7 +210,15 @@ where
                 debug_assert!(!ty.has_vars_bound_above(ty::INNERMOST));
                 ty::fold::shift_vars(self.tcx, ty, self.current_index.as_u32())
             }
-            _ if t.has_vars_bound_at_or_above(self.current_index) => t.super_fold_with(self),
+            _ if t.has_vars_bound_at_or_above(self.current_index) => {
+                if let Some(&ty) = self.cache.get(&(self.current_index, t)) {
+                    return ty;
+                }
+
+                let res = t.super_fold_with(self);
+                assert!(self.cache.insert((self.current_index, t), res));
+                res
+            }
             _ => t,
         }
     }
diff --git a/compiler/rustc_next_trait_solver/src/resolve.rs b/compiler/rustc_next_trait_solver/src/resolve.rs
index 132b7400300..a37066cec66 100644
--- a/compiler/rustc_next_trait_solver/src/resolve.rs
+++ b/compiler/rustc_next_trait_solver/src/resolve.rs
@@ -1,3 +1,4 @@
+use rustc_type_ir::data_structures::DelayedMap;
 use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
 use rustc_type_ir::inherent::*;
 use rustc_type_ir::visit::TypeVisitableExt;
@@ -15,11 +16,12 @@ where
     I: Interner,
 {
     delegate: &'a D,
+    cache: DelayedMap<I::Ty, I::Ty>,
 }
 
 impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
     pub fn new(delegate: &'a D) -> Self {
-        EagerResolver { delegate }
+        EagerResolver { delegate, cache: Default::default() }
     }
 }
 
@@ -42,7 +44,12 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
             ty::Infer(ty::FloatVar(vid)) => self.delegate.opportunistic_resolve_float_var(vid),
             _ => {
                 if t.has_infer() {
-                    t.super_fold_with(self)
+                    if let Some(&ty) = self.cache.get(&t) {
+                        return ty;
+                    }
+                    let res = t.super_fold_with(self);
+                    assert!(self.cache.insert(t, res));
+                    res
                 } else {
                     t
                 }
diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
index 12ad0724b5c..12b4b3cb3a9 100644
--- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
@@ -3,7 +3,7 @@ use std::ops::ControlFlow;
 use derive_where::derive_where;
 #[cfg(feature = "nightly")]
 use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable};
-use rustc_type_ir::data_structures::ensure_sufficient_stack;
+use rustc_type_ir::data_structures::{HashMap, HashSet, ensure_sufficient_stack};
 use rustc_type_ir::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
 use rustc_type_ir::inherent::*;
 use rustc_type_ir::relate::Relate;
@@ -579,18 +579,16 @@ where
 
     #[instrument(level = "trace", skip(self))]
     pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<I, ty::NormalizesTo<I>>) {
-        goal.predicate = goal
-            .predicate
-            .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
+        goal.predicate =
+            goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env));
         self.inspect.add_normalizes_to_goal(self.delegate, self.max_input_universe, goal);
         self.nested_goals.normalizes_to_goals.push(goal);
     }
 
     #[instrument(level = "debug", skip(self))]
     pub(super) fn add_goal(&mut self, source: GoalSource, mut goal: Goal<I, I::Predicate>) {
-        goal.predicate = goal
-            .predicate
-            .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
+        goal.predicate =
+            goal.predicate.fold_with(&mut ReplaceAliasWithInfer::new(self, goal.param_env));
         self.inspect.add_goal(self.delegate, self.max_input_universe, source, goal);
         self.nested_goals.goals.push((source, goal));
     }
@@ -654,6 +652,7 @@ where
             term: I::Term,
             universe_of_term: ty::UniverseIndex,
             delegate: &'a D,
+            cache: HashSet<I::Ty>,
         }
 
         impl<D: SolverDelegate<Interner = I>, I: Interner> ContainsTermOrNotNameable<'_, D, I> {
@@ -671,6 +670,10 @@ where
         {
             type Result = ControlFlow<()>;
             fn visit_ty(&mut self, t: I::Ty) -> Self::Result {
+                if self.cache.contains(&t) {
+                    return ControlFlow::Continue(());
+                }
+
                 match t.kind() {
                     ty::Infer(ty::TyVar(vid)) => {
                         if let ty::TermKind::Ty(term) = self.term.kind() {
@@ -683,17 +686,18 @@ where
                             }
                         }
 
-                        self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())
+                        self.check_nameable(self.delegate.universe_of_ty(vid).unwrap())?;
                     }
-                    ty::Placeholder(p) => self.check_nameable(p.universe()),
+                    ty::Placeholder(p) => self.check_nameable(p.universe())?,
                     _ => {
                         if t.has_non_region_infer() || t.has_placeholders() {
-                            t.super_visit_with(self)
-                        } else {
-                            ControlFlow::Continue(())
+                            t.super_visit_with(self)?
                         }
                     }
                 }
+
+                assert!(self.cache.insert(t));
+                ControlFlow::Continue(())
             }
 
             fn visit_const(&mut self, c: I::Const) -> Self::Result {
@@ -728,6 +732,7 @@ where
             delegate: self.delegate,
             universe_of_term,
             term: goal.predicate.term,
+            cache: Default::default(),
         };
         goal.predicate.alias.visit_with(&mut visitor).is_continue()
             && goal.param_env.visit_with(&mut visitor).is_continue()
@@ -1015,6 +1020,17 @@ where
 {
     ecx: &'me mut EvalCtxt<'a, D>,
     param_env: I::ParamEnv,
+    cache: HashMap<I::Ty, I::Ty>,
+}
+
+impl<'me, 'a, D, I> ReplaceAliasWithInfer<'me, 'a, D, I>
+where
+    D: SolverDelegate<Interner = I>,
+    I: Interner,
+{
+    fn new(ecx: &'me mut EvalCtxt<'a, D>, param_env: I::ParamEnv) -> Self {
+        ReplaceAliasWithInfer { ecx, param_env, cache: Default::default() }
+    }
 }
 
 impl<D, I> TypeFolder<I> for ReplaceAliasWithInfer<'_, '_, D, I>
@@ -1041,7 +1057,16 @@ where
                 );
                 infer_ty
             }
-            _ => ty.super_fold_with(self),
+            _ if ty.has_aliases() => {
+                if let Some(&entry) = self.cache.get(&ty) {
+                    return entry;
+                }
+
+                let res = ty.super_fold_with(self);
+                assert!(self.cache.insert(ty, res).is_none());
+                res
+            }
+            _ => ty,
         }
     }
 
diff --git a/compiler/rustc_type_ir/src/data_structures/delayed_map.rs b/compiler/rustc_type_ir/src/data_structures/delayed_map.rs
new file mode 100644
index 00000000000..7e7406e3706
--- /dev/null
+++ b/compiler/rustc_type_ir/src/data_structures/delayed_map.rs
@@ -0,0 +1,92 @@
+use std::hash::Hash;
+
+use crate::data_structures::{HashMap, HashSet};
+
+const CACHE_CUTOFF: u32 = 32;
+
+/// A hashmap which only starts hashing after ignoring the first few inputs.
+///
+/// This is used in type folders as in nearly all cases caching is not worth it
+/// as nearly all folded types are tiny. However, there are very rare incredibly
+/// large types for which caching is necessary to avoid hangs.
+#[derive(Debug)]
+pub struct DelayedMap<K, V> {
+    cache: HashMap<K, V>,
+    count: u32,
+}
+
+impl<K, V> Default for DelayedMap<K, V> {
+    fn default() -> Self {
+        DelayedMap { cache: Default::default(), count: 0 }
+    }
+}
+
+impl<K: Hash + Eq, V> DelayedMap<K, V> {
+    #[inline(always)]
+    pub fn insert(&mut self, key: K, value: V) -> bool {
+        if self.count >= CACHE_CUTOFF {
+            self.cold_insert(key, value)
+        } else {
+            self.count += 1;
+            true
+        }
+    }
+
+    #[cold]
+    #[inline(never)]
+    fn cold_insert(&mut self, key: K, value: V) -> bool {
+        self.cache.insert(key, value).is_none()
+    }
+
+    #[inline(always)]
+    pub fn get(&self, key: &K) -> Option<&V> {
+        if self.cache.is_empty() { None } else { self.cold_get(key) }
+    }
+
+    #[cold]
+    #[inline(never)]
+    fn cold_get(&self, key: &K) -> Option<&V> {
+        self.cache.get(key)
+    }
+}
+
+#[derive(Debug)]
+pub struct DelayedSet<T> {
+    cache: HashSet<T>,
+    count: u32,
+}
+
+impl<T> Default for DelayedSet<T> {
+    fn default() -> Self {
+        DelayedSet { cache: Default::default(), count: 0 }
+    }
+}
+
+impl<T: Hash + Eq> DelayedSet<T> {
+    #[inline(always)]
+    pub fn insert(&mut self, value: T) -> bool {
+        if self.count >= CACHE_CUTOFF {
+            self.cold_insert(value)
+        } else {
+            self.count += 1;
+            true
+        }
+    }
+
+    #[cold]
+    #[inline(never)]
+    fn cold_insert(&mut self, value: T) -> bool {
+        self.cache.insert(value)
+    }
+
+    #[inline(always)]
+    pub fn contains(&self, value: &T) -> bool {
+        !self.cache.is_empty() && self.cold_contains(value)
+    }
+
+    #[cold]
+    #[inline(never)]
+    fn cold_contains(&self, value: &T) -> bool {
+        self.cache.contains(value)
+    }
+}
diff --git a/compiler/rustc_type_ir/src/data_structures.rs b/compiler/rustc_type_ir/src/data_structures/mod.rs
index 96036e53b0a..d9766d91845 100644
--- a/compiler/rustc_type_ir/src/data_structures.rs
+++ b/compiler/rustc_type_ir/src/data_structures/mod.rs
@@ -6,6 +6,8 @@ pub use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
 pub type IndexMap<K, V> = indexmap::IndexMap<K, V, BuildHasherDefault<FxHasher>>;
 pub type IndexSet<V> = indexmap::IndexSet<V, BuildHasherDefault<FxHasher>>;
 
+mod delayed_map;
+
 #[cfg(feature = "nightly")]
 mod impl_ {
     pub use rustc_data_structures::sso::{SsoHashMap, SsoHashSet};
@@ -24,4 +26,5 @@ mod impl_ {
     }
 }
 
+pub use delayed_map::{DelayedMap, DelayedSet};
 pub use impl_::*;