about summary refs log tree commit diff
diff options
context:
space:
mode:
authorAli MJ Al-Nasrawy <alimjalnasrawy@gmail.com>2023-12-04 20:42:41 +0000
committerAli MJ Al-Nasrawy <alimjalnasrawy@gmail.com>2023-12-13 14:57:52 +0000
commitf38d1e971dcba3a3e9739d0d5aaf5f14329118bd (patch)
tree152d72c1e0a68eb445c40bc7469fd58c20228714
parent6f40082313d8374bdf962aba943a712d5322fae6 (diff)
downloadrust-f38d1e971dcba3a3e9739d0d5aaf5f14329118bd.tar.gz
rust-f38d1e971dcba3a3e9739d0d5aaf5f14329118bd.zip
global param_env canonicalization cache
-rw-r--r--compiler/rustc_infer/src/infer/canonical/canonicalizer.rs94
-rw-r--r--compiler/rustc_infer/src/infer/combine.rs2
-rw-r--r--compiler/rustc_middle/src/infer/canonical.rs37
-rw-r--r--compiler/rustc_middle/src/query/mod.rs4
-rw-r--r--compiler/rustc_middle/src/ty/context.rs5
-rw-r--r--compiler/rustc_trait_selection/src/traits/misc.rs8
6 files changed, 128 insertions, 22 deletions
diff --git a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
index 473a3965885..d3ab3c0afe2 100644
--- a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
+++ b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
@@ -35,13 +35,13 @@ impl<'tcx> InferCtxt<'tcx> {
     /// [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html#canonicalizing-the-query
     pub fn canonicalize_query<V>(
         &self,
-        value: V,
+        value: ty::ParamEnvAnd<'tcx, V>,
         query_state: &mut OriginalQueryValues<'tcx>,
-    ) -> Canonical<'tcx, V>
+    ) -> Canonical<'tcx, ty::ParamEnvAnd<'tcx, V>>
     where
         V: TypeFoldable<TyCtxt<'tcx>>,
     {
-        Canonicalizer::canonicalize(value, self, self.tcx, &CanonicalizeAllFreeRegions, query_state)
+        self.canonicalize_query_with_mode(value, query_state, &CanonicalizeAllFreeRegions)
     }
 
     /// Like [Self::canonicalize_query], but preserves distinct universes. For
@@ -126,19 +126,52 @@ impl<'tcx> InferCtxt<'tcx> {
     /// handling of `'static` regions (e.g. trait evaluation).
     pub fn canonicalize_query_keep_static<V>(
         &self,
-        value: V,
+        value: ty::ParamEnvAnd<'tcx, V>,
         query_state: &mut OriginalQueryValues<'tcx>,
-    ) -> Canonical<'tcx, V>
+    ) -> Canonical<'tcx, ty::ParamEnvAnd<'tcx, V>>
     where
         V: TypeFoldable<TyCtxt<'tcx>>,
     {
-        Canonicalizer::canonicalize(
+        self.canonicalize_query_with_mode(
+            value,
+            query_state,
+            &CanonicalizeFreeRegionsOtherThanStatic,
+        )
+    }
+
+    fn canonicalize_query_with_mode<V>(
+        &self,
+        value: ty::ParamEnvAnd<'tcx, V>,
+        query_state: &mut OriginalQueryValues<'tcx>,
+        canonicalize_region_mode: &dyn CanonicalizeMode,
+    ) -> Canonical<'tcx, ty::ParamEnvAnd<'tcx, V>>
+    where
+        V: TypeFoldable<TyCtxt<'tcx>>,
+    {
+        let (param_env, value) = value.into_parts();
+        let base = self.tcx.canonical_param_env_cache.get_or_insert(
+            param_env,
+            query_state,
+            |query_state| {
+                Canonicalizer::canonicalize(
+                    param_env,
+                    self,
+                    self.tcx,
+                    &CanonicalizeFreeRegionsOtherThanStatic,
+                    query_state,
+                )
+            },
+        );
+
+        Canonicalizer::canonicalize_with_base(
+            base,
             value,
             self,
             self.tcx,
-            &CanonicalizeFreeRegionsOtherThanStatic,
+            canonicalize_region_mode,
             query_state,
         )
+        .unchecked_map(|(param_env, value)| param_env.and(value))
     }
 }
 
