about summary refs log tree commit diff
diff options
context:
space:
mode:
authorlcnr <rust@lcnr.de>2024-10-10 09:26:49 +0000
committerMichael Goulet <michael@errs.io>2024-10-10 06:09:50 -0400
commitd6fd45c2e3ee24bedbc1b7643b622ab97f29537f (patch)
tree7cf49fbf1f867c4b5e86b581d14ee5fd1d42f8cb
parenta1eceec00b2684f947481696ae2322e20d59db60 (diff)
downloadrust-d6fd45c2e3ee24bedbc1b7643b622ab97f29537f.tar.gz
rust-d6fd45c2e3ee24bedbc1b7643b622ab97f29537f.zip
impossible obligations check fast path
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs72
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/mod.rs8
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/search_graph.rs3
-rw-r--r--compiler/rustc_trait_selection/src/solve.rs1
-rw-r--r--compiler/rustc_trait_selection/src/solve/fulfill.rs15
-rw-r--r--compiler/rustc_trait_selection/src/traits/coherence.rs13
-rw-r--r--compiler/rustc_type_ir/src/search_graph/mod.rs12
-rw-r--r--tests/crashes/124894.rs11
8 files changed, 90 insertions, 45 deletions
diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
index ffa800348f2..bc324dcbf51 100644
--- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
@@ -18,8 +18,8 @@ use crate::solve::inspect::{self, ProofTreeBuilder};
 use crate::solve::search_graph::SearchGraph;
 use crate::solve::{
     CanonicalInput, CanonicalResponse, Certainty, FIXPOINT_STEP_LIMIT, Goal, GoalEvaluationKind,
-    GoalSource, NestedNormalizationGoals, NoSolution, PredefinedOpaquesData, QueryResult,
-    SolverMode,
+    GoalSource, HasChanged, NestedNormalizationGoals, NoSolution, PredefinedOpaquesData,
+    QueryResult, SolverMode,
 };
 
 pub(super) mod canonical;
@@ -126,11 +126,31 @@ pub enum GenerateProofTree {
 }
 
 pub trait SolverDelegateEvalExt: SolverDelegate {
+    /// Evaluates a goal from **outside** of the trait solver.
+    ///
+    /// Using this while inside of the solver is wrong as it uses a new
+    /// search graph which would break cycle detection.
     fn evaluate_root_goal(
         &self,
         goal: Goal<Self::Interner, <Self::Interner as Interner>::Predicate>,
         generate_proof_tree: GenerateProofTree,
-    ) -> (Result<(bool, Certainty), NoSolution>, Option<inspect::GoalEvaluation<Self::Interner>>);
+    ) -> (
+        Result<(HasChanged, Certainty), NoSolution>,
+        Option<inspect::GoalEvaluation<Self::Interner>>,
+    );
+
+    /// Check whether evaluating `goal` with a depth of `root_depth` may
+    /// succeed. This only returns `false` if the goal is guaranteed to
+    /// not hold. In case evaluation overflows and fails with ambiguity this
+    /// returns `true`.
+    ///
+    /// This is only intended to be used as a performance optimization
+    /// in coherence checking.
+    fn root_goal_may_hold_with_depth(
+        &self,
+        root_depth: usize,
+        goal: Goal<Self::Interner, <Self::Interner as Interner>::Predicate>,
+    ) -> bool;
 
     // FIXME: This is only exposed because we need to use it in `analyse.rs`
     // which is not yet uplifted. Once that's done, we should remove this.
@@ -139,7 +159,7 @@ pub trait SolverDelegateEvalExt: SolverDelegate {
         goal: Goal<Self::Interner, <Self::Interner as Interner>::Predicate>,
         generate_proof_tree: GenerateProofTree,
     ) -> (
-        Result<(NestedNormalizationGoals<Self::Interner>, bool, Certainty), NoSolution>,
+        Result<(NestedNormalizationGoals<Self::Interner>, HasChanged, Certainty), NoSolution>,
         Option<inspect::GoalEvaluation<Self::Interner>>,
     );
 }
