about summary refs log tree commit diff
diff options
context:
space:
mode:
authorlcnr <rust@lcnr.de>2024-05-13 15:50:59 +0000
committerlcnr <rust@lcnr.de>2024-05-20 20:40:02 +0000
commitee0f20bb97ee834d79f69506a959ba9a78f16683 (patch)
treeac95f34f136d4e04b924270ce54a2a46dc137d80
parent82df0c3540dd6325ee4401ffd9f4ad8f13120a6a (diff)
downloadrust-ee0f20bb97ee834d79f69506a959ba9a78f16683.tar.gz
rust-ee0f20bb97ee834d79f69506a959ba9a78f16683.zip
move global cache lookup into fn
-rw-r--r--compiler/rustc_middle/src/traits/solve/cache.rs34
-rw-r--r--compiler/rustc_trait_selection/src/solve/search_graph.rs86
2 files changed, 62 insertions, 58 deletions
diff --git a/compiler/rustc_middle/src/traits/solve/cache.rs b/compiler/rustc_middle/src/traits/solve/cache.rs
index 03ce7cf98cf..2ff4ade21d0 100644
--- a/compiler/rustc_middle/src/traits/solve/cache.rs
+++ b/compiler/rustc_middle/src/traits/solve/cache.rs
@@ -14,11 +14,11 @@ pub struct EvaluationCache<'tcx> {
     map: Lock<FxHashMap<CanonicalInput<'tcx>, CacheEntry<'tcx>>>,
 }
 
-#[derive(PartialEq, Eq)]
+#[derive(Debug, PartialEq, Eq)]
 pub struct CacheData<'tcx> {
     pub result: QueryResult<'tcx>,
     pub proof_tree: Option<&'tcx [inspect::GoalEvaluationStep<TyCtxt<'tcx>>]>,
-    pub reached_depth: usize,
+    pub additional_depth: usize,
     pub encountered_overflow: bool,
 }
 
@@ -29,7 +29,7 @@ impl<'tcx> EvaluationCache<'tcx> {
         tcx: TyCtxt<'tcx>,
         key: CanonicalInput<'tcx>,
         proof_tree: Option<&'tcx [inspect::GoalEvaluationStep<TyCtxt<'tcx>>]>,
-        reached_depth: usize,
+        additional_depth: usize,
         encountered_overflow: bool,
         cycle_participants: FxHashSet<CanonicalInput<'tcx>>,
         dep_node: DepNodeIndex,
