about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_type_ir/src/search_graph/global_cache.rs17
-rw-r--r--compiler/rustc_type_ir/src/search_graph/mod.rs117
-rw-r--r--compiler/rustc_type_ir/src/search_graph/stack.rs8
3 files changed, 97 insertions, 45 deletions
diff --git a/compiler/rustc_type_ir/src/search_graph/global_cache.rs b/compiler/rustc_type_ir/src/search_graph/global_cache.rs
index a2442660259..1b99cc820f1 100644
--- a/compiler/rustc_type_ir/src/search_graph/global_cache.rs
+++ b/compiler/rustc_type_ir/src/search_graph/global_cache.rs
@@ -2,6 +2,7 @@ use derive_where::derive_where;
 
 use super::{AvailableDepth, Cx, NestedGoals};
 use crate::data_structures::HashMap;
+use crate::search_graph::EvaluationResult;
 
 struct Success<X: Cx> {
     required_depth: usize,
@@ -43,28 +44,26 @@ impl<X: Cx> GlobalCache<X> {
         &mut self,
         cx: X,
         input: X::Input,
-
-        origin_result: X::Result,
+        evaluation_result: EvaluationResult<X>,
         dep_node: X::DepNodeIndex,
-
-        required_depth: usize,
-        encountered_overflow: bool,
-        nested_goals: NestedGoals<X>,
     ) {
-        let result = cx.mk_tracked(origin_result, dep_node);
+        let EvaluationResult { encountered_overflow, required_depth, heads, nested_goals, result } =
+            evaluation_result;
+        debug_assert!(heads.is_empty());
+        let result = cx.mk_tracked(result, dep_node);
         let entry = self.map.entry(input).or_default();
         if encountered_overflow {
             let with_overflow = WithOverflow { nested_goals, result };
             let prev = entry.with_overflow.insert(required_depth, with_overflow);
             if let Some(prev) = &prev {
                 assert!(cx.evaluation_is_concurrent());
-                assert_eq!(cx.get_tracked(&prev.result), origin_result);
+                assert_eq!(cx.get_tracked(&prev.result), evaluation_result.result);
             }
         } else {
             let prev = entry.success.replace(Success { required_depth, nested_goals, result });
             if let Some(prev) = &prev {
                 assert!(cx.evaluation_is_concurrent());
-                assert_eq!(cx.get_tracked(&prev.result), origin_result);
+                assert_eq!(cx.get_tracked(&prev.result), evaluation_result.result);
             }
         }
     }
diff --git a/compiler/rustc_type_ir/src/search_graph/mod.rs b/compiler/rustc_type_ir/src/search_graph/mod.rs
index 953dc2d1bc8..1a7ccb5b916 100644
--- a/compiler/rustc_type_ir/src/search_graph/mod.rs
+++ b/compiler/rustc_type_ir/src/search_graph/mod.rs
@@ -448,6 +448,43 @@ struct ProvisionalCacheEntry<X: Cx> {
     result: X::Result,
 }
 
+/// The final result of evaluating a goal.
+///
+/// We reset `encountered_overflow` when reevaluating a goal,
+/// but need to track whether we've hit the recursion limit at
+/// all for correctness.
+///
+/// We've previously simply returned the final `StackEntry` but this
+/// made it easy to accidentally drop information from the previous
+/// evaluation.
+#[derive_where(Debug; X: Cx)]
+struct EvaluationResult<X: Cx> {
+    encountered_overflow: bool,
+    required_depth: usize,
+    heads: CycleHeads,
+    nested_goals: NestedGoals<X>,
+    result: X::Result,
+}
+
+impl<X: Cx> EvaluationResult<X> {
+    fn finalize(
+        final_entry: StackEntry<X>,
+        encountered_overflow: bool,
+        result: X::Result,
+    ) -> EvaluationResult<X> {
+        EvaluationResult {
+            encountered_overflow,
+            // Unlike `encountered_overflow`, we share `heads`, `required_depth`,
+            // and `nested_goals` between evaluations.
+            required_depth: final_entry.required_depth,
+            heads: final_entry.heads,
+            nested_goals: final_entry.nested_goals,
+            // We only care about the final result.
+            result,
+        }
+    }
+}
+
 pub struct SearchGraph<D: Delegate<Cx = X>, X: Cx = <D as Delegate>::Cx> {
     root_depth: AvailableDepth,
     /// The stack of goals currently being computed.
@@ -614,12 +651,12 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
             input,
             step_kind_from_parent,
             available_depth,
+            provisional_result: None,
             required_depth: 0,
             heads: Default::default(),
             encountered_overflow: false,
             has_been_used: None,
             nested_goals: Default::default(),
-            provisional_result: None,
         });
 
         // This is for global caching, so we properly track query dependencies.
