about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2025-06-20 12:57:46 +0000
committerbors <bors@rust-lang.org>2025-06-20 12:57:46 +0000
commit9c4ff566babe632af5e30281a822d1ae9972873b (patch)
tree134aa9a785af9aa2ec6a2a8de171ac6894306b39
parent3b97f1308ff72016a4aaa93fbe6d09d4d6427815 (diff)
parent04a2eec304c687e8098f4d1d0e10a57b988924f2 (diff)
downloadrust-9c4ff566babe632af5e30281a822d1ae9972873b.tar.gz
rust-9c4ff566babe632af5e30281a822d1ae9972873b.zip
Auto merge of #142316 - compiler-errors:cache-param-env, r=lcnr
[perf] Cache the canonical *instantiation* of param-envs

r? lcnr
-rw-r--r--compiler/rustc_infer/src/infer/canonical/instantiate.rs189
-rw-r--r--compiler/rustc_middle/src/ty/context.rs8
2 files changed, 177 insertions, 20 deletions
diff --git a/compiler/rustc_infer/src/infer/canonical/instantiate.rs b/compiler/rustc_infer/src/infer/canonical/instantiate.rs
index 67f13192b52..2385c68ef6b 100644
--- a/compiler/rustc_infer/src/infer/canonical/instantiate.rs
+++ b/compiler/rustc_infer/src/infer/canonical/instantiate.rs
@@ -7,8 +7,11 @@
 //! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html
 
 use rustc_macros::extension;
-use rustc_middle::bug;
-use rustc_middle::ty::{self, FnMutDelegate, GenericArgKind, TyCtxt, TypeFoldable};
+use rustc_middle::ty::{
+    self, DelayedMap, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeSuperVisitable,
+    TypeVisitableExt, TypeVisitor,
+};
+use rustc_type_ir::TypeVisitable;
 
 use crate::infer::canonical::{Canonical, CanonicalVarValues};
 
