diff options
| author | Michael Goulet <michael@errs.io> | 2024-06-14 18:25:31 -0400 |
|---|---|---|
| committer | Michael Goulet <michael@errs.io> | 2024-06-18 10:40:30 -0400 |
| commit | dba414763362a15d56992c35728242300282f0ef (patch) | |
| tree | c48ad4a0e61148d2d0734ca659a35bdf96cf3184 | |
| parent | af3d1004c766fc1413f7ad8ad052b77c077b83a1 (diff) | |
| download | rust-dba414763362a15d56992c35728242300282f0ef.tar.gz rust-dba414763362a15d56992c35728242300282f0ef.zip | |
Make SearchGraph fully generic
| -rw-r--r-- | compiler/rustc_middle/src/traits/solve.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/traits/solve/cache.rs | 28 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/ty/context.rs | 21 | ||||
| -rw-r--r-- | compiler/rustc_trait_selection/src/solve/mod.rs | 30 | ||||
| -rw-r--r-- | compiler/rustc_trait_selection/src/solve/search_graph.rs | 92 | ||||
| -rw-r--r-- | compiler/rustc_type_ir/src/inherent.rs | 29 | ||||
| -rw-r--r-- | compiler/rustc_type_ir/src/interner.rs | 20 | ||||
| -rw-r--r-- | compiler/rustc_type_ir/src/solve.rs | 22 |
8 files changed, 149 insertions, 95 deletions
diff --git a/compiler/rustc_middle/src/traits/solve.rs b/compiler/rustc_middle/src/traits/solve.rs index 9e979620a44..888e3aec5ea 100644 --- a/compiler/rustc_middle/src/traits/solve.rs +++ b/compiler/rustc_middle/src/traits/solve.rs @@ -10,7 +10,7 @@ use crate::ty::{ mod cache; -pub use cache::{CacheData, EvaluationCache}; +pub use cache::EvaluationCache; pub type Goal<'tcx, P> = ir::solve::Goal<TyCtxt<'tcx>, P>; pub type QueryInput<'tcx, P> = ir::solve::QueryInput<TyCtxt<'tcx>, P>; diff --git a/compiler/rustc_middle/src/traits/solve/cache.rs b/compiler/rustc_middle/src/traits/solve/cache.rs index dc31114b2c4..72a8d4eb405 100644 --- a/compiler/rustc_middle/src/traits/solve/cache.rs +++ b/compiler/rustc_middle/src/traits/solve/cache.rs @@ -5,6 +5,8 @@ use rustc_data_structures::sync::Lock; use rustc_query_system::cache::WithDepNode; use rustc_query_system::dep_graph::DepNodeIndex; use rustc_session::Limit; +use rustc_type_ir::solve::CacheData; + /// The trait solver cache used by `-Znext-solver`. /// /// FIXME(@lcnr): link to some official documentation of how @@ -14,17 +16,9 @@ pub struct EvaluationCache<'tcx> { map: Lock<FxHashMap<CanonicalInput<'tcx>, CacheEntry<'tcx>>>, } -#[derive(Debug, PartialEq, Eq)] -pub struct CacheData<'tcx> { - pub result: QueryResult<'tcx>, - pub proof_tree: Option<&'tcx inspect::CanonicalGoalEvaluationStep<TyCtxt<'tcx>>>, - pub additional_depth: usize, - pub encountered_overflow: bool, -} - -impl<'tcx> EvaluationCache<'tcx> { +impl<'tcx> rustc_type_ir::inherent::EvaluationCache<TyCtxt<'tcx>> for &'tcx EvaluationCache<'tcx> { /// Insert a final result into the global cache. - pub fn insert( + fn insert( &self, tcx: TyCtxt<'tcx>, key: CanonicalInput<'tcx>, @@ -48,7 +42,7 @@ impl<'tcx> EvaluationCache<'tcx> { if cfg!(debug_assertions) { drop(map); let expected = CacheData { result, proof_tree, additional_depth, encountered_overflow }; - let actual = self.get(tcx, key, [], Limit(additional_depth)); + let actual = self.get(tcx, key, [], additional_depth); if !actual.as_ref().is_some_and(|actual| expected == *actual) { bug!("failed to lookup inserted element for {key:?}: {expected:?} != {actual:?}"); } @@ -59,13 +53,13 @@ impl<'tcx> EvaluationCache<'tcx> { /// and handling root goals of coinductive cycles. /// /// If this returns `Some` the cache result can be used. - pub fn get( + fn get( &self, tcx: TyCtxt<'tcx>, key: CanonicalInput<'tcx>, stack_entries: impl IntoIterator<Item = CanonicalInput<'tcx>>, - available_depth: Limit, - ) -> Option<CacheData<'tcx>> { + available_depth: usize, + ) -> Option<CacheData<TyCtxt<'tcx>>> { let map = self.map.borrow(); let entry = map.get(&key)?; @@ -76,7 +70,7 @@ impl<'tcx> EvaluationCache<'tcx> { } if let Some(ref success) = entry.success { - if available_depth.value_within_limit(success.additional_depth) { + if Limit(available_depth).value_within_limit(success.additional_depth) { let QueryData { result, proof_tree } = success.data.get(tcx); return Some(CacheData { result, @@ -87,12 +81,12 @@ impl<'tcx> EvaluationCache<'tcx> { } } - entry.with_overflow.get(&available_depth.0).map(|e| { + entry.with_overflow.get(&available_depth).map(|e| { let QueryData { result, proof_tree } = e.get(tcx); CacheData { result, proof_tree, - additional_depth: available_depth.0, + additional_depth: available_depth, encountered_overflow: true, } }) diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index e2f15dac019..eec7fa8db1d 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -71,6 +71,7 @@ use rustc_target::abi::{FieldIdx, Layout, LayoutS, TargetDataLayout, VariantIdx} use rustc_target::spec::abi; use rustc_type_ir::fold::TypeFoldable; use rustc_type_ir::lang_items::TraitSolverLangItem; +use rustc_type_ir::solve::SolverMode; use rustc_type_ir::TyKind::*; use rustc_type_ir::{CollectAndApply, Interner, TypeFlags, WithCachedTypeInfo}; use tracing::{debug, instrument}; @@ -139,10 +140,30 @@ impl<'tcx> Interner for TyCtxt<'tcx> { type Clause = Clause<'tcx>; type Clauses = ty::Clauses<'tcx>; + type DepNodeIndex = DepNodeIndex; + fn with_cached_task<T>(self, task: impl FnOnce() -> T) -> (T, DepNodeIndex) { + self.dep_graph.with_anon_task(self, crate::dep_graph::dep_kinds::TraitSelect, task) + } + + type EvaluationCache = &'tcx solve::EvaluationCache<'tcx>; + fn evaluation_cache(self, mode: SolverMode) -> &'tcx solve::EvaluationCache<'tcx> { + match mode { + SolverMode::Normal => &self.new_solver_evaluation_cache, + SolverMode::Coherence => &self.new_solver_coherence_evaluation_cache, + } + } + fn expand_abstract_consts<T: TypeFoldable<TyCtxt<'tcx>>>(self, t: T) -> T { self.expand_abstract_consts(t) } + fn mk_external_constraints( + self, + data: ExternalConstraintsData<Self>, + ) -> ExternalConstraints<'tcx> { + self.mk_external_constraints(data) + } + fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars { self.mk_canonical_var_infos(infos) } diff --git a/compiler/rustc_trait_selection/src/solve/mod.rs b/compiler/rustc_trait_selection/src/solve/mod.rs index 4f1be5cbc85..7b6e525370c 100644 --- a/compiler/rustc_trait_selection/src/solve/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/mod.rs @@ -14,12 +14,11 @@ //! FIXME(@lcnr): Write that section. If you read this before then ask me //! about it on zulip. use rustc_hir::def_id::DefId; -use rustc_infer::infer::canonical::{Canonical, CanonicalVarValues}; +use rustc_infer::infer::canonical::Canonical; use rustc_infer::infer::InferCtxt; use rustc_infer::traits::query::NoSolution; use rustc_macros::extension; use rustc_middle::bug; -use rustc_middle::infer::canonical::CanonicalVarInfos; use rustc_middle::traits::solve::{ CanonicalResponse, Certainty, ExternalConstraintsData, Goal, GoalSource, QueryResult, Response, }; @@ -27,6 +26,8 @@ use rustc_middle::ty::{ self, AliasRelationDirection, CoercePredicate, RegionOutlivesPredicate, SubtypePredicate, Ty, TyCtxt, TypeOutlivesPredicate, UniverseIndex, }; +use rustc_type_ir::solve::SolverMode; +use rustc_type_ir::{self as ir, Interner}; mod alias_relate; mod assembly; @@ -57,19 +58,6 @@ pub use select::InferCtxtSelectExt; /// recursion limit again. However, this feels very unlikely. const FIXPOINT_STEP_LIMIT: usize = 8; -#[derive(Debug, Clone, Copy)] -enum SolverMode { - /// Ordinary trait solving, using everywhere except for coherence. - Normal, - /// Trait solving during coherence. There are a few notable differences - /// between coherence and ordinary trait solving. - /// - /// Most importantly, trait solving during coherence must not be incomplete, - /// i.e. return `Err(NoSolution)` for goals for which a solution exists. - /// This means that we must not make any guesses or arbitrary choices. - Coherence, -} - #[derive(Debug, Copy, Clone, PartialEq, Eq)] enum GoalEvaluationKind { Root, @@ -314,17 +302,17 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> { } } -fn response_no_constraints_raw<'tcx>( - tcx: TyCtxt<'tcx>, +fn response_no_constraints_raw<I: Interner>( + tcx: I, max_universe: UniverseIndex, - variables: CanonicalVarInfos<'tcx>, + variables: I::CanonicalVars, certainty: Certainty, -) -> CanonicalResponse<'tcx> { - Canonical { +) -> ir::solve::CanonicalResponse<I> { + ir::Canonical { max_universe, variables, value: Response { - var_values: CanonicalVarValues::make_identity(tcx, variables), + var_values: ir::CanonicalVarValues::make_identity(tcx, variables), // FIXME: maybe we should store the "no response" version in tcx, like // we do for tcx.types and stuff. external_constraints: tcx.mk_external_constraints(ExternalConstraintsData::default()), diff --git a/compiler/rustc_trait_selection/src/solve/search_graph.rs b/compiler/rustc_trait_selection/src/solve/search_graph.rs index 84878fea101..681061c25aa 100644 --- a/compiler/rustc_trait_selection/src/solve/search_graph.rs +++ b/compiler/rustc_trait_selection/src/solve/search_graph.rs @@ -3,14 +3,11 @@ use std::mem; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_index::Idx; use rustc_index::IndexVec; -use rustc_infer::infer::InferCtxt; -use rustc_middle::dep_graph::dep_kinds; -use rustc_middle::traits::solve::CacheData; -use rustc_middle::traits::solve::EvaluationCache; -use rustc_middle::ty::TyCtxt; +use rustc_next_trait_solver::solve::CacheData; use rustc_next_trait_solver::solve::{CanonicalInput, Certainty, QueryResult}; use rustc_session::Limit; use rustc_type_ir::inherent::*; +use rustc_type_ir::InferCtxtLike; use rustc_type_ir::Interner; use super::inspect; @@ -240,34 +237,26 @@ impl<I: Interner> SearchGraph<I> { !entry.is_empty() }); } -} -impl<'tcx> SearchGraph<TyCtxt<'tcx>> { /// The trait solver behavior is different for coherence /// so we use a separate cache. Alternatively we could use /// a single cache and share it between coherence and ordinary /// trait solving. - pub(super) fn global_cache(&self, tcx: TyCtxt<'tcx>) -> &'tcx EvaluationCache<'tcx> { - match self.mode { - SolverMode::Normal => &tcx.new_solver_evaluation_cache, - SolverMode::Coherence => &tcx.new_solver_coherence_evaluation_cache, - } + pub(super) fn global_cache(&self, tcx: I) -> I::EvaluationCache { + tcx.evaluation_cache(self.mode) } /// Probably the most involved method of the whole solver. /// /// Given some goal which is proven via the `prove_goal` closure, this /// handles caching, overflow, and coinductive cycles. - pub(super) fn with_new_goal( + pub(super) fn with_new_goal<Infcx: InferCtxtLike<Interner = I>>( &mut self, - tcx: TyCtxt<'tcx>, - input: CanonicalInput<TyCtxt<'tcx>>, - inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>, - mut prove_goal: impl FnMut( - &mut Self, - &mut ProofTreeBuilder<InferCtxt<'tcx>>, - ) -> QueryResult<TyCtxt<'tcx>>, - ) -> QueryResult<TyCtxt<'tcx>> { + tcx: I, + input: CanonicalInput<I>, + inspect: &mut ProofTreeBuilder<Infcx>, + mut prove_goal: impl FnMut(&mut Self, &mut ProofTreeBuilder<Infcx>) -> QueryResult<I>, + ) -> QueryResult<I> { self.check_invariants(); // Check for overflow. let Some(available_depth) = Self::allowed_depth_for_nested(tcx, &self.stack) else { @@ -361,21 +350,20 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> { // not tracked by the cache key and from outside of this anon task, it // must not be added to the global cache. Notably, this is the case for // trait solver cycles participants. - let ((final_entry, result), dep_node) = - tcx.dep_graph.with_anon_task(tcx, dep_kinds::TraitSelect, || { - for _ in 0..FIXPOINT_STEP_LIMIT { - match self.fixpoint_step_in_task(tcx, input, inspect, &mut prove_goal) { - StepResult::Done(final_entry, result) => return (final_entry, result), - StepResult::HasChanged => debug!("fixpoint changed provisional results"), - } + let ((final_entry, result), dep_node) = tcx.with_cached_task(|| { + for _ in 0..FIXPOINT_STEP_LIMIT { + match self.fixpoint_step_in_task(tcx, input, inspect, &mut prove_goal) { + StepResult::Done(final_entry, result) => return (final_entry, result), + StepResult::HasChanged => debug!("fixpoint changed provisional results"), } + } - debug!("canonical cycle overflow"); - let current_entry = self.pop_stack(); - debug_assert!(current_entry.has_been_used.is_empty()); - let result = Self::response_no_constraints(tcx, input, Certainty::overflow(false)); - (current_entry, result) - }); + debug!("canonical cycle overflow"); + let current_entry = self.pop_stack(); + debug_assert!(current_entry.has_been_used.is_empty()); + let result = Self::response_no_constraints(tcx, input, Certainty::overflow(false)); + (current_entry, result) + }); let proof_tree = inspect.finalize_canonical_goal_evaluation(tcx); @@ -423,16 +411,17 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> { /// Try to fetch a previously computed result from the global cache, /// making sure to only do so if it would match the result of reevaluating /// this goal. - fn lookup_global_cache( + fn lookup_global_cache<Infcx: InferCtxtLike<Interner = I>>( &mut self, - tcx: TyCtxt<'tcx>, - input: CanonicalInput<TyCtxt<'tcx>>, + tcx: I, + input: CanonicalInput<I>, available_depth: Limit, - inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>, - ) -> Option<QueryResult<TyCtxt<'tcx>>> { + inspect: &mut ProofTreeBuilder<Infcx>, + ) -> Option<QueryResult<I>> { let CacheData { result, proof_tree, additional_depth, encountered_overflow } = self .global_cache(tcx) - .get(tcx, input, self.stack.iter().map(|e| e.input), available_depth)?; + // TODO: Awkward `Limit -> usize -> Limit`. + .get(tcx, input, self.stack.iter().map(|e| e.input), available_depth.0)?; // If we're building a proof tree and the current cache entry does not // contain a proof tree, we do not use the entry but instead recompute @@ -465,21 +454,22 @@ enum StepResult<I: Interner> { HasChanged, } -impl<'tcx> SearchGraph<TyCtxt<'tcx>> { +impl<I: Interner> SearchGraph<I> { /// When we encounter a coinductive cycle, we have to fetch the /// result of that cycle while we are still computing it. Because /// of this we continuously recompute the cycle until the result /// of the previous iteration is equal to the final result, at which /// point we are done. - fn fixpoint_step_in_task<F>( + fn fixpoint_step_in_task<Infcx, F>( &mut self, - tcx: TyCtxt<'tcx>, - input: CanonicalInput<TyCtxt<'tcx>>, - inspect: &mut ProofTreeBuilder<InferCtxt<'tcx>>, + tcx: I, + input: CanonicalInput<I>, + inspect: &mut ProofTreeBuilder<Infcx>, prove_goal: &mut F, - ) -> StepResult<TyCtxt<'tcx>> + ) -> StepResult<I> where - F: FnMut(&mut Self, &mut ProofTreeBuilder<InferCtxt<'tcx>>) -> QueryResult<TyCtxt<'tcx>>, + Infcx: InferCtxtLike<Interner = I>, + F: FnMut(&mut Self, &mut ProofTreeBuilder<Infcx>) -> QueryResult<I>, { let result = prove_goal(self, inspect); let stack_entry = self.pop_stack(); @@ -533,15 +523,13 @@ impl<'tcx> SearchGraph<TyCtxt<'tcx>> { } fn response_no_constraints( - tcx: TyCtxt<'tcx>, - goal: CanonicalInput<TyCtxt<'tcx>>, + tcx: I, + goal: CanonicalInput<I>, certainty: Certainty, - ) -> QueryResult<TyCtxt<'tcx>> { + ) -> QueryResult<I> { Ok(super::response_no_constraints_raw(tcx, goal.max_universe, goal.variables, certainty)) } -} -impl<I: Interner> SearchGraph<I> { #[allow(rustc::potential_query_instability)] fn check_invariants(&self) { if !cfg!(debug_assertions) { diff --git a/compiler/rustc_type_ir/src/inherent.rs b/compiler/rustc_type_ir/src/inherent.rs index 6b84592978a..4afb9a2339b 100644 --- a/compiler/rustc_type_ir/src/inherent.rs +++ b/compiler/rustc_type_ir/src/inherent.rs @@ -8,9 +8,11 @@ use std::hash::Hash; use std::ops::Deref; use rustc_ast_ir::Mutability; +use rustc_data_structures::fx::FxHashSet; use crate::fold::{TypeFoldable, TypeSuperFoldable}; use crate::relate::Relate; +use crate::solve::{CacheData, CanonicalInput, QueryResult}; use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable}; use crate::{self as ty, CollectAndApply, Interner, UpcastFrom}; @@ -363,3 +365,30 @@ pub trait Features<I: Interner>: Copy { fn coroutine_clone(self) -> bool; } + +pub trait EvaluationCache<I: Interner> { + /// Insert a final result into the global cache. + fn insert( + &self, + tcx: I, + key: CanonicalInput<I>, + proof_tree: Option<I::CanonicalGoalEvaluationStepRef>, + additional_depth: usize, + encountered_overflow: bool, + cycle_participants: FxHashSet<CanonicalInput<I>>, + dep_node: I::DepNodeIndex, + result: QueryResult<I>, + ); + + /// Try to fetch a cached result, checking the recursion limit + /// and handling root goals of coinductive cycles. + /// + /// If this returns `Some` the cache result can be used. + fn get( + &self, + tcx: I, + key: CanonicalInput<I>, + stack_entries: impl IntoIterator<Item = CanonicalInput<I>>, + available_depth: usize, + ) -> Option<CacheData<I>>; +} diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index 11c1f73fef3..b099f63d382 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -10,6 +10,7 @@ use crate::ir_print::IrPrint; use crate::lang_items::TraitSolverLangItem; use crate::relate::Relate; use crate::solve::inspect::CanonicalGoalEvaluationStep; +use crate::solve::{ExternalConstraintsData, SolverMode}; use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable}; use crate::{self as ty}; @@ -45,16 +46,26 @@ pub trait Interner: + Default; type BoundVarKind: Copy + Debug + Hash + Eq; - type CanonicalVars: Copy + Debug + Hash + Eq + IntoIterator<Item = ty::CanonicalVarInfo<Self>>; type PredefinedOpaques: Copy + Debug + Hash + Eq; type DefiningOpaqueTypes: Copy + Debug + Hash + Default + Eq + TypeVisitable<Self>; - type ExternalConstraints: Copy + Debug + Hash + Eq; type CanonicalGoalEvaluationStepRef: Copy + Debug + Hash + Eq + Deref<Target = CanonicalGoalEvaluationStep<Self>>; + type CanonicalVars: Copy + Debug + Hash + Eq + IntoIterator<Item = ty::CanonicalVarInfo<Self>>; + fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars; + + type ExternalConstraints: Copy + Debug + Hash + Eq; + fn mk_external_constraints( + self, + data: ExternalConstraintsData<Self>, + ) -> Self::ExternalConstraints; + + type DepNodeIndex; + fn with_cached_task<T>(self, task: impl FnOnce() -> T) -> (T, Self::DepNodeIndex); + // Kinds of tys type Ty: Ty<Self>; type Tys: Tys<Self>; @@ -97,9 +108,10 @@ pub trait Interner: type Clause: Clause<Self>; type Clauses: Copy + Debug + Hash + Eq + TypeSuperVisitable<Self> + Flags; - fn expand_abstract_consts<T: TypeFoldable<Self>>(self, t: T) -> T; + type EvaluationCache: EvaluationCache<Self>; + fn evaluation_cache(self, mode: SolverMode) -> Self::EvaluationCache; - fn mk_canonical_var_infos(self, infos: &[ty::CanonicalVarInfo<Self>]) -> Self::CanonicalVars; + fn expand_abstract_consts<T: TypeFoldable<Self>>(self, t: T) -> T; type GenericsOf: GenericsOf<Self>; fn generics_of(self, def_id: Self::DefId) -> Self::GenericsOf; diff --git a/compiler/rustc_type_ir/src/solve.rs b/compiler/rustc_type_ir/src/solve.rs index 99d2fa74494..fc4df7ede9d 100644 --- a/compiler/rustc_type_ir/src/solve.rs +++ b/compiler/rustc_type_ir/src/solve.rs @@ -57,6 +57,19 @@ pub enum Reveal { All, } +#[derive(Debug, Clone, Copy)] +pub enum SolverMode { + /// Ordinary trait solving, using everywhere except for coherence. + Normal, + /// Trait solving during coherence. There are a few notable differences + /// between coherence and ordinary trait solving. + /// + /// Most importantly, trait solving during coherence must not be incomplete, + /// i.e. return `Err(NoSolution)` for goals for which a solution exists. + /// This means that we must not make any guesses or arbitrary choices. + Coherence, +} + pub type CanonicalInput<I, T = <I as Interner>::Predicate> = Canonical<I, QueryInput<I, T>>; pub type CanonicalResponse<I> = Canonical<I, Response<I>>; /// The result of evaluating a canonical query. @@ -356,3 +369,12 @@ impl MaybeCause { } } } + +#[derive(derivative::Derivative)] +#[derivative(PartialEq(bound = ""), Eq(bound = ""), Debug(bound = ""))] +pub struct CacheData<I: Interner> { + pub result: QueryResult<I>, + pub proof_tree: Option<I::CanonicalGoalEvaluationStepRef>, + pub additional_depth: usize, + pub encountered_overflow: bool, +} |
