about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection/src/solve/search_graph.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_trait_selection/src/solve/search_graph.rs')
-rw-r--r--compiler/rustc_trait_selection/src/solve/search_graph.rs178
1 files changed, 128 insertions, 50 deletions
diff --git a/compiler/rustc_trait_selection/src/solve/search_graph.rs b/compiler/rustc_trait_selection/src/solve/search_graph.rs
index 6cc674dcfed..bcd210f789b 100644
--- a/compiler/rustc_trait_selection/src/solve/search_graph.rs
+++ b/compiler/rustc_trait_selection/src/solve/search_graph.rs
@@ -47,20 +47,39 @@ struct StackEntry<I: Interner> {
     /// Whether this entry is a non-root cycle participant.
     ///
     /// We must not move the result of non-root cycle participants to the
-    /// global cache. See [SearchGraph::cycle_participants] for more details.
-    /// We store the highest stack depth of a head of a cycle this goal is involved
-    /// in. This necessary to soundly cache its provisional result.
+    /// global cache. We store the highest stack depth of a head of a cycle
+    /// this goal is involved in. This necessary to soundly cache its
+    /// provisional result.
     non_root_cycle_participant: Option<StackDepth>,
 
     encountered_overflow: bool,
 
     has_been_used: HasBeenUsed,
+
+    /// We put only the root goal of a coinductive cycle into the global cache.
+    ///
+    /// If we were to use that result when later trying to prove another cycle
+    /// participant, we can end up with unstable query results.
+    ///
+    /// See tests/ui/next-solver/coinduction/incompleteness-unstable-result.rs for
+    /// an example of where this is needed.
+    ///
+    /// There can  be multiple roots on the same stack, so we need to track
+    /// cycle participants per root:
+    /// ```plain
+    /// A :- B
+    /// B :- A, C
+    /// C :- D
+    /// D :- C
+    /// ```
+    cycle_participants: FxHashSet<CanonicalInput<I>>,
     /// Starts out as `None` and gets set when rerunning this
     /// goal in case we encounter a cycle.
     provisional_result: Option<QueryResult<I>>,
 }
 
 /// The provisional result for a goal which is not on the stack.
