about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-05-18 10:03:53 -0400
committerMichael Goulet <michael@errs.io>2024-05-19 19:38:28 -0400
commit91685c0ef4a9d41cd593a0d2433c1b6131ecee32 (patch)
tree3d9d665d8e477c4e2fd02ad3dae051e83a6e74b9
parentd84b9037541f45dc2c52a41d723265af211c0497 (diff)
downloadrust-91685c0ef4a9d41cd593a0d2433c1b6131ecee32.tar.gz
rust-91685c0ef4a9d41cd593a0d2433c1b6131ecee32.zip
Make search graph generic over interner
-rw-r--r--compiler/rustc_middle/src/ty/context.rs4
-rw-r--r--compiler/rustc_middle/src/ty/predicate.rs6
-rw-r--r--compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs6
-rw-r--r--compiler/rustc_trait_selection/src/solve/search_graph.rs110
-rw-r--r--compiler/rustc_type_ir/src/inherent.rs1
-rw-r--r--compiler/rustc_type_ir/src/interner.rs2
6 files changed, 75 insertions, 54 deletions
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index 69681930be6..d75f250275b 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -233,6 +233,10 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
     fn parent(self, def_id: Self::DefId) -> Self::DefId {
         self.parent(def_id)
     }
+
+    fn recursion_limit(self) -> usize {
+        self.recursion_limit().0
+    }
 }
 
 impl<'tcx> rustc_type_ir::inherent::Abi<TyCtxt<'tcx>> for abi::Abi {
diff --git a/compiler/rustc_middle/src/ty/predicate.rs b/compiler/rustc_middle/src/ty/predicate.rs
index 644fca7c5fe..be91249a25f 100644
--- a/compiler/rustc_middle/src/ty/predicate.rs
+++ b/compiler/rustc_middle/src/ty/predicate.rs
@@ -37,7 +37,11 @@ pub struct Predicate<'tcx>(
     pub(super) Interned<'tcx, WithCachedTypeInfo<ty::Binder<'tcx, PredicateKind<'tcx>>>>,
 );
 