@@ -58,23 +61,169 @@ where
     T: TypeFoldable<TyCtxt<'tcx>>,
 {
     if var_values.var_values.is_empty() {
-        value
-    } else {
-        let delegate = FnMutDelegate {
-            regions: &mut |br: ty::BoundRegion| match var_values[br.var].kind() {
-                GenericArgKind::Lifetime(l) => l,
-                r => bug!("{:?} is a region but value is {:?}", br, r),
-            },
-            types: &mut |bound_ty: ty::BoundTy| match var_values[bound_ty.var].kind() {
-                GenericArgKind::Type(ty) => ty,
-                r => bug!("{:?} is a type but value is {:?}", bound_ty, r),
-            },
-            consts: &mut |bound_ct: ty::BoundVar| match var_values[bound_ct].kind() {
-                GenericArgKind::Const(ct) => ct,
-                c => bug!("{:?} is a const but value is {:?}", bound_ct, c),
-            },
-        };
-
-        tcx.replace_escaping_bound_vars_uncached(value, delegate)
+        return value;
     }
+
+    value.fold_with(&mut CanonicalInstantiator {
+        tcx,
+        current_index: ty::INNERMOST,
+        var_values: var_values.var_values,
+        cache: Default::default(),
+    })
+}
+
+/// Replaces the bound vars in a canonical binder with var values.
+struct CanonicalInstantiator<'tcx> {
+    tcx: TyCtxt<'tcx>,
+
+    // The values that the bound vars are are being instantiated with.
+    var_values: ty::GenericArgsRef<'tcx>,
+
+    /// As with `BoundVarReplacer`, represents the index of a binder *just outside*
+    /// the ones we have visited.
+    current_index: ty::DebruijnIndex,
+
+    // Instantiation is a pure function of `DebruijnIndex` and `Ty`.
+    cache: DelayedMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>,
+}
+
+impl<'tcx> TypeFolder<TyCtxt<'tcx>> for CanonicalInstantiator<'tcx> {
+    fn cx(&self) -> TyCtxt<'tcx> {
+        self.tcx
+    }
+
+    fn fold_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
+        &mut self,
+        t: ty::Binder<'tcx, T>,
+    ) -> ty::Binder<'tcx, T> {
+        self.current_index.shift_in(1);
+        let t = t.super_fold_with(self);
+        self.current_index.shift_out(1);
+        t
+    }
+
+    fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
+        match *t.kind() {
+            ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => {
+                self.var_values[bound_ty.var.as_usize()].expect_ty()
+            }
+            _ => {
+                if !t.has_vars_bound_at_or_above(self.current_index) {
+                    t
+                } else if let Some(&t) = self.cache.get(&(self.current_index, t)) {
+                    t
+                } else {
+                    let res = t.super_fold_with(self);
+                    assert!(self.cache.insert((self.current_index, t), res));
+                    res
+                }
+            }
+        }
+    }
+
+    fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
+        match r.kind() {
+            ty::ReBound(debruijn, br) if debruijn == self.current_index => {
+                self.var_values[br.var.as_usize()].expect_region()
+            }
+            _ => r,
+        }
+    }
+
+    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
+        match ct.kind() {
+            ty::ConstKind::Bound(debruijn, bound_const) if debruijn == self.current_index => {
+                self.var_values[bound_const.as_usize()].expect_const()
+            }
+            _ => ct.super_fold_with(self),
+        }
+    }
+
+    fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
+        if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
+    }
+
+    fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
+        if !c.has_vars_bound_at_or_above(self.current_index) {
+            return c;
+        }
+
+        // Since instantiation is a function of `DebruijnIndex`, we don't want
+        // to have to cache more copies of clauses when we're inside of binders.
+        // Since we currently expect to only have clauses in the outermost
+        // debruijn index, we just fold if we're inside of a binder.
+        if self.current_index > ty::INNERMOST {
+            return c.super_fold_with(self);
+        }
+
+        // Our cache key is `(clauses, var_values)`, but we also don't care about
+        // var values that aren't named in the clauses, since they can change without
+        // affecting the output. Since `ParamEnv`s are cached first, we compute the
+        // last var value that is mentioned in the clauses, and cut off the list so
+        // that we have more hits in the cache.
+
+        // We also cache the computation of "highest var named by clauses" since that
+        // is both expensive (depending on the size of the clauses) and a pure function.
+        let index = *self
+            .tcx
+            .highest_var_in_clauses_cache
+            .lock()
+            .entry(c)
+            .or_insert_with(|| highest_var_in_clauses(c));
+        let c_args = &self.var_values[..=index];
+
+        if let Some(c) = self.tcx.clauses_cache.lock().get(&(c, c_args)) {
+            c
+        } else {
+            let folded = c.super_fold_with(self);
+            self.tcx.clauses_cache.lock().insert((c, c_args), folded);
+            folded
+        }
+    }
+}
+
+fn highest_var_in_clauses<'tcx>(c: ty::Clauses<'tcx>) -> usize {
+    struct HighestVarInClauses {
+        max_var: usize,
+        current_index: ty::DebruijnIndex,
+    }
+    impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for HighestVarInClauses {
+        fn visit_binder<T: TypeVisitable<TyCtxt<'tcx>>>(
+            &mut self,
+            t: &ty::Binder<'tcx, T>,
+        ) -> Self::Result {
+            self.current_index.shift_in(1);
+            let t = t.super_visit_with(self);
+            self.current_index.shift_out(1);
+            t
+        }
+        fn visit_ty(&mut self, t: Ty<'tcx>) {
+            if let ty::Bound(debruijn, bound_ty) = *t.kind()
+                && debruijn == self.current_index
+            {
+                self.max_var = self.max_var.max(bound_ty.var.as_usize());
+            } else if t.has_vars_bound_at_or_above(self.current_index) {
+                t.super_visit_with(self);
+            }
+        }
+        fn visit_region(&mut self, r: ty::Region<'tcx>) {
+            if let ty::ReBound(debruijn, bound_region) = r.kind()
+                && debruijn == self.current_index
+            {
+                self.max_var = self.max_var.max(bound_region.var.as_usize());
+            }
+        }
+        fn visit_const(&mut self, ct: ty::Const<'tcx>) {
+            if let ty::ConstKind::Bound(debruijn, bound_const) = ct.kind()
+                && debruijn == self.current_index
+            {
+                self.max_var = self.max_var.max(bound_const.as_usize());
+            } else if ct.has_vars_bound_at_or_above(self.current_index) {
+                ct.super_visit_with(self);
+            }
+        }
+    }
+    let mut visitor = HighestVarInClauses { max_var: 0, current_index: ty::INNERMOST };
+    c.visit_with(&mut visitor);
+    visitor.max_var
 }
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index 82825b9b38e..4f8cfd86597 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -1479,6 +1479,12 @@ pub struct GlobalCtxt<'tcx> {
 
     pub canonical_param_env_cache: CanonicalParamEnvCache<'tcx>,
 
+    /// Caches the index of the highest bound var in clauses in a canonical binder.
+    pub highest_var_in_clauses_cache: Lock<FxHashMap<ty::Clauses<'tcx>, usize>>,
+    /// Caches the instantiation of a canonical binder given a set of args.
+    pub clauses_cache:
+        Lock<FxHashMap<(ty::Clauses<'tcx>, &'tcx [ty::GenericArg<'tcx>]), ty::Clauses<'tcx>>>,
+
     /// Data layout specification for the current target.
     pub data_layout: TargetDataLayout,
 
@@ -1727,6 +1733,8 @@ impl<'tcx> TyCtxt<'tcx> {
             new_solver_evaluation_cache: Default::default(),
             new_solver_canonical_param_env_cache: Default::default(),
             canonical_param_env_cache: Default::default(),
+            highest_var_in_clauses_cache: Default::default(),
+            clauses_cache: Default::default(),
             data_layout,
             alloc_map: interpret::AllocMap::new(),
             current_gcx,