+#[derive(Debug)]
 struct DetachedEntry<I: Interner> {
     /// The head of the smallest non-trivial cycle involving this entry.
     ///
@@ -110,24 +129,11 @@ pub(super) struct SearchGraph<I: Interner> {
     /// An element is *deeper* in the stack if its index is *lower*.
     stack: IndexVec<StackDepth, StackEntry<I>>,
     provisional_cache: FxHashMap<CanonicalInput<I>, ProvisionalCacheEntry<I>>,
-    /// We put only the root goal of a coinductive cycle into the global cache.
-    ///
-    /// If we were to use that result when later trying to prove another cycle
-    /// participant, we can end up with unstable query results.
-    ///
-    /// See tests/ui/next-solver/coinduction/incompleteness-unstable-result.rs for
-    /// an example of where this is needed.
-    cycle_participants: FxHashSet<CanonicalInput<I>>,
 }
 
 impl<I: Interner> SearchGraph<I> {
     pub(super) fn new(mode: SolverMode) -> SearchGraph<I> {
-        Self {
-            mode,
-            stack: Default::default(),
-            provisional_cache: Default::default(),
-            cycle_participants: Default::default(),
-        }
+        Self { mode, stack: Default::default(), provisional_cache: Default::default() }
     }
 
     pub(super) fn solver_mode(&self) -> SolverMode {
@@ -149,13 +155,7 @@ impl<I: Interner> SearchGraph<I> {
     }
 
     pub(super) fn is_empty(&self) -> bool {
-        if self.stack.is_empty() {
-            debug_assert!(self.provisional_cache.is_empty());
-            debug_assert!(self.cycle_participants.is_empty());
-            true
-        } else {
-            false
-        }
+        self.stack.is_empty()
     }
 
     /// Returns the remaining depth allowed for nested goals.
@@ -205,15 +205,26 @@ impl<I: Interner> SearchGraph<I> {
     // their result does not get moved to the global cache.
     fn tag_cycle_participants(
         stack: &mut IndexVec<StackDepth, StackEntry<I>>,
-        cycle_participants: &mut FxHashSet<CanonicalInput<I>>,
         usage_kind: HasBeenUsed,
         head: StackDepth,
     ) {
         stack[head].has_been_used |= usage_kind;
         debug_assert!(!stack[head].has_been_used.is_empty());
-        for entry in &mut stack.raw[head.index() + 1..] {
+
+        // The current root of these cycles. Note that this may not be the final
+        // root in case a later goal depends on a goal higher up the stack.
+        let mut current_root = head;
+        while let Some(parent) = stack[current_root].non_root_cycle_participant {
+            current_root = parent;
+            debug_assert!(!stack[current_root].has_been_used.is_empty());
+        }
+
+        let (stack, cycle_participants) = stack.raw.split_at_mut(head.index() + 1);
+        let current_cycle_root = &mut stack[current_root.as_usize()];
+        for entry in cycle_participants {
             entry.non_root_cycle_participant = entry.non_root_cycle_participant.max(Some(head));
-            cycle_participants.insert(entry.input);
+            current_cycle_root.cycle_participants.insert(entry.input);
+            current_cycle_root.cycle_participants.extend(mem::take(&mut entry.cycle_participants));
         }
     }
 
@@ -256,6 +267,7 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
             &mut ProofTreeBuilder<TyCtxt<'tcx>>,
         ) -> QueryResult<TyCtxt<'tcx>>,
     ) -> QueryResult<TyCtxt<'tcx>> {
+        self.check_invariants();
         // Check for overflow.
         let Some(available_depth) = Self::allowed_depth_for_nested(tcx, &self.stack) else {
             if let Some(last) = self.stack.raw.last_mut() {
@@ -292,12 +304,7 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
             // already set correctly while computing the cache entry.
             inspect
                 .goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::ProvisionalCacheHit);
-            Self::tag_cycle_participants(
-                &mut self.stack,
-                &mut self.cycle_participants,
-                HasBeenUsed::empty(),
-                entry.head,
-            );
+            Self::tag_cycle_participants(&mut self.stack, HasBeenUsed::empty(), entry.head);
             return entry.result;
         } else if let Some(stack_depth) = cache_entry.stack_depth {
             debug!("encountered cycle with depth {stack_depth:?}");
@@ -314,12 +321,7 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
             } else {
                 HasBeenUsed::INDUCTIVE_CYCLE
             };
-            Self::tag_cycle_participants(
-                &mut self.stack,
-                &mut self.cycle_participants,
-                usage_kind,
-                stack_depth,
-            );
+            Self::tag_cycle_participants(&mut self.stack, usage_kind, stack_depth);
 
             // Return the provisional result or, if we're in the first iteration,
             // start with no constraints.
@@ -340,6 +342,7 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
                 non_root_cycle_participant: None,
                 encountered_overflow: false,
                 has_been_used: HasBeenUsed::empty(),
+                cycle_participants: Default::default(),
                 provisional_result: None,
             };
             assert_eq!(self.stack.push(entry), depth);
@@ -386,14 +389,13 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
         } else {
             self.provisional_cache.remove(&input);
             let reached_depth = final_entry.reached_depth.as_usize() - self.stack.len();
-            let cycle_participants = mem::take(&mut self.cycle_participants);
             // When encountering a cycle, both inductive and coinductive, we only
             // move the root into the global cache. We also store all other cycle
             // participants involved.
             //
             // We must not use the global cache entry of a root goal if a cycle
             // participant is on the stack. This is necessary to prevent unstable
-            // results. See the comment of `SearchGraph::cycle_participants` for
+            // results. See the comment of `StackEntry::cycle_participants` for
             // more details.
             self.global_cache(tcx).insert(
                 tcx,
@@ -401,12 +403,14 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
                 proof_tree,
                 reached_depth,
                 final_entry.encountered_overflow,
-                cycle_participants,
+                final_entry.cycle_participants,
                 dep_node,
                 result,
             )
         }
 
+        self.check_invariants();
+
         result
     }
 
@@ -416,10 +420,10 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
     fn lookup_global_cache(
         &mut self,
         tcx: TyCtxt<'tcx>,
-        input: CanonicalInput<'tcx>,
+        input: CanonicalInput<TyCtxt<'tcx>>,
         available_depth: Limit,
         inspect: &mut ProofTreeBuilder<TyCtxt<'tcx>>,
-    ) -> Option<QueryResult<'tcx>> {
+    ) -> Option<QueryResult<TyCtxt<'tcx>>> {
         let CacheData { result, proof_tree, additional_depth, encountered_overflow } = self
             .global_cache(tcx)
             .get(tcx, input, self.stack.iter().map(|e| e.input), available_depth)?;