-impl<'tcx> rustc_type_ir::inherent::Predicate<TyCtxt<'tcx>> for Predicate<'tcx> {}
+impl<'tcx> rustc_type_ir::inherent::Predicate<TyCtxt<'tcx>> for Predicate<'tcx> {
+    fn is_coinductive(self, interner: TyCtxt<'tcx>) -> bool {
+        self.is_coinductive(interner)
+    }
+}
 
 impl<'tcx> rustc_type_ir::visit::Flags for Predicate<'tcx> {
     fn flags(&self) -> TypeFlags {
diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
index 70308d4359d..9771e3f4e30 100644
--- a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
@@ -85,7 +85,7 @@ pub struct EvalCtxt<'a, 'tcx> {
     /// new placeholders to the caller.
     pub(super) max_input_universe: ty::UniverseIndex,
 
-    pub(super) search_graph: &'a mut SearchGraph<'tcx>,
+    pub(super) search_graph: &'a mut SearchGraph<TyCtxt<'tcx>>,
 
     nested_goals: NestedGoals<TyCtxt<'tcx>>,
 
@@ -225,7 +225,7 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
     /// and registering opaques from the canonicalized input.
     fn enter_canonical<R>(
         tcx: TyCtxt<'tcx>,
-        search_graph: &'a mut search_graph::SearchGraph<'tcx>,
+        search_graph: &'a mut search_graph::SearchGraph<TyCtxt<'tcx>>,
         canonical_input: CanonicalInput<'tcx>,
         canonical_goal_evaluation: &mut ProofTreeBuilder<TyCtxt<'tcx>>,
         f: impl FnOnce(&mut EvalCtxt<'_, 'tcx>, Goal<'tcx, ty::Predicate<'tcx>>) -> R,
@@ -287,7 +287,7 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
     #[instrument(level = "debug", skip(tcx, search_graph, goal_evaluation), ret)]
     fn evaluate_canonical_goal(
         tcx: TyCtxt<'tcx>,
-        search_graph: &'a mut search_graph::SearchGraph<'tcx>,
+        search_graph: &'a mut search_graph::SearchGraph<TyCtxt<'tcx>>,
         canonical_input: CanonicalInput<'tcx>,
         goal_evaluation: &mut ProofTreeBuilder<TyCtxt<'tcx>>,
     ) -> QueryResult<'tcx> {
diff --git a/compiler/rustc_trait_selection/src/solve/search_graph.rs b/compiler/rustc_trait_selection/src/solve/search_graph.rs
index 0164d44667c..5a5df439a78 100644
--- a/compiler/rustc_trait_selection/src/solve/search_graph.rs
+++ b/compiler/rustc_trait_selection/src/solve/search_graph.rs
@@ -1,18 +1,21 @@
-use crate::solve::FIXPOINT_STEP_LIMIT;
+use std::mem;
 
-use super::inspect;
-use super::inspect::ProofTreeBuilder;
-use super::SolverMode;
-use rustc_data_structures::fx::FxHashMap;
-use rustc_data_structures::fx::FxHashSet;
+use rustc_data_structures::fx::{FxHashMap, FxHashSet};
 use rustc_index::Idx;
 use rustc_index::IndexVec;
 use rustc_middle::dep_graph::dep_kinds;
 use rustc_middle::traits::solve::CacheData;
-use rustc_middle::traits::solve::{CanonicalInput, Certainty, EvaluationCache, QueryResult};
+use rustc_middle::traits::solve::EvaluationCache;
 use rustc_middle::ty::TyCtxt;
+use rustc_next_trait_solver::solve::{CanonicalInput, Certainty, QueryResult};
 use rustc_session::Limit;
-use std::mem;
+use rustc_type_ir::inherent::*;
+use rustc_type_ir::Interner;
+
+use super::inspect;
+use super::inspect::ProofTreeBuilder;
+use super::SolverMode;
+use crate::solve::FIXPOINT_STEP_LIMIT;
 
 rustc_index::newtype_index! {
     #[orderable]
@@ -30,9 +33,10 @@ bitflags::bitflags! {
     }
 }
 
-#[derive(Debug)]
-struct StackEntry<'tcx> {
-    input: CanonicalInput<'tcx>,
+#[derive(derivative::Derivative)]
+#[derivative(Debug(bound = ""))]
+struct StackEntry<I: Interner> {
+    input: CanonicalInput<I>,
 
     available_depth: Limit,
 
@@ -53,11 +57,11 @@ struct StackEntry<'tcx> {
     has_been_used: HasBeenUsed,
     /// Starts out as `None` and gets set when rerunning this
     /// goal in case we encounter a cycle.
-    provisional_result: Option<QueryResult<'tcx>>,
+    provisional_result: Option<QueryResult<I>>,
 }
 
 /// The provisional result for a goal which is not on the stack.
-struct DetachedEntry<'tcx> {
+struct DetachedEntry<I: Interner> {
     /// The head of the smallest non-trivial cycle involving this entry.
     ///
     /// Given the following rules, when proving `A` the head for
@@ -68,7 +72,7 @@ struct DetachedEntry<'tcx> {
     /// C :- A + B + C
     /// ```
     head: StackDepth,
-    result: QueryResult<'tcx>,
+    result: QueryResult<I>,
 }
 
 /// Stores the stack depth of a currently evaluated goal *and* already
@@ -83,14 +87,15 @@ struct DetachedEntry<'tcx> {
 ///
 /// The provisional cache can theoretically result in changes to the observable behavior,
 /// see tests/ui/traits/next-solver/cycles/provisional-cache-impacts-behavior.rs.
-#[derive(Default)]
-struct ProvisionalCacheEntry<'tcx> {
+#[derive(derivative::Derivative)]
+#[derivative(Default(bound = ""))]
+struct ProvisionalCacheEntry<I: Interner> {
     stack_depth: Option<StackDepth>,
-    with_inductive_stack: Option<DetachedEntry<'tcx>>,
-    with_coinductive_stack: Option<DetachedEntry<'tcx>>,
+    with_inductive_stack: Option<DetachedEntry<I>>,
+    with_coinductive_stack: Option<DetachedEntry<I>>,
 }
 
-impl<'tcx> ProvisionalCacheEntry<'tcx> {
+impl<I: Interner> ProvisionalCacheEntry<I> {
     fn is_empty(&self) -> bool {
         self.stack_depth.is_none()
             && self.with_inductive_stack.is_none()
@@ -98,13 +103,13 @@ impl<'tcx> ProvisionalCacheEntry<'tcx> {
     }
 }
 
-pub(super) struct SearchGraph<'tcx> {
+pub(super) struct SearchGraph<I: Interner> {
     mode: SolverMode,
     /// The stack of goals currently being computed.
     ///
     /// An element is *deeper* in the stack if its index is *lower*.
-    stack: IndexVec<StackDepth, StackEntry<'tcx>>,
-    provisional_cache: FxHashMap<CanonicalInput<'tcx>, ProvisionalCacheEntry<'tcx>>,
+    stack: IndexVec<StackDepth, StackEntry<I>>,
+    provisional_cache: FxHashMap<CanonicalInput<I>, ProvisionalCacheEntry<I>>,
     /// We put only the root goal of a coinductive cycle into the global cache.
     ///
     /// If we were to use that result when later trying to prove another cycle
@@ -112,11 +117,11 @@ pub(super) struct SearchGraph<'tcx> {
     ///
     /// See tests/ui/next-solver/coinduction/incompleteness-unstable-result.rs for
     /// an example of where this is needed.
-    cycle_participants: FxHashSet<CanonicalInput<'tcx>>,
+    cycle_participants: FxHashSet<CanonicalInput<I>>,
 }
 
-impl<'tcx> SearchGraph<'tcx> {
-    pub(super) fn new(mode: SolverMode) -> SearchGraph<'tcx> {
+impl<I: Interner> SearchGraph<I> {
+    pub(super) fn new(mode: SolverMode) -> SearchGraph<I> {
         Self {
             mode,
             stack: Default::default(),
@@ -144,7 +149,7 @@ impl<'tcx> SearchGraph<'tcx> {
     ///
     /// Directly popping from the stack instead of using this method
     /// would cause us to not track overflow and recursion depth correctly.
-    fn pop_stack(&mut self) -> StackEntry<'tcx> {
+    fn pop_stack(&mut self) -> StackEntry<I> {
         let elem = self.stack.pop().unwrap();
         if let Some(last) = self.stack.raw.last_mut() {
             last.reached_depth = last.reached_depth.max(elem.reached_depth);
@@ -153,17 +158,6 @@ impl<'tcx> SearchGraph<'tcx> {
         elem
     }
 
-    /// The trait solver behavior is different for coherence
-    /// so we use a separate cache. Alternatively we could use
-    /// a single cache and share it between coherence and ordinary
-    /// trait solving.
-    pub(super) fn global_cache(&self, tcx: TyCtxt<'tcx>) -> &'tcx EvaluationCache<'tcx> {
-        match self.mode {
-            SolverMode::Normal => &tcx.new_solver_evaluation_cache,
-            SolverMode::Coherence => &tcx.new_solver_coherence_evaluation_cache,
-        }
-    }
-
     pub(super) fn is_empty(&self) -> bool {
         if self.stack.is_empty() {
             debug_assert!(self.provisional_cache.is_empty());
@@ -181,8 +175,8 @@ impl<'tcx> SearchGraph<'tcx> {
     /// the remaining depth of all nested goals to prevent hangs
     /// in case there is exponential blowup.
     fn allowed_depth_for_nested(
-        tcx: TyCtxt<'tcx>,
-        stack: &IndexVec<StackDepth, StackEntry<'tcx>>,
+        tcx: I,
+        stack: &IndexVec<StackDepth, StackEntry<I>>,
     ) -> Option<Limit> {
         if let Some(last) = stack.raw.last() {
             if last.available_depth.0 == 0 {
@@ -195,13 +189,13 @@ impl<'tcx> SearchGraph<'tcx> {
                 Limit(last.available_depth.0 - 1)
             })
         } else {
-            Some(tcx.recursion_limit())
+            Some(Limit(tcx.recursion_limit()))
         }
     }
 
     fn stack_coinductive_from(
-        tcx: TyCtxt<'tcx>,
-        stack: &IndexVec<StackDepth, StackEntry<'tcx>>,
+        tcx: I,
+        stack: &IndexVec<StackDepth, StackEntry<I>>,
         head: StackDepth,
     ) -> bool {
         stack.raw[head.index()..]
@@ -220,8 +214,8 @@ impl<'tcx> SearchGraph<'tcx> {
     // we reach a fixpoint and all other cycle participants to make sure that
     // their result does not get moved to the global cache.
     fn tag_cycle_participants(
-        stack: &mut IndexVec<StackDepth, StackEntry<'tcx>>,
-        cycle_participants: &mut FxHashSet<CanonicalInput<'tcx>>,
+        stack: &mut IndexVec<StackDepth, StackEntry<I>>,
+        cycle_participants: &mut FxHashSet<CanonicalInput<I>>,
         usage_kind: HasBeenUsed,
         head: StackDepth,
     ) {
@@ -234,7 +228,7 @@ impl<'tcx> SearchGraph<'tcx> {
     }
 
     fn clear_dependent_provisional_results(
-        provisional_cache: &mut FxHashMap<CanonicalInput<'tcx>, ProvisionalCacheEntry<'tcx>>,
+        provisional_cache: &mut FxHashMap<CanonicalInput<I>, ProvisionalCacheEntry<I>>,
         head: StackDepth,
     ) {
         #[allow(rustc::potential_query_instability)]
@@ -244,6 +238,19 @@ impl<'tcx> SearchGraph<'tcx> {
             !entry.is_empty()
         });
     }
+}
+
+impl<'tcx> SearchGraph<TyCtxt<'tcx>> {
+    /// The trait solver behavior is different for coherence
+    /// so we use a separate cache. Alternatively we could use
+    /// a single cache and share it between coherence and ordinary
+    /// trait solving.
+    pub(super) fn global_cache(&self, tcx: TyCtxt<'tcx>) -> &'tcx EvaluationCache<'tcx> {
+        match self.mode {
+            SolverMode::Normal => &tcx.new_solver_evaluation_cache,
+            SolverMode::Coherence => &tcx.new_solver_coherence_evaluation_cache,
+        }
+    }
 
     /// Probably the most involved method of the whole solver.
     ///
@@ -252,10 +259,13 @@ impl<'tcx> SearchGraph<'tcx> {
     pub(super) fn with_new_goal(
         &mut self,
         tcx: TyCtxt<'tcx>,
-        input: CanonicalInput<'tcx>,
+        input: CanonicalInput<TyCtxt<'tcx>>,
         inspect: &mut ProofTreeBuilder<TyCtxt<'tcx>>,
-        mut prove_goal: impl FnMut(&mut Self, &mut ProofTreeBuilder<TyCtxt<'tcx>>) -> QueryResult<'tcx>,
-    ) -> QueryResult<'tcx> {
+        mut prove_goal: impl FnMut(
+            &mut Self,
+            &mut ProofTreeBuilder<TyCtxt<'tcx>>,
+        ) -> QueryResult<TyCtxt<'tcx>>,
+    ) -> QueryResult<TyCtxt<'tcx>> {
         // Check for overflow.
         let Some(available_depth) = Self::allowed_depth_for_nested(tcx, &self.stack) else {
             if let Some(last) = self.stack.raw.last_mut() {
@@ -489,9 +499,9 @@ impl<'tcx> SearchGraph<'tcx> {
 
     fn response_no_constraints(
         tcx: TyCtxt<'tcx>,
-        goal: CanonicalInput<'tcx>,
+        goal: CanonicalInput<TyCtxt<'tcx>>,
         certainty: Certainty,
-    ) -> QueryResult<'tcx> {
+    ) -> QueryResult<TyCtxt<'tcx>> {
         Ok(super::response_no_constraints_raw(tcx, goal.max_universe, goal.variables, certainty))
     }
 }
diff --git a/compiler/rustc_type_ir/src/inherent.rs b/compiler/rustc_type_ir/src/inherent.rs
index 5289dfd932f..19c76fb165a 100644
--- a/compiler/rustc_type_ir/src/inherent.rs
+++ b/compiler/rustc_type_ir/src/inherent.rs
@@ -96,6 +96,7 @@ pub trait GenericArgs<I: Interner<GenericArgs = Self>>:
 pub trait Predicate<I: Interner<Predicate = Self>>:
     Copy + Debug + Hash + Eq + TypeSuperVisitable<I> + TypeSuperFoldable<I> + Flags
 {
+    fn is_coinductive(self, interner: I) -> bool;
 }
 
 /// Common capabilities of placeholder kinds
diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs
index 9acf7c04dd6..0b51d2e75f4 100644
--- a/compiler/rustc_type_ir/src/interner.rs
+++ b/compiler/rustc_type_ir/src/interner.rs
@@ -124,6 +124,8 @@ pub trait Interner:
     ) -> Self::GenericArgs;
 
     fn parent(self, def_id: Self::DefId) -> Self::DefId;
+
+    fn recursion_limit(self) -> usize;
 }
 
 /// Imagine you have a function `F: FnOnce(&[T]) -> R`, plus an iterator `iter`