about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_trait_selection')
-rw-r--r--compiler/rustc_trait_selection/src/lib.rs1
-rw-r--r--compiler/rustc_trait_selection/src/solve/inspect/analyse.rs3
-rw-r--r--compiler/rustc_trait_selection/src/solve/inspect/build.rs5
-rw-r--r--compiler/rustc_trait_selection/src/solve/search_graph.rs215
4 files changed, 162 insertions, 62 deletions
diff --git a/compiler/rustc_trait_selection/src/lib.rs b/compiler/rustc_trait_selection/src/lib.rs
index de2577cca49..552c28c0586 100644
--- a/compiler/rustc_trait_selection/src/lib.rs
+++ b/compiler/rustc_trait_selection/src/lib.rs
@@ -19,6 +19,7 @@
 #![feature(control_flow_enum)]
 #![feature(extract_if)]
 #![feature(let_chains)]
+#![feature(option_take_if)]
 #![feature(if_let_guard)]
 #![feature(never_type)]
 #![feature(type_alias_impl_trait)]
diff --git a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
index 6db53d6ddc4..f33d0f397ce 100644
--- a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
+++ b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
@@ -171,7 +171,8 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
         let mut candidates = vec![];
         let last_eval_step = match self.evaluation.evaluation.kind {
             inspect::CanonicalGoalEvaluationKind::Overflow
-            | inspect::CanonicalGoalEvaluationKind::CycleInStack => {
+            | inspect::CanonicalGoalEvaluationKind::CycleInStack
+            | inspect::CanonicalGoalEvaluationKind::ProvisionalCacheHit => {
                 warn!("unexpected root evaluation: {:?}", self.evaluation);
                 return vec![];
             }
diff --git a/compiler/rustc_trait_selection/src/solve/inspect/build.rs b/compiler/rustc_trait_selection/src/solve/inspect/build.rs
index d8caef5b03f..b587a93b24c 100644
--- a/compiler/rustc_trait_selection/src/solve/inspect/build.rs
+++ b/compiler/rustc_trait_selection/src/solve/inspect/build.rs
@@ -118,6 +118,7 @@ pub(in crate::solve) enum WipGoalEvaluationKind<'tcx> {
 pub(in crate::solve) enum WipCanonicalGoalEvaluationKind<'tcx> {
     Overflow,
     CycleInStack,
+    ProvisionalCacheHit,
     Interned { revisions: &'tcx [inspect::GoalEvaluationStep<'tcx>] },
 }
 
@@ -126,6 +127,7 @@ impl std::fmt::Debug for WipCanonicalGoalEvaluationKind<'_> {
         match self {
             Self::Overflow => write!(f, "Overflow"),
             Self::CycleInStack => write!(f, "CycleInStack"),
+            Self::ProvisionalCacheHit => write!(f, "ProvisionalCacheHit"),
             Self::Interned { revisions: _ } => f.debug_struct("Interned").finish_non_exhaustive(),
         }
     }
@@ -151,6 +153,9 @@ impl<'tcx> WipCanonicalGoalEvaluation<'tcx> {
             WipCanonicalGoalEvaluationKind::CycleInStack => {
                 inspect::CanonicalGoalEvaluationKind::CycleInStack
             }
+            WipCanonicalGoalEvaluationKind::ProvisionalCacheHit => {
+                inspect::CanonicalGoalEvaluationKind::ProvisionalCacheHit
+            }
             WipCanonicalGoalEvaluationKind::Interned { revisions } => {
                 inspect::CanonicalGoalEvaluationKind::Evaluation { revisions }
             }
diff --git a/compiler/rustc_trait_selection/src/solve/search_graph.rs b/compiler/rustc_trait_selection/src/solve/search_graph.rs
index 523025cabe2..ead4ab5723d 100644
--- a/compiler/rustc_trait_selection/src/solve/search_graph.rs
+++ b/compiler/rustc_trait_selection/src/solve/search_graph.rs
@@ -11,7 +11,6 @@ use rustc_middle::traits::solve::{CanonicalInput, Certainty, EvaluationCache, Qu
 use rustc_middle::ty;
 use rustc_middle::ty::TyCtxt;
 use rustc_session::Limit;
-use std::collections::hash_map::Entry;
 use std::mem;
 
 rustc_index::newtype_index! {
@@ -30,7 +29,7 @@ struct StackEntry<'tcx> {
     ///
     /// If so, it must not be moved to the global cache. See
     /// [SearchGraph::cycle_participants] for more details.
-    non_root_cycle_participant: bool,
+    non_root_cycle_participant: Option<StackDepth>,
 
     encountered_overflow: bool,
     has_been_used: bool,
@@ -39,6 +38,34 @@ struct StackEntry<'tcx> {
     provisional_result: Option<QueryResult<'tcx>>,
 }
 
+struct DetachedEntry<'tcx> {
+    /// The head of the smallest non-trivial cycle involving this entry.
+    ///
+    /// Given the following rules, when proving `A` the head for
+    /// the provisional entry of `C` would be `B`.
+    ///
+    ///     A :- B
+    ///     B :- C
+    ///     C :- A + B + C
+    head: StackDepth,
+    result: QueryResult<'tcx>,
+}
+
+#[derive(Default)]
+struct ProvisionalCacheEntry<'tcx> {
+    stack_depth: Option<StackDepth>,
+    with_inductive_stack: Option<DetachedEntry<'tcx>>,
+    with_coinductive_stack: Option<DetachedEntry<'tcx>>,
+}
+
+impl<'tcx> ProvisionalCacheEntry<'tcx> {
+    fn is_empty(&self) -> bool {
+        self.stack_depth.is_none()
+            && self.with_inductive_stack.is_none()
+            && self.with_coinductive_stack.is_none()
+    }
+}
+
 pub(super) struct SearchGraph<'tcx> {
     mode: SolverMode,
     local_overflow_limit: usize,