@@ -570,6 +603,33 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
     where
         V: TypeFoldable<TyCtxt<'tcx>>,
     {
+        let base = Canonical {
+            max_universe: ty::UniverseIndex::ROOT,
+            variables: List::empty(),
+            value: (),
+        };
+        Canonicalizer::canonicalize_with_base(
+            base,
+            value,
+            infcx,
+            tcx,
+            canonicalize_region_mode,
+            query_state,
+        )
+        .unchecked_map(|((), val)| val)
+    }
+
+    fn canonicalize_with_base<U, V>(
+        base: Canonical<'tcx, U>,
+        value: V,
+        infcx: &InferCtxt<'tcx>,
+        tcx: TyCtxt<'tcx>,
+        canonicalize_region_mode: &dyn CanonicalizeMode,
+        query_state: &mut OriginalQueryValues<'tcx>,
+    ) -> Canonical<'tcx, (U, V)>
+    where
+        V: TypeFoldable<TyCtxt<'tcx>>,
+    {
         let needs_canonical_flags = if canonicalize_region_mode.any() {
             TypeFlags::HAS_INFER | TypeFlags::HAS_PLACEHOLDER | TypeFlags::HAS_FREE_REGIONS
         } else {
@@ -578,12 +638,7 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
 
         // Fast path: nothing that needs to be canonicalized.
         if !value.has_type_flags(needs_canonical_flags) {
-            let canon_value = Canonical {
-                max_universe: ty::UniverseIndex::ROOT,
-                variables: List::empty(),
-                value,
-            };
-            return canon_value;
+            return base.unchecked_map(|b| (b, value));
         }
 
         let mut canonicalizer = Canonicalizer {
@@ -591,11 +646,20 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
             tcx,
             canonicalize_mode: canonicalize_region_mode,
             needs_canonical_flags,
-            variables: SmallVec::new(),
+            variables: SmallVec::from_slice(base.variables),
             query_state,
             indices: FxHashMap::default(),
             binder_index: ty::INNERMOST,
         };
+        if canonicalizer.query_state.var_values.spilled() {
+            canonicalizer.indices = canonicalizer
+                .query_state
+                .var_values
+                .iter()
+                .enumerate()
+                .map(|(i, &kind)| (kind, BoundVar::new(i)))
+                .collect();
+        }
         let out_value = value.fold_with(&mut canonicalizer);
 
         // Once we have canonicalized `out_value`, it should not
@@ -612,7 +676,7 @@ impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
             .max()
             .unwrap_or(ty::UniverseIndex::ROOT);
 
-        Canonical { max_universe, variables: canonical_variables, value: out_value }
+        Canonical { max_universe, variables: canonical_variables, value: (base.value, out_value) }
     }
 
     /// Creates a canonical variable replacing `kind` from the input,
diff --git a/compiler/rustc_infer/src/infer/combine.rs b/compiler/rustc_infer/src/infer/combine.rs
index bab21bc237a..6608fdab9d0 100644
--- a/compiler/rustc_infer/src/infer/combine.rs
+++ b/compiler/rustc_infer/src/infer/combine.rs
@@ -172,7 +172,7 @@ impl<'tcx> InferCtxt<'tcx> {
             // two const param's types are able to be equal has to go through a canonical query with the actual logic
             // in `rustc_trait_selection`.
             let canonical = self.canonicalize_query(
-                (relation.param_env(), a.ty(), b.ty()),
+                relation.param_env().and((a.ty(), b.ty())),
                 &mut OriginalQueryValues::default(),
             );
             self.tcx.check_tys_might_be_eq(canonical).map_err(|_| {
diff --git a/compiler/rustc_middle/src/infer/canonical.rs b/compiler/rustc_middle/src/infer/canonical.rs
index ef5a1caadb7..9208cd5febb 100644
--- a/compiler/rustc_middle/src/infer/canonical.rs
+++ b/compiler/rustc_middle/src/infer/canonical.rs
@@ -21,11 +21,14 @@
 //!
 //! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html
 
+use rustc_data_structures::fx::FxHashMap;
+use rustc_data_structures::sync::Lock;
 use rustc_macros::HashStable;
 use rustc_type_ir::Canonical as IrCanonical;
 use rustc_type_ir::CanonicalVarInfo as IrCanonicalVarInfo;
 pub use rustc_type_ir::{CanonicalTyVarKind, CanonicalVarKind};
 use smallvec::SmallVec;
+use std::collections::hash_map::Entry;
 use std::ops::Index;
 
 use crate::infer::MemberConstraint;
@@ -291,3 +294,37 @@ impl<'tcx> Index<BoundVar> for CanonicalVarValues<'tcx> {
         &self.var_values[value.as_usize()]
     }
 }
+
+#[derive(Default)]
+pub struct CanonicalParamEnvCache<'tcx> {
+    map: Lock<
+        FxHashMap<
+            ty::ParamEnv<'tcx>,
+            (Canonical<'tcx, ty::ParamEnv<'tcx>>, OriginalQueryValues<'tcx>),
+        >,
+    >,
+}
+
+impl<'tcx> CanonicalParamEnvCache<'tcx> {
+    pub fn get_or_insert(
+        &self,
+        key: ty::ParamEnv<'tcx>,
+        state: &mut OriginalQueryValues<'tcx>,
+        canonicalize_op: impl FnOnce(
+            &mut OriginalQueryValues<'tcx>,
+        ) -> Canonical<'tcx, ty::ParamEnv<'tcx>>,
+    ) -> Canonical<'tcx, ty::ParamEnv<'tcx>> {
+        match self.map.borrow().entry(key) {
+            Entry::Occupied(e) => {
+                let (canonical, state_cached) = e.get();
+                state.clone_from(state_cached);
+                canonical.clone()
+            }
+            Entry::Vacant(e) => {
+                let canonical = canonicalize_op(state);
+                e.insert((canonical.clone(), state.clone()));
+                canonical
+            }
+        }
+    }
+}
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index 03f3ceb8d17..a69bff6ed8c 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -2177,7 +2177,9 @@ rustc_queries! {
     /// Used in `super_combine_consts` to ICE if the type of the two consts are definitely not going to end up being
     /// equal to eachother. This might return `Ok` even if the types are not equal, but will never return `Err` if
     /// the types might be equal.
-    query check_tys_might_be_eq(arg: Canonical<'tcx, (ty::ParamEnv<'tcx>, Ty<'tcx>, Ty<'tcx>)>) -> Result<(), NoSolution> {
+    query check_tys_might_be_eq(
+        arg: Canonical<'tcx, ty::ParamEnvAnd<'tcx, (Ty<'tcx>, Ty<'tcx>)>>
+    ) -> Result<(), NoSolution> {
         desc { "check whether two const param are definitely not equal to eachother"}
     }
 
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index eb6fde83fcc..da2bdf4d93e 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -6,7 +6,7 @@ pub mod tls;
 
 use crate::arena::Arena;
 use crate::dep_graph::{DepGraph, DepKindStruct};
-use crate::infer::canonical::{CanonicalVarInfo, CanonicalVarInfos};
+use crate::infer::canonical::{CanonicalParamEnvCache, CanonicalVarInfo, CanonicalVarInfos};
 use crate::lint::struct_lint_level;
 use crate::metadata::ModChild;
 use crate::middle::codegen_fn_attrs::CodegenFnAttrs;
@@ -653,6 +653,8 @@ pub struct GlobalCtxt<'tcx> {
     pub new_solver_evaluation_cache: solve::EvaluationCache<'tcx>,
     pub new_solver_coherence_evaluation_cache: solve::EvaluationCache<'tcx>,
 
+    pub canonical_param_env_cache: CanonicalParamEnvCache<'tcx>,
+
     /// Data layout specification for the current target.
     pub data_layout: TargetDataLayout,
 
@@ -817,6 +819,7 @@ impl<'tcx> TyCtxt<'tcx> {
             evaluation_cache: Default::default(),
             new_solver_evaluation_cache: Default::default(),
             new_solver_coherence_evaluation_cache: Default::default(),
+            canonical_param_env_cache: Default::default(),
             data_layout,
             alloc_map: Lock::new(interpret::AllocMap::new()),
         }
diff --git a/compiler/rustc_trait_selection/src/traits/misc.rs b/compiler/rustc_trait_selection/src/traits/misc.rs
index 2f2411310a9..cf4fa233768 100644
--- a/compiler/rustc_trait_selection/src/traits/misc.rs
+++ b/compiler/rustc_trait_selection/src/traits/misc.rs
@@ -9,7 +9,7 @@ use rustc_infer::infer::canonical::Canonical;
 use rustc_infer::infer::{RegionResolutionError, TyCtxtInferExt};
 use rustc_infer::traits::query::NoSolution;
 use rustc_infer::{infer::outlives::env::OutlivesEnvironment, traits::FulfillmentError};
-use rustc_middle::ty::{self, AdtDef, GenericArg, List, ParamEnv, Ty, TyCtxt, TypeVisitableExt};
+use rustc_middle::ty::{self, AdtDef, GenericArg, List, Ty, TyCtxt, TypeVisitableExt};
 use rustc_span::DUMMY_SP;
 
 use super::outlives_bounds::InferCtxtExt;
@@ -209,10 +209,10 @@ pub fn all_fields_implement_trait<'tcx>(
 
 pub fn check_tys_might_be_eq<'tcx>(
     tcx: TyCtxt<'tcx>,
-    canonical: Canonical<'tcx, (ParamEnv<'tcx>, Ty<'tcx>, Ty<'tcx>)>,
+    canonical: Canonical<'tcx, ty::ParamEnvAnd<'tcx, (Ty<'tcx>, Ty<'tcx>)>>,
 ) -> Result<(), NoSolution> {
-    let (infcx, (param_env, ty_a, ty_b), _) =
-        tcx.infer_ctxt().build_with_canonical(DUMMY_SP, &canonical);
+    let (infcx, key, _) = tcx.infer_ctxt().build_with_canonical(DUMMY_SP, &canonical);
+    let (param_env, (ty_a, ty_b)) = key.into_parts();
     let ocx = ObligationCtxt::new(&infcx);
 
     let result = ocx.eq(&ObligationCause::dummy(), param_env, ty_a, ty_b);