@@ -40,17 +40,17 @@ impl<'tcx> EvaluationCache<'tcx> {
         let data = WithDepNode::new(dep_node, QueryData { result, proof_tree });
         entry.cycle_participants.extend(cycle_participants);
         if encountered_overflow {
-            entry.with_overflow.insert(reached_depth, data);
+            entry.with_overflow.insert(additional_depth, data);
         } else {
-            entry.success = Some(Success { data, reached_depth });
+            entry.success = Some(Success { data, additional_depth });
         }
 
         if cfg!(debug_assertions) {
             drop(map);
-            if Some(CacheData { result, proof_tree, reached_depth, encountered_overflow })
-                != self.get(tcx, key, |_| false, Limit(reached_depth))
-            {
-                bug!("unable to retrieve inserted element from cache: {key:?}");
+            let expected = CacheData { result, proof_tree, additional_depth, encountered_overflow };
+            let actual = self.get(tcx, key, [], Limit(additional_depth));
+            if !actual.as_ref().is_some_and(|actual| expected == *actual) {
+                bug!("failed to lookup inserted element for {key:?}: {expected:?} != {actual:?}");
             }
         }
     }
@@ -63,23 +63,25 @@ impl<'tcx> EvaluationCache<'tcx> {
         &self,
         tcx: TyCtxt<'tcx>,
         key: CanonicalInput<'tcx>,
-        cycle_participant_in_stack: impl FnOnce(&FxHashSet<CanonicalInput<'tcx>>) -> bool,
+        stack_entries: impl IntoIterator<Item = CanonicalInput<'tcx>>,
         available_depth: Limit,
     ) -> Option<CacheData<'tcx>> {
         let map = self.map.borrow();
         let entry = map.get(&key)?;
 
-        if cycle_participant_in_stack(&entry.cycle_participants) {
-            return None;
+        for stack_entry in stack_entries {
+            if entry.cycle_participants.contains(&stack_entry) {
+                return None;
+            }
         }
 
         if let Some(ref success) = entry.success {
-            if available_depth.value_within_limit(success.reached_depth) {
+            if available_depth.value_within_limit(success.additional_depth) {
                 let QueryData { result, proof_tree } = success.data.get(tcx);
                 return Some(CacheData {
                     result,
                     proof_tree,
-                    reached_depth: success.reached_depth,
+                    additional_depth: success.additional_depth,
                     encountered_overflow: false,
                 });
             }
@@ -90,7 +92,7 @@ impl<'tcx> EvaluationCache<'tcx> {
             CacheData {
                 result,
                 proof_tree,
-                reached_depth: available_depth.0,
+                additional_depth: available_depth.0,
                 encountered_overflow: true,
             }
         })
@@ -99,7 +101,7 @@ impl<'tcx> EvaluationCache<'tcx> {
 
 struct Success<'tcx> {
     data: WithDepNode<QueryData<'tcx>>,
-    reached_depth: usize,
+    additional_depth: usize,
 }
 
 #[derive(Clone, Copy)]
diff --git a/compiler/rustc_trait_selection/src/solve/search_graph.rs b/compiler/rustc_trait_selection/src/solve/search_graph.rs
index 87e4fd9ae73..6cc674dcfed 100644
--- a/compiler/rustc_trait_selection/src/solve/search_graph.rs
+++ b/compiler/rustc_trait_selection/src/solve/search_graph.rs
@@ -134,16 +134,6 @@ impl<I: Interner> SearchGraph<I> {
         self.mode
     }
 
-    /// Update the stack and reached depths on cache hits.
-    #[instrument(level = "trace", skip(self))]
-    fn on_cache_hit(&mut self, additional_depth: usize, encountered_overflow: bool) {
-        let reached_depth = self.stack.next_index().plus(additional_depth);
-        if let Some(last) = self.stack.raw.last_mut() {
-            last.reached_depth = last.reached_depth.max(reached_depth);
-            last.encountered_overflow |= encountered_overflow;
-        }
-    }
-
     /// Pops the highest goal from the stack, lazily updating the
     /// the next goal in the stack.
     ///
@@ -276,37 +266,7 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
             return Self::response_no_constraints(tcx, input, Certainty::overflow(true));
         };
 
-        // Try to fetch the goal from the global cache.
-        'global: {
-            let Some(CacheData { result, proof_tree, reached_depth, encountered_overflow }) =
-                self.global_cache(tcx).get(
-                    tcx,
-                    input,
-                    |cycle_participants| {
-                        self.stack.iter().any(|entry| cycle_participants.contains(&entry.input))
-                    },
-                    available_depth,
-                )
-            else {
-                break 'global;
-            };
-
-            // If we're building a proof tree and the current cache entry does not
-            // contain a proof tree, we do not use the entry but instead recompute
-            // the goal. We simply overwrite the existing entry once we're done,
-            // caching the proof tree.
-            if !inspect.is_noop() {
-                if let Some(revisions) = proof_tree {
-                    inspect.goal_evaluation_kind(
-                        inspect::WipCanonicalGoalEvaluationKind::Interned { revisions },
-                    );
-                } else {
-                    break 'global;
-                }
-            }
-
-            self.on_cache_hit(reached_depth, encountered_overflow);
-            debug!("global cache hit");
+        if let Some(result) = self.lookup_global_cache(tcx, input, available_depth, inspect) {
             return result;
         }
 
@@ -388,7 +348,10 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
 
         // This is for global caching, so we properly track query dependencies.
         // Everything that affects the `result` should be performed within this
-        // `with_anon_task` closure.
+        // `with_anon_task` closure. If computing this goal depends on something
+        // not tracked by the cache key and from outside of this anon task, it
+        // must not be added to the global cache. Notably, this is the case for
+        // trait solver cycles participants.
         let ((final_entry, result), dep_node) =
             tcx.dep_graph.with_anon_task(tcx, dep_kinds::TraitSelect, || {
                 for _ in 0..FIXPOINT_STEP_LIMIT {
@@ -446,6 +409,45 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
 
         result
     }
+
+    /// Try to fetch a previously computed result from the global cache,
+    /// making sure to only do so if it would match the result of reevaluating
+    /// this goal.
+    fn lookup_global_cache(
+        &mut self,
+        tcx: TyCtxt<'tcx>,
+        input: CanonicalInput<'tcx>,
+        available_depth: Limit,
+        inspect: &mut ProofTreeBuilder<TyCtxt<'tcx>>,
+    ) -> Option<QueryResult<'tcx>> {
+        let CacheData { result, proof_tree, additional_depth, encountered_overflow } = self
+            .global_cache(tcx)
+            .get(tcx, input, self.stack.iter().map(|e| e.input), available_depth)?;
+
+        // If we're building a proof tree and the current cache entry does not
+        // contain a proof tree, we do not use the entry but instead recompute
+        // the goal. We simply overwrite the existing entry once we're done,
+        // caching the proof tree.
+        if !inspect.is_noop() {
+            if let Some(revisions) = proof_tree {
+                let kind = inspect::WipCanonicalGoalEvaluationKind::Interned { revisions };
+                inspect.goal_evaluation_kind(kind);
+            } else {
+                return None;
+            }
+        }
+
+        // Update the reached depth of the current goal to make sure
+        // its state is the same regardless of whether we've used the
+        // global cache or not.
+        let reached_depth = self.stack.next_index().plus(additional_depth);
+        if let Some(last) = self.stack.raw.last_mut() {
+            last.reached_depth = last.reached_depth.max(reached_depth);
+            last.encountered_overflow |= encountered_overflow;
+        }
+
+        Some(result)
+    }
 }
 
 enum StepResult<'tcx> {