@@ -149,31 +169,41 @@ where
     D: SolverDelegate<Interner = I>,
     I: Interner,
 {
-    /// Evaluates a goal from **outside** of the trait solver.
-    ///
-    /// Using this while inside of the solver is wrong as it uses a new
-    /// search graph which would break cycle detection.
     #[instrument(level = "debug", skip(self))]
     fn evaluate_root_goal(
         &self,
         goal: Goal<I, I::Predicate>,
         generate_proof_tree: GenerateProofTree,
-    ) -> (Result<(bool, Certainty), NoSolution>, Option<inspect::GoalEvaluation<I>>) {
-        EvalCtxt::enter_root(self, generate_proof_tree, |ecx| {
+    ) -> (Result<(HasChanged, Certainty), NoSolution>, Option<inspect::GoalEvaluation<I>>) {
+        EvalCtxt::enter_root(self, self.cx().recursion_limit(), generate_proof_tree, |ecx| {
             ecx.evaluate_goal(GoalEvaluationKind::Root, GoalSource::Misc, goal)
         })
     }
 
+    fn root_goal_may_hold_with_depth(
+        &self,
+        root_depth: usize,
+        goal: Goal<Self::Interner, <Self::Interner as Interner>::Predicate>,
+    ) -> bool {
+        self.probe(|| {
+            EvalCtxt::enter_root(self, root_depth, GenerateProofTree::No, |ecx| {
+                ecx.evaluate_goal(GoalEvaluationKind::Root, GoalSource::Misc, goal)
+            })
+            .0
+        })
+        .is_ok()
+    }
+
     #[instrument(level = "debug", skip(self))]
     fn evaluate_root_goal_raw(
         &self,
         goal: Goal<I, I::Predicate>,
         generate_proof_tree: GenerateProofTree,
     ) -> (
-        Result<(NestedNormalizationGoals<I>, bool, Certainty), NoSolution>,
+        Result<(NestedNormalizationGoals<I>, HasChanged, Certainty), NoSolution>,
         Option<inspect::GoalEvaluation<I>>,
     ) {
-        EvalCtxt::enter_root(self, generate_proof_tree, |ecx| {
+        EvalCtxt::enter_root(self, self.cx().recursion_limit(), generate_proof_tree, |ecx| {
             ecx.evaluate_goal_raw(GoalEvaluationKind::Root, GoalSource::Misc, goal)
         })
     }
@@ -197,10 +227,11 @@ where
     /// over using this manually (such as [`SolverDelegateEvalExt::evaluate_root_goal`]).
     pub(super) fn enter_root<R>(
         delegate: &D,
+        root_depth: usize,
         generate_proof_tree: GenerateProofTree,
         f: impl FnOnce(&mut EvalCtxt<'_, D>) -> R,
     ) -> (R, Option<inspect::GoalEvaluation<I>>) {
-        let mut search_graph = SearchGraph::new(delegate.solver_mode());
+        let mut search_graph = SearchGraph::new(delegate.solver_mode(), root_depth);
 
         let mut ecx = EvalCtxt {
             delegate,
@@ -339,7 +370,7 @@ where
         goal_evaluation_kind: GoalEvaluationKind,
         source: GoalSource,
         goal: Goal<I, I::Predicate>,
-    ) -> Result<(bool, Certainty), NoSolution> {
+    ) -> Result<(HasChanged, Certainty), NoSolution> {
         let (normalization_nested_goals, has_changed, certainty) =
             self.evaluate_goal_raw(goal_evaluation_kind, source, goal)?;
         assert!(normalization_nested_goals.is_empty());
@@ -360,7 +391,7 @@ where
         goal_evaluation_kind: GoalEvaluationKind,
         _source: GoalSource,
         goal: Goal<I, I::Predicate>,
-    ) -> Result<(NestedNormalizationGoals<I>, bool, Certainty), NoSolution> {
+    ) -> Result<(NestedNormalizationGoals<I>, HasChanged, Certainty), NoSolution> {
         let (orig_values, canonical_goal) = self.canonicalize_goal(goal);
         let mut goal_evaluation =
             self.inspect.new_goal_evaluation(goal, &orig_values, goal_evaluation_kind);
@@ -378,8 +409,13 @@ where
             Ok(response) => response,
         };
 
-        let has_changed = !response.value.var_values.is_identity_modulo_regions()
-            || !response.value.external_constraints.opaque_types.is_empty();
+        let has_changed = if !response.value.var_values.is_identity_modulo_regions()
+            || !response.value.external_constraints.opaque_types.is_empty()
+        {
+            HasChanged::Yes
+        } else {
+            HasChanged::No
+        };
 
         let (normalization_nested_goals, certainty) =
             self.instantiate_and_apply_query_response(goal.param_env, orig_values, response);
@@ -552,7 +588,7 @@ where
         for (source, goal) in goals.goals {
             let (has_changed, certainty) =
                 self.evaluate_goal(GoalEvaluationKind::Nested, source, goal)?;
-            if has_changed {
+            if has_changed == HasChanged::Yes {
                 unchanged_certainty = None;
             }
 
diff --git a/compiler/rustc_next_trait_solver/src/solve/mod.rs b/compiler/rustc_next_trait_solver/src/solve/mod.rs
index 309ab7f28d1..97f7c71f3fc 100644
--- a/compiler/rustc_next_trait_solver/src/solve/mod.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/mod.rs
@@ -48,6 +48,14 @@ enum GoalEvaluationKind {
     Nested,
 }
 
+/// Whether evaluating this goal ended up changing the
+/// inference state.
+#[derive(PartialEq, Eq, Debug, Hash, Clone, Copy)]
+pub enum HasChanged {
+    Yes,
+    No,
+}
+
 // FIXME(trait-system-refactor-initiative#117): we don't detect whether a response
 // ended up pulling down any universes.
 fn has_no_inference_or_external_constraints<I: Interner>(
diff --git a/compiler/rustc_next_trait_solver/src/solve/search_graph.rs b/compiler/rustc_next_trait_solver/src/solve/search_graph.rs
index e47cc03f5ad..0e3f179b0c8 100644
--- a/compiler/rustc_next_trait_solver/src/solve/search_graph.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/search_graph.rs
@@ -40,9 +40,6 @@ where
     }
 
     const DIVIDE_AVAILABLE_DEPTH_ON_OVERFLOW: usize = 4;
-    fn recursion_limit(cx: I) -> usize {
-        cx.recursion_limit()
-    }
 
     fn initial_provisional_result(
         cx: I,
diff --git a/compiler/rustc_trait_selection/src/solve.rs b/compiler/rustc_trait_selection/src/solve.rs
index e47f5389cd1..d425ab50ae0 100644
--- a/compiler/rustc_trait_selection/src/solve.rs
+++ b/compiler/rustc_trait_selection/src/solve.rs
@@ -6,6 +6,7 @@ pub mod inspect;
 mod normalize;
 mod select;
 
+pub(crate) use delegate::SolverDelegate;
 pub use fulfill::{FulfillmentCtxt, NextSolverError};
 pub(crate) use normalize::deeply_normalize_for_diagnostics;
 pub use normalize::{deeply_normalize, deeply_normalize_with_skipped_universes};
diff --git a/compiler/rustc_trait_selection/src/solve/fulfill.rs b/compiler/rustc_trait_selection/src/solve/fulfill.rs
index c6e3ba3c957..081d7a6a769 100644
--- a/compiler/rustc_trait_selection/src/solve/fulfill.rs
+++ b/compiler/rustc_trait_selection/src/solve/fulfill.rs
@@ -12,7 +12,7 @@ use rustc_infer::traits::{
 use rustc_middle::bug;
 use rustc_middle::ty::error::{ExpectedFound, TypeError};
 use rustc_middle::ty::{self, TyCtxt};
-use rustc_next_trait_solver::solve::{GenerateProofTree, SolverDelegateEvalExt as _};
+use rustc_next_trait_solver::solve::{GenerateProofTree, HasChanged, SolverDelegateEvalExt as _};
 use tracing::instrument;
 
 use super::Certainty;
@@ -86,10 +86,7 @@ impl<'tcx> ObligationStorage<'tcx> {
                 let result = <&SolverDelegate<'tcx>>::from(infcx)
                     .evaluate_root_goal(goal, GenerateProofTree::No)
                     .0;
-                match result {
-                    Ok((has_changed, _)) => has_changed,
-                    _ => false,
-                }
+                matches!(result, Ok((HasChanged::Yes, _)))
             }));
         })
     }
@@ -113,7 +110,7 @@ impl<'tcx, E: 'tcx> FulfillmentCtxt<'tcx, E> {
         &self,
         infcx: &InferCtxt<'tcx>,
         obligation: &PredicateObligation<'tcx>,
-        result: &Result<(bool, Certainty), NoSolution>,
+        result: &Result<(HasChanged, Certainty), NoSolution>,
     ) {
         if let Some(inspector) = infcx.obligation_inspector.get() {
             let result = match result {
@@ -181,7 +178,11 @@ where
                         continue;
                     }
                 };
-                has_changed |= changed;
+
+                if changed == HasChanged::Yes {
+                    has_changed = true;
+                }
+
                 match certainty {
                     Certainty::Yes => {}
                     Certainty::Maybe(_) => self.obligations.register(obligation),
diff --git a/compiler/rustc_trait_selection/src/traits/coherence.rs b/compiler/rustc_trait_selection/src/traits/coherence.rs
index 27d2a3c15b9..b29e41beab5 100644
--- a/compiler/rustc_trait_selection/src/traits/coherence.rs
+++ b/compiler/rustc_trait_selection/src/traits/coherence.rs
@@ -19,6 +19,7 @@ use rustc_middle::ty::fast_reject::DeepRejectCtxt;
 use rustc_middle::ty::visit::{TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor};
 use rustc_middle::ty::{self, Ty, TyCtxt};
 pub use rustc_next_trait_solver::coherence::*;
+use rustc_next_trait_solver::solve::SolverDelegateEvalExt;
 use rustc_span::symbol::sym;
 use rustc_span::{DUMMY_SP, Span};
 use tracing::{debug, instrument, warn};
@@ -28,7 +29,7 @@ use crate::error_reporting::traits::suggest_new_overflow_limit;
 use crate::infer::InferOk;
 use crate::infer::outlives::env::OutlivesEnvironment;
 use crate::solve::inspect::{InspectGoal, ProofTreeInferCtxtExt, ProofTreeVisitor};
-use crate::solve::{deeply_normalize_for_diagnostics, inspect};
+use crate::solve::{SolverDelegate, deeply_normalize_for_diagnostics, inspect};
 use crate::traits::query::evaluate_obligation::InferCtxtExt;
 use crate::traits::select::IntercrateAmbiguityCause;
 use crate::traits::{
@@ -333,6 +334,16 @@ fn impl_intersection_has_impossible_obligation<'a, 'cx, 'tcx>(
     let infcx = selcx.infcx;
 
     if infcx.next_trait_solver() {
+        // A fast path optimization, try evaluating all goals with
+        // a very low recursion depth and bail if any of them don't
+        // hold.
+        if !obligations.iter().all(|o| {
+            <&SolverDelegate<'tcx>>::from(infcx)
+                .root_goal_may_hold_with_depth(8, Goal::new(infcx.tcx, o.param_env, o.predicate))
+        }) {
+            return IntersectionHasImpossibleObligations::Yes;
+        }
+
         let ocx = ObligationCtxt::new_with_diagnostics(infcx);
         ocx.register_obligations(obligations.iter().cloned());
         let errors_and_ambiguities = ocx.select_all_or_error();
diff --git a/compiler/rustc_type_ir/src/search_graph/mod.rs b/compiler/rustc_type_ir/src/search_graph/mod.rs
index ac4d0795a92..f4fb03562de 100644
--- a/compiler/rustc_type_ir/src/search_graph/mod.rs
+++ b/compiler/rustc_type_ir/src/search_graph/mod.rs
@@ -81,7 +81,6 @@ pub trait Delegate {
     fn inspect_is_noop(inspect: &mut Self::ProofTreeBuilder) -> bool;
 
     const DIVIDE_AVAILABLE_DEPTH_ON_OVERFLOW: usize;
-    fn recursion_limit(cx: Self::Cx) -> usize;
 
     fn initial_provisional_result(
         cx: Self::Cx,
@@ -156,7 +155,7 @@ impl AvailableDepth {
     /// the remaining depth of all nested goals to prevent hangs
     /// in case there is exponential blowup.
     fn allowed_depth_for_nested<D: Delegate>(
-        cx: D::Cx,
+        root_depth: AvailableDepth,
         stack: &IndexVec<StackDepth, StackEntry<D::Cx>>,
     ) -> Option<AvailableDepth> {
         if let Some(last) = stack.raw.last() {
@@ -170,7 +169,7 @@ impl AvailableDepth {
                 AvailableDepth(last.available_depth.0 - 1)
             })
         } else {
-            Some(AvailableDepth(D::recursion_limit(cx)))
+            Some(root_depth)
         }
     }
 
@@ -360,6 +359,7 @@ struct ProvisionalCacheEntry<X: Cx> {
 
 pub struct SearchGraph<D: Delegate<Cx = X>, X: Cx = <D as Delegate>::Cx> {
     mode: SolverMode,
+    root_depth: AvailableDepth,
     /// The stack of goals currently being computed.
     ///
     /// An element is *deeper* in the stack if its index is *lower*.
@@ -374,9 +374,10 @@ pub struct SearchGraph<D: Delegate<Cx = X>, X: Cx = <D as Delegate>::Cx> {
 }
 
 impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
-    pub fn new(mode: SolverMode) -> SearchGraph<D> {
+    pub fn new(mode: SolverMode, root_depth: usize) -> SearchGraph<D> {
         Self {
             mode,
+            root_depth: AvailableDepth(root_depth),
             stack: Default::default(),
             provisional_cache: Default::default(),
             _marker: PhantomData,
@@ -460,7 +461,8 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
         inspect: &mut D::ProofTreeBuilder,
         mut evaluate_goal: impl FnMut(&mut Self, &mut D::ProofTreeBuilder) -> X::Result,
     ) -> X::Result {
-        let Some(available_depth) = AvailableDepth::allowed_depth_for_nested::<D>(cx, &self.stack)
+        let Some(available_depth) =
+            AvailableDepth::allowed_depth_for_nested::<D>(self.root_depth, &self.stack)
         else {
             return self.handle_overflow(cx, input, inspect);
         };
diff --git a/tests/crashes/124894.rs b/tests/crashes/124894.rs
deleted file mode 100644
index 230cf4a89c1..00000000000
--- a/tests/crashes/124894.rs
+++ /dev/null
@@ -1,11 +0,0 @@
-//@ known-bug: rust-lang/rust#124894
-//@ compile-flags: -Znext-solver=coherence
-
-#![feature(generic_const_exprs)]
-
-pub trait IsTrue<const mem: bool> {}
-impl<T> IsZST for T where (): IsTrue<{ std::mem::size_of::<T>() == 0 }> {}
-
-pub trait IsZST {}
-
-impl IsZST for IsZST {}