@@ -450,12 +454,12 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
     }
 }
 
-enum StepResult<'tcx> {
-    Done(StackEntry<'tcx>, QueryResult<'tcx>),
+enum StepResult<I: Interner> {
+    Done(StackEntry<I>, QueryResult<I>),
     HasChanged,
 }
 
-impl<'tcx> SearchGraph<'tcx> {
+impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
     /// When we encounter a coinductive cycle, we have to fetch the
     /// result of that cycle while we are still computing it. Because
     /// of this we continuously recompute the cycle until the result
@@ -464,12 +468,12 @@ impl<'tcx> SearchGraph<'tcx> {
     fn fixpoint_step_in_task<F>(
         &mut self,
         tcx: TyCtxt<'tcx>,
-        input: CanonicalInput<'tcx>,
+        input: CanonicalInput<TyCtxt<'tcx>>,
         inspect: &mut ProofTreeBuilder<TyCtxt<'tcx>>,
         prove_goal: &mut F,
-    ) -> StepResult<'tcx>
+    ) -> StepResult<TyCtxt<'tcx>>
     where
-        F: FnMut(&mut Self, &mut ProofTreeBuilder<TyCtxt<'tcx>>) -> QueryResult<'tcx>,
+        F: FnMut(&mut Self, &mut ProofTreeBuilder<TyCtxt<'tcx>>) -> QueryResult<TyCtxt<'tcx>>,
     {
         let result = prove_goal(self, inspect);
         let stack_entry = self.pop_stack();
@@ -530,3 +534,77 @@ impl<'tcx> SearchGraph<'tcx> {
         Ok(super::response_no_constraints_raw(tcx, goal.max_universe, goal.variables, certainty))
     }
 }
+
+impl<I: Interner> SearchGraph<I> {
+    #[allow(rustc::potential_query_instability)]
+    fn check_invariants(&self) {
+        if !cfg!(debug_assertions) {
+            return;
+        }
+
+        let SearchGraph { mode: _, stack, provisional_cache } = self;
+        if stack.is_empty() {
+            assert!(provisional_cache.is_empty());
+        }
+
+        for (depth, entry) in stack.iter_enumerated() {
+            let StackEntry {
+                input,
+                available_depth: _,
+                reached_depth: _,
+                non_root_cycle_participant,
+                encountered_overflow: _,
+                has_been_used,
+                ref cycle_participants,
+                provisional_result,
+            } = *entry;
+            let cache_entry = provisional_cache.get(&entry.input).unwrap();
+            assert_eq!(cache_entry.stack_depth, Some(depth));
+            if let Some(head) = non_root_cycle_participant {
+                assert!(head < depth);
+                assert!(cycle_participants.is_empty());
+                assert_ne!(stack[head].has_been_used, HasBeenUsed::empty());
+
+                let mut current_root = head;
+                while let Some(parent) = stack[current_root].non_root_cycle_participant {
+                    current_root = parent;
+                }
+                assert!(stack[current_root].cycle_participants.contains(&input));
+            }
+
+            if !cycle_participants.is_empty() {
+                assert!(provisional_result.is_some() || !has_been_used.is_empty());
+                for entry in stack.iter().take(depth.as_usize()) {
+                    assert_eq!(cycle_participants.get(&entry.input), None);
+                }
+            }
+        }
+
+        for (&input, entry) in &self.provisional_cache {
+            let ProvisionalCacheEntry { stack_depth, with_coinductive_stack, with_inductive_stack } =
+                entry;
+            assert!(
+                stack_depth.is_some()
+                    || with_coinductive_stack.is_some()
+                    || with_inductive_stack.is_some()
+            );
+
+            if let &Some(stack_depth) = stack_depth {
+                assert_eq!(stack[stack_depth].input, input);
+            }
+
+            let check_detached = |detached_entry: &DetachedEntry<I>| {
+                let DetachedEntry { head, result: _ } = *detached_entry;
+                assert_ne!(stack[head].has_been_used, HasBeenUsed::empty());
+            };
+
+            if let Some(with_coinductive_stack) = with_coinductive_stack {
+                check_detached(with_coinductive_stack);
+            }
+
+            if let Some(with_inductive_stack) = with_inductive_stack {
+                check_detached(with_inductive_stack);
+            }
+        }
+    }
+}