about summary refs log tree commit diff
diff options
context:
space:
mode:
authorlcnr <rust@lcnr.de>2025-09-26 13:48:22 +0200
committerlcnr <rust@lcnr.de>2025-09-26 15:29:01 +0200
commit4a0c7cc730a5434574c41a6073d9efb92141af1a (patch)
treefae2d074f884559e0c01b9eb7d0e8855b55958a9
parent5b9007bfc358817cf066ee27c6b440440727d3a7 (diff)
downloadrust-4a0c7cc730a5434574c41a6073d9efb92141af1a.tar.gz
rust-4a0c7cc730a5434574c41a6073d9efb92141af1a.zip
fix cycle head usages tracking
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/search_graph.rs22
-rw-r--r--compiler/rustc_type_ir/src/search_graph/mod.rs141
-rw-r--r--tests/ui/traits/next-solver/cycles/ignore-head-usages-provisional-cache.rs55
3 files changed, 166 insertions, 52 deletions
diff --git a/compiler/rustc_next_trait_solver/src/solve/search_graph.rs b/compiler/rustc_next_trait_solver/src/solve/search_graph.rs
index aa9dfc9a9a2..4ae2af59a70 100644
--- a/compiler/rustc_next_trait_solver/src/solve/search_graph.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/search_graph.rs
@@ -74,13 +74,21 @@ where
         }
     }
 
-    fn is_initial_provisional_result(
-        cx: Self::Cx,
-        kind: PathKind,
-        input: CanonicalInput<I>,
-        result: QueryResult<I>,
-    ) -> bool {
-        Self::initial_provisional_result(cx, kind, input) == result
+    fn is_initial_provisional_result(result: QueryResult<I>) -> Option<PathKind> {
+        match result {
+            Ok(response) => {
+                if has_no_inference_or_external_constraints(response) {
+                    if response.value.certainty == Certainty::Yes {
+                        return Some(PathKind::Coinductive);
+                    } else if response.value.certainty == Certainty::overflow(false) {
+                        return Some(PathKind::Unknown);
+                    }
+                }
+
+                None
+            }
+            Err(NoSolution) => Some(PathKind::Inductive),
+        }
     }
 
     fn on_stack_overflow(cx: I, input: CanonicalInput<I>) -> QueryResult<I> {
diff --git a/compiler/rustc_type_ir/src/search_graph/mod.rs b/compiler/rustc_type_ir/src/search_graph/mod.rs
index 8f8f019510f..a7902cd0210 100644
--- a/compiler/rustc_type_ir/src/search_graph/mod.rs
+++ b/compiler/rustc_type_ir/src/search_graph/mod.rs
@@ -86,12 +86,7 @@ pub trait Delegate: Sized {
         kind: PathKind,
         input: <Self::Cx as Cx>::Input,
     ) -> <Self::Cx as Cx>::Result;
-    fn is_initial_provisional_result(
-        cx: Self::Cx,
-        kind: PathKind,
-        input: <Self::Cx as Cx>::Input,
-        result: <Self::Cx as Cx>::Result,
-    ) -> bool;
+    fn is_initial_provisional_result(result: <Self::Cx as Cx>::Result) -> Option<PathKind>;
     fn on_stack_overflow(cx: Self::Cx, input: <Self::Cx as Cx>::Input) -> <Self::Cx as Cx>::Result;
     fn on_fixpoint_overflow(
         cx: Self::Cx,
@@ -215,6 +210,27 @@ impl HeadUsages {
         let HeadUsages { inductive, unknown, coinductive, forced_ambiguity } = self;
         inductive == 0 && unknown == 0 && coinductive == 0 && forced_ambiguity == 0
     }
+
+    fn is_single(self, path_kind: PathKind) -> bool {
+        match path_kind {
+            PathKind::Inductive => matches!(
+                self,
+                HeadUsages { inductive: _, unknown: 0, coinductive: 0, forced_ambiguity: 0 },
+            ),
+            PathKind::Unknown => matches!(
+                self,
+                HeadUsages { inductive: 0, unknown: _, coinductive: 0, forced_ambiguity: 0 },
+            ),
+            PathKind::Coinductive => matches!(
+                self,
+                HeadUsages { inductive: 0, unknown: 0, coinductive: _, forced_ambiguity: 0 },
+            ),
+            PathKind::ForcedAmbiguity => matches!(
+                self,
+                HeadUsages { inductive: 0, unknown: 0, coinductive: 0, forced_ambiguity: _ },
+            ),
+        }
+    }
 }
 
 #[derive(Debug, Default)]
@@ -888,7 +904,29 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
             !entries.is_empty()
         });
     }