@@ -628,7 +665,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
         // not tracked by the cache key and from outside of this anon task, it
         // must not be added to the global cache. Notably, this is the case for
         // trait solver cycles participants.
-        let ((final_entry, result), dep_node) = cx.with_cached_task(|| {
+        let (evaluation_result, dep_node) = cx.with_cached_task(|| {
             self.evaluate_goal_in_task(cx, input, inspect, &mut evaluate_goal)
         });
 
@@ -636,27 +673,34 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
         // lazily update its parent goal.
         Self::update_parent_goal(
             &mut self.stack,
-            final_entry.step_kind_from_parent,
-            final_entry.required_depth,
-            &final_entry.heads,
-            final_entry.encountered_overflow,
-            UpdateParentGoalCtxt::Ordinary(&final_entry.nested_goals),
+            step_kind_from_parent,
+            evaluation_result.required_depth,
+            &evaluation_result.heads,
+            evaluation_result.encountered_overflow,
+            UpdateParentGoalCtxt::Ordinary(&evaluation_result.nested_goals),
         );
+        let result = evaluation_result.result;
 
         // We're now done with this goal. We only add the root of cycles to the global cache.
         // In case this goal is involved in a larger cycle add it to the provisional cache.
-        if final_entry.heads.is_empty() {
+        if evaluation_result.heads.is_empty() {
             if let Some((_scope, expected)) = validate_cache {
                 // Do not try to move a goal into the cache again if we're testing
                 // the global cache.
-                assert_eq!(result, expected, "input={input:?}");
+                assert_eq!(evaluation_result.result, expected, "input={input:?}");
             } else if D::inspect_is_noop(inspect) {
-                self.insert_global_cache(cx, final_entry, result, dep_node)
+                self.insert_global_cache(cx, input, evaluation_result, dep_node)
             }
         } else if D::ENABLE_PROVISIONAL_CACHE {
             debug_assert!(validate_cache.is_none(), "unexpected non-root: {input:?}");
             let entry = self.provisional_cache.entry(input).or_default();
-            let StackEntry { heads, encountered_overflow, .. } = final_entry;
+            let EvaluationResult {
+                encountered_overflow,
+                required_depth: _,
+                heads,
+                nested_goals: _,
+                result,
+            } = evaluation_result;
             let path_from_head = Self::cycle_path_kind(
                 &self.stack,
                 step_kind_from_parent,
@@ -1022,18 +1066,24 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
         input: X::Input,
         inspect: &mut D::ProofTreeBuilder,
         mut evaluate_goal: impl FnMut(&mut Self, &mut D::ProofTreeBuilder) -> X::Result,
-    ) -> (StackEntry<X>, X::Result) {
+    ) -> EvaluationResult<X> {
+        // We reset `encountered_overflow` each time we rerun this goal
+        // but need to make sure we currently propagate it to the global
+        // cache even if only some of the evaluations actually reach the
+        // recursion limit.
+        let mut encountered_overflow = false;
         let mut i = 0;
         loop {
             let result = evaluate_goal(self, inspect);
             let stack_entry = self.stack.pop();
+            encountered_overflow |= stack_entry.encountered_overflow;
             debug_assert_eq!(stack_entry.input, input);
 
             // If the current goal is not the root of a cycle, we are done.
             //
             // There are no provisional cache entries which depend on this goal.
             let Some(usage_kind) = stack_entry.has_been_used else {
-                return (stack_entry, result);
+                return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
             };
 
             // If it is a cycle head, we have to keep trying to prove it until
@@ -1049,7 +1099,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
             // final result is equal to the initial response for that case.
             if self.reached_fixpoint(cx, &stack_entry, usage_kind, result) {
                 self.rebase_provisional_cache_entries(&stack_entry, |_, result| result);
-                return (stack_entry, result);
+                return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
             }
 
             // If computing this goal results in ambiguity with no constraints,
@@ -1068,7 +1118,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
                 self.rebase_provisional_cache_entries(&stack_entry, |input, _| {
                     D::propagate_ambiguity(cx, input, result)
                 });
-                return (stack_entry, result);
+                return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
             };
 
             // If we've reached the fixpoint step limit, we bail with overflow and taint all
@@ -1080,7 +1130,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
                 self.rebase_provisional_cache_entries(&stack_entry, |input, _| {
                     D::on_fixpoint_overflow(cx, input)
                 });
-                return (stack_entry, result);
+                return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
             }
 
             // Clear all provisional cache entries which depend on a previous provisional
@@ -1089,9 +1139,22 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
 
             debug!(?result, "fixpoint changed provisional results");
             self.stack.push(StackEntry {
-                has_been_used: None,
+                input,
+                step_kind_from_parent: stack_entry.step_kind_from_parent,
+                available_depth: stack_entry.available_depth,
                 provisional_result: Some(result),
-                ..stack_entry
+                // We can keep these goals from previous iterations as they are only
+                // ever read after finalizing this evaluation.
+                required_depth: stack_entry.required_depth,
+                heads: stack_entry.heads,
+                nested_goals: stack_entry.nested_goals,
+                // We reset these two fields when rerunning this goal. We could
+                // keep `encountered_overflow` as it's only used as a performance
+                // optimization. However, given that the proof tree will likely look
+                // similar to the previous iterations when reevaluating, it's better
+                // for caching if the reevaluation also starts out with `false`.
+                encountered_overflow: false,
+                has_been_used: None,
             });
         }
     }
@@ -1107,21 +1170,11 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
     fn insert_global_cache(
         &mut self,
         cx: X,
-        final_entry: StackEntry<X>,
-        result: X::Result,
+        input: X::Input,
+        evaluation_result: EvaluationResult<X>,
         dep_node: X::DepNodeIndex,
     ) {
-        debug!(?final_entry, ?result, "insert global cache");
-        cx.with_global_cache(|cache| {
-            cache.insert(
-                cx,
-                final_entry.input,
-                result,
-                dep_node,
-                final_entry.required_depth,
-                final_entry.encountered_overflow,
-                final_entry.nested_goals,
-            )
-        })
+        debug!(?evaluation_result, "insert global cache");
+        cx.with_global_cache(|cache| cache.insert(cx, input, evaluation_result, dep_node))
     }
 }
diff --git a/compiler/rustc_type_ir/src/search_graph/stack.rs b/compiler/rustc_type_ir/src/search_graph/stack.rs
index 8bb247bf055..e0fd934df69 100644
--- a/compiler/rustc_type_ir/src/search_graph/stack.rs
+++ b/compiler/rustc_type_ir/src/search_graph/stack.rs
@@ -26,6 +26,10 @@ pub(super) struct StackEntry<X: Cx> {
     /// The available depth of a given goal, immutable.
     pub available_depth: AvailableDepth,
 
+    /// Starts out as `None` and gets set when rerunning this
+    /// goal in case we encounter a cycle.
+    pub provisional_result: Option<X::Result>,
+
     /// The maximum depth required while evaluating this goal.
     pub required_depth: usize,
 
@@ -42,10 +46,6 @@ pub(super) struct StackEntry<X: Cx> {
 
     /// The nested goals of this goal, see the doc comment of the type.
     pub nested_goals: NestedGoals<X>,
-
-    /// Starts out as `None` and gets set when rerunning this
-    /// goal in case we encounter a cycle.
-    pub provisional_result: Option<X::Result>,
 }
 
 #[derive_where(Default; X: Cx)]