@@ -46,7 +73,7 @@ pub(super) struct SearchGraph<'tcx> {
     ///
     /// An element is *deeper* in the stack if its index is *lower*.
     stack: IndexVec<StackDepth, StackEntry<'tcx>>,
-    stack_entries: FxHashMap<CanonicalInput<'tcx>, StackDepth>,
+    provisional_cache: FxHashMap<CanonicalInput<'tcx>, ProvisionalCacheEntry<'tcx>>,
     /// We put only the root goal of a coinductive cycle into the global cache.
     ///
     /// If we were to use that result when later trying to prove another cycle
@@ -63,7 +90,7 @@ impl<'tcx> SearchGraph<'tcx> {
             mode,
             local_overflow_limit: tcx.recursion_limit().0.checked_ilog2().unwrap_or(0) as usize,
             stack: Default::default(),
-            stack_entries: Default::default(),
+            provisional_cache: Default::default(),
             cycle_participants: Default::default(),
         }
     }
@@ -93,7 +120,6 @@ impl<'tcx> SearchGraph<'tcx> {
     /// would cause us to not track overflow and recursion depth correctly.
     fn pop_stack(&mut self) -> StackEntry<'tcx> {
         let elem = self.stack.pop().unwrap();
-        assert!(self.stack_entries.remove(&elem.input).is_some());
         if let Some(last) = self.stack.raw.last_mut() {
             last.reached_depth = last.reached_depth.max(elem.reached_depth);
             last.encountered_overflow |= elem.encountered_overflow;
@@ -114,7 +140,7 @@ impl<'tcx> SearchGraph<'tcx> {
 
     pub(super) fn is_empty(&self) -> bool {
         if self.stack.is_empty() {
-            debug_assert!(self.stack_entries.is_empty());
+            debug_assert!(self.provisional_cache.is_empty());
             debug_assert!(self.cycle_participants.is_empty());
             true
         } else {
@@ -156,6 +182,40 @@ impl<'tcx> SearchGraph<'tcx> {
         }
     }
 
+    fn stack_coinductive_from(
+        tcx: TyCtxt<'tcx>,
+        stack: &IndexVec<StackDepth, StackEntry<'tcx>>,
+        head: StackDepth,
+    ) -> bool {
+        stack.raw[head.index()..]
+            .iter()
+            .all(|entry| entry.input.value.goal.predicate.is_coinductive(tcx))
+    }
+
+    fn tag_cycle_participants(
+        stack: &mut IndexVec<StackDepth, StackEntry<'tcx>>,
+        cycle_participants: &mut FxHashSet<CanonicalInput<'tcx>>,
+        head: StackDepth,
+    ) {
+        stack[head].has_been_used = true;
+        for entry in &mut stack.raw[head.index() + 1..] {
+            entry.non_root_cycle_participant = entry.non_root_cycle_participant.max(Some(head));
+            cycle_participants.insert(entry.input);
+        }
+    }
+
+    fn clear_dependent_provisional_results(
+        provisional_cache: &mut FxHashMap<CanonicalInput<'tcx>, ProvisionalCacheEntry<'tcx>>,
+        head: StackDepth,
+    ) {
+        #[allow(rustc::potential_query_instability)]
+        provisional_cache.retain(|_, entry| {
+            entry.with_coinductive_stack.take_if(|p| p.head == head);
+            entry.with_inductive_stack.take_if(|p| p.head == head);
+            !entry.is_empty()
+        });
+    }
+
     /// Probably the most involved method of the whole solver.
     ///
     /// Given some goal which is proven via the `prove_goal` closure, this
@@ -210,23 +270,36 @@ impl<'tcx> SearchGraph<'tcx> {
             return result;
         }
 
-        // Check whether we're in a cycle.
-        match self.stack_entries.entry(input) {
-            // No entry, we push this goal on the stack and try to prove it.
-            Entry::Vacant(v) => {
-                let depth = self.stack.next_index();
-                let entry = StackEntry {
-                    input,
-                    available_depth,
-                    reached_depth: depth,
-                    non_root_cycle_participant: false,
-                    encountered_overflow: false,
-                    has_been_used: false,
-                    provisional_result: None,
-                };
-                assert_eq!(self.stack.push(entry), depth);
-                v.insert(depth);
-            }
+        // Check whether the goal is in the provisional cache.
+        let cache_entry = self.provisional_cache.entry(input).or_default();
+        if let Some(with_coinductive_stack) = &mut cache_entry.with_coinductive_stack
+            && Self::stack_coinductive_from(tcx, &self.stack, with_coinductive_stack.head)
+        {
+            // We have a nested goal which is already in the provisional cache, use
+            // its result.
+            inspect
+                .goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::ProvisionalCacheHit);
+            Self::tag_cycle_participants(
+                &mut self.stack,
+                &mut self.cycle_participants,
+                with_coinductive_stack.head,
+            );
+            return with_coinductive_stack.result;
+        } else if let Some(with_inductive_stack) = &mut cache_entry.with_inductive_stack
+            && !Self::stack_coinductive_from(tcx, &self.stack, with_inductive_stack.head)
+        {
+            // We have a nested goal which is already in the provisional cache, use
+            // its result.
+            inspect
+                .goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::ProvisionalCacheHit);
+            Self::tag_cycle_participants(
+                &mut self.stack,
+                &mut self.cycle_participants,
+                with_inductive_stack.head,
+            );
+            return with_inductive_stack.result;
+        } else if let Some(stack_depth) = cache_entry.stack_depth {
+            debug!("encountered cycle with depth {stack_depth:?}");
             // We have a nested goal which relies on a goal `root` deeper in the stack.
             //
             // We first store that we may have to reprove `root` in case the provisional
@@ -236,40 +309,37 @@ impl<'tcx> SearchGraph<'tcx> {
             //
             // Finally we can return either the provisional response for that goal if we have a
             // coinductive cycle or an ambiguous result if the cycle is inductive.
-            Entry::Occupied(entry) => {
-                inspect.goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::CycleInStack);
-
-                let stack_depth = *entry.get();
-                debug!("encountered cycle with depth {stack_depth:?}");
-                // We start by tagging all non-root cycle participants.
-                let participants = self.stack.raw.iter_mut().skip(stack_depth.as_usize() + 1);
-                for entry in participants {
-                    entry.non_root_cycle_participant = true;
-                    self.cycle_participants.insert(entry.input);
-                }
-
-                // If we're in a cycle, we have to retry proving the cycle head
-                // until we reach a fixpoint. It is not enough to simply retry the
-                // `root` goal of this cycle.
-                //
-                // See tests/ui/traits/next-solver/cycles/fixpoint-rerun-all-cycle-heads.rs
-                // for an example.
-                self.stack[stack_depth].has_been_used = true;
-                return if let Some(result) = self.stack[stack_depth].provisional_result {
-                    result
+            inspect.goal_evaluation_kind(inspect::WipCanonicalGoalEvaluationKind::CycleInStack);
+            Self::tag_cycle_participants(
+                &mut self.stack,
+                &mut self.cycle_participants,
+                stack_depth,
+            );
+            return if let Some(result) = self.stack[stack_depth].provisional_result {
+                result
+            } else {
+                // If we don't have a provisional result yet we're in the first iteration,
+                // so we start with no constraints.
+                if Self::stack_coinductive_from(tcx, &self.stack, stack_depth) {
+                    Self::response_no_constraints(tcx, input, Certainty::Yes)
                 } else {
-                    // If we don't have a provisional result yet we're in the first iteration,
-                    // so we start with no constraints.
-                    let is_inductive = self.stack.raw[stack_depth.index()..]
-                        .iter()
-                        .any(|entry| !entry.input.value.goal.predicate.is_coinductive(tcx));
-                    if is_inductive {
-                        Self::response_no_constraints(tcx, input, Certainty::OVERFLOW)
-                    } else {
-                        Self::response_no_constraints(tcx, input, Certainty::Yes)
-                    }
-                };
-            }
+                    Self::response_no_constraints(tcx, input, Certainty::OVERFLOW)
+                }
+            };
+        } else {
+            // No entry, we push this goal on the stack and try to prove it.
+            let depth = self.stack.next_index();
+            let entry = StackEntry {
+                input,
+                available_depth,
+                reached_depth: depth,
+                non_root_cycle_participant: None,
+                encountered_overflow: false,
+                has_been_used: false,
+                provisional_result: None,
+            };
+            assert_eq!(self.stack.push(entry), depth);
+            cache_entry.stack_depth = Some(depth);
         }
 
         // This is for global caching, so we properly track query dependencies.
@@ -285,11 +355,22 @@ impl<'tcx> SearchGraph<'tcx> {
                 for _ in 0..self.local_overflow_limit() {
                     let result = prove_goal(self, inspect);
 
-                    // Check whether the current goal is the root of a cycle and whether
-                    // we have to rerun because its provisional result differed from the
-                    // final result.
+                    // Check whether the current goal is the root of a cycle.
+                    // If so, we have to retry proving the cycle head
+                    // until its result reaches a fixpoint. We need to do so for
+                    // all cycle heads, not only for the root.
+                    //
+                    // See tests/ui/traits/next-solver/cycles/fixpoint-rerun-all-cycle-heads.rs
+                    // for an example.
                     let stack_entry = self.pop_stack();
                     debug_assert_eq!(stack_entry.input, input);
+                    if stack_entry.has_been_used {
+                        Self::clear_dependent_provisional_results(
+                            &mut self.provisional_cache,
+                            self.stack.next_index(),
+                        );
+                    }
+
                     if stack_entry.has_been_used
                         && stack_entry.provisional_result.map_or(true, |r| r != result)
                     {
@@ -299,7 +380,7 @@ impl<'tcx> SearchGraph<'tcx> {
                             provisional_result: Some(result),
                             ..stack_entry
                         });
-                        assert_eq!(self.stack_entries.insert(input, depth), None);
+                        debug_assert_eq!(self.provisional_cache[&input].stack_depth, Some(depth));
                     } else {
                         return (stack_entry, result);
                     }
@@ -307,6 +388,7 @@ impl<'tcx> SearchGraph<'tcx> {
 
                 debug!("canonical cycle overflow");
                 let current_entry = self.pop_stack();
+                debug_assert!(!current_entry.has_been_used);
                 let result = Self::response_no_constraints(tcx, input, Certainty::OVERFLOW);
                 (current_entry, result)
             });
@@ -319,7 +401,17 @@ impl<'tcx> SearchGraph<'tcx> {
         //
         // It is not possible for any nested goal to depend on something deeper on the
         // stack, as this would have also updated the depth of the current goal.
-        if !final_entry.non_root_cycle_participant {
+        if let Some(head) = final_entry.non_root_cycle_participant {
+            let coinductive_stack = Self::stack_coinductive_from(tcx, &self.stack, head);
+
+            let entry = self.provisional_cache.get_mut(&input).unwrap();
+            entry.stack_depth = None;
+            if coinductive_stack {
+                entry.with_coinductive_stack = Some(DetachedEntry { head, result });
+            } else {
+                entry.with_inductive_stack = Some(DetachedEntry { head, result });
+            }
+        } else {
             // When encountering a cycle, both inductive and coinductive, we only
             // move the root into the global cache. We also store all other cycle
             // participants involved.
@@ -328,6 +420,7 @@ impl<'tcx> SearchGraph<'tcx> {
             // participant is on the stack. This is necessary to prevent unstable
             // results. See the comment of `SearchGraph::cycle_participants` for
             // more details.
+            self.provisional_cache.remove(&input);
             let reached_depth = final_entry.reached_depth.as_usize() - self.stack.len();
             let cycle_participants = mem::take(&mut self.cycle_participants);
             self.global_cache(tcx).insert(