+}
+
+/// We need to rebase provisional cache entries when popping one of their cycle
+/// heads from the stack. This may not necessarily mean that we've actually
+/// reached a fixpoint for that cycle head, which impacts the way we rebase
+/// provisional cache entries.
+enum RebaseReason {
+    NoCycleUsages,
+    Ambiguity,
+    Overflow,
+    /// We've actually reached a fixpoint.
+    ///
+    /// This either happens in the first evaluation step for the cycle head.
+    /// In this case the used provisional result depends on the cycle `PathKind`.
+    /// We store this path kind to check whether the the provisional cache entry
+    /// we're rebasing relied on the same cycles.
+    ///
+    /// In later iterations cycles always return `stack_entry.provisional_result`
+    /// so we no longer depend on the `PathKind`. We store `None` in that case.
+    ReachedFixpoint(Option<PathKind>),
+}
 
+impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
     /// A necessary optimization to handle complex solver cycles. A provisional cache entry
     /// relies on a set of cycle heads and the path towards these heads. When popping a cycle
     /// head from the stack after we've finished computing it, we can't be sure that the
@@ -908,8 +946,9 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
     /// to me.
     fn rebase_provisional_cache_entries(
         &mut self,
+        cx: X,
         stack_entry: &StackEntry<X>,
-        mut mutate_result: impl FnMut(X::Input, X::Result) -> X::Result,
+        rebase_reason: RebaseReason,
     ) {
         let popped_head_index = self.stack.next_index();
         #[allow(rustc::potential_query_instability)]
@@ -927,6 +966,10 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
                     return true;
                 };
 
+                let Some(new_highest_head_index) = heads.opt_highest_cycle_head_index() else {
+                    return false;
+                };
+
                 // We're rebasing an entry `e` over a head `p`. This head
                 // has a number of own heads `h` it depends on.
                 //
@@ -977,22 +1020,37 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
                         let eph = ep.extend_with_paths(ph);
                         heads.insert(head_index, eph, head.usages);
                     }
-                }
 
-                let Some(head_index) = heads.opt_highest_cycle_head_index() else {
-                    return false;
-                };
+                    // The provisional cache entry does depend on the provisional result
+                    // of the popped cycle head. We need to mutate the result of our
+                    // provisional cache entry in case we did not reach a fixpoint.
+                    match rebase_reason {
+                        // If the cycle head does not actually depend on itself, then
+                        // the provisional result used by the provisional cache entry
+                        // is not actually equal to the final provisional result. We
+                        // need to discard the provisional cache entry in this case.
+                        RebaseReason::NoCycleUsages => return false,
+                        RebaseReason::Ambiguity => {
+                            *result = D::propagate_ambiguity(cx, input, *result);
+                        }
+                        RebaseReason::Overflow => *result = D::on_fixpoint_overflow(cx, input),
+                        RebaseReason::ReachedFixpoint(None) => {}
+                        RebaseReason::ReachedFixpoint(Some(path_kind)) => {
+                            if !popped_head.usages.is_single(path_kind) {
+                                return false;
+                            }
+                        }
+                    };
+                }
 
                 // We now care about the path from the next highest cycle head to the
                 // provisional cache entry.
                 *path_from_head = path_from_head.extend(Self::cycle_path_kind(
                     &self.stack,
                     stack_entry.step_kind_from_parent,
-                    head_index,
+                    new_highest_head_index,
                 ));
-                // Mutate the result of the provisional cache entry in case we did
-                // not reach a fixpoint.
-                *result = mutate_result(input, *result);
+
                 true
             });
             !entries.is_empty()
@@ -1209,33 +1267,19 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
     /// Whether we've reached a fixpoint when evaluating a cycle head.
     fn reached_fixpoint(
         &mut self,
-        cx: X,
         stack_entry: &StackEntry<X>,
         usages: HeadUsages,
         result: X::Result,
-    ) -> bool {
+    ) -> Result<Option<PathKind>, ()> {
         let provisional_result = stack_entry.provisional_result;
-        if usages.is_empty() {
-            true
-        } else if let Some(provisional_result) = provisional_result {
-            provisional_result == result
+        if let Some(provisional_result) = provisional_result {
+            if provisional_result == result { Ok(None) } else { Err(()) }
+        } else if let Some(path_kind) = D::is_initial_provisional_result(result)
+            .filter(|&path_kind| usages.is_single(path_kind))
+        {
+            Ok(Some(path_kind))
         } else {
-            let check = |k| D::is_initial_provisional_result(cx, k, stack_entry.input, result);
-            match usages {
-                HeadUsages { inductive: _, unknown: 0, coinductive: 0, forced_ambiguity: 0 } => {
-                    check(PathKind::Inductive)
-                }
-                HeadUsages { inductive: 0, unknown: _, coinductive: 0, forced_ambiguity: 0 } => {
-                    check(PathKind::Unknown)
-                }
-                HeadUsages { inductive: 0, unknown: 0, coinductive: _, forced_ambiguity: 0 } => {
-                    check(PathKind::Coinductive)
-                }
-                HeadUsages { inductive: 0, unknown: 0, coinductive: 0, forced_ambiguity: _ } => {
-                    check(PathKind::ForcedAmbiguity)
-                }
-                _ => false,
-            }
+            Err(())
         }
     }
 
@@ -1280,8 +1324,19 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
             // is equal to the provisional result of the previous iteration, or because
             // this was only the head of either coinductive or inductive cycles, and the
             // final result is equal to the initial response for that case.
-            if self.reached_fixpoint(cx, &stack_entry, usages, result) {
-                self.rebase_provisional_cache_entries(&stack_entry, |_, result| result);
+            if let Ok(fixpoint) = self.reached_fixpoint(&stack_entry, usages, result) {
+                self.rebase_provisional_cache_entries(
+                    cx,
+                    &stack_entry,
+                    RebaseReason::ReachedFixpoint(fixpoint),
+                );
+                return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
+            } else if usages.is_empty() {
+                self.rebase_provisional_cache_entries(
+                    cx,
+                    &stack_entry,
+                    RebaseReason::NoCycleUsages,
+                );
                 return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
             }
 
@@ -1298,9 +1353,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
             // we also taint all provisional cache entries which depend on the
             // current goal.
             if D::is_ambiguous_result(result) {
-                self.rebase_provisional_cache_entries(&stack_entry, |input, _| {
-                    D::propagate_ambiguity(cx, input, result)
-                });
+                self.rebase_provisional_cache_entries(cx, &stack_entry, RebaseReason::Ambiguity);
                 return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
             };
 
@@ -1310,9 +1363,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
             if i >= D::FIXPOINT_STEP_LIMIT {
                 debug!("canonical cycle overflow");
                 let result = D::on_fixpoint_overflow(cx, input);
-                self.rebase_provisional_cache_entries(&stack_entry, |input, _| {
-                    D::on_fixpoint_overflow(cx, input)
-                });
+                self.rebase_provisional_cache_entries(cx, &stack_entry, RebaseReason::Overflow);
                 return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
             }
 
diff --git a/tests/ui/traits/next-solver/cycles/ignore-head-usages-provisional-cache.rs b/tests/ui/traits/next-solver/cycles/ignore-head-usages-provisional-cache.rs
new file mode 100644
index 00000000000..9a33ae53644
--- /dev/null
+++ b/tests/ui/traits/next-solver/cycles/ignore-head-usages-provisional-cache.rs
@@ -0,0 +1,55 @@
+//@ compile-flags: -Znext-solver
+//@ check-pass
+
+// A regression test for trait-system-refactor-initiative#232. We've
+// previously incorrectly rebased provisional cache entries even if
+// the cycle head didn't reach a fixpoint as it did not depend on any
+// cycles itself.
+//
+// Just because the result of a goal does not depend on its own provisional
+// result, it does not mean its nested goals don't depend on its result.
+struct B;
+struct C;
+struct D;
+
+pub trait Trait {
+    type Output;
+}
+macro_rules! k {
+    ($t:ty) => {
+        <$t as Trait>::Output
+    };
+}
+
+trait CallB<T1, T2> {
+    type Output;
+    type Return;
+}
+
+trait CallC<T1> {
+    type Output;
+    type Return;
+}
+
+trait CallD<T1, T2> {
+    type Output;
+}
+
+fn foo<X, Y>()
+where
+    X: Trait,
+    Y: Trait,
+    D: CallD<k![X], k![Y]>,
+    C: CallC<<D as CallD<k![X], k![Y]>>::Output>,
+    <C as CallC<<D as CallD<k![X], k![Y]>>::Output>>::Output: Trait,
+    B: CallB<
+            <C as CallC<<D as CallD<k![X], k![Y]>>::Output>>::Return,
+            <C as CallC<<D as CallD<k![X], k![Y]>>::Output>>::Output,
+        >,
+    <B as CallB<
+        <C as CallC<<D as CallD<k![X], k![Y]>>::Output>>::Return,
+        <C as CallC<<D as CallD<k![X], k![Y]>>::Output>>::Output,
+    >>::Output: Trait<Output = ()>,
+{
+}
+fn main() {}