about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2025-04-22 23:31:22 +0000
committerMichael Goulet <michael@errs.io>2025-04-23 15:09:25 +0000
commitf943f73db4791d64ff83d72986da8d6250c42933 (patch)
tree4474dcc7ba139e0d9f70a7754c31694fcf7186ee /compiler/rustc_trait_selection
parent7c1661f9457825df6e6bbf4869be3cad59b608a9 (diff)
downloadrust-f943f73db4791d64ff83d72986da8d6250c42933.tar.gz
rust-f943f73db4791d64ff83d72986da8d6250c42933.zip
More
Diffstat (limited to 'compiler/rustc_trait_selection')
-rw-r--r--compiler/rustc_trait_selection/src/solve/fulfill.rs47
-rw-r--r--compiler/rustc_trait_selection/src/solve/normalize.rs42
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/mod.rs4
3 files changed, 62 insertions, 31 deletions
diff --git a/compiler/rustc_trait_selection/src/solve/fulfill.rs b/compiler/rustc_trait_selection/src/solve/fulfill.rs
index abee5ac52c1..848d0646d00 100644
--- a/compiler/rustc_trait_selection/src/solve/fulfill.rs
+++ b/compiler/rustc_trait_selection/src/solve/fulfill.rs
@@ -14,6 +14,7 @@ use rustc_middle::ty::{
 };
 use rustc_next_trait_solver::solve::{GenerateProofTree, HasChanged, SolverDelegateEvalExt as _};
 use rustc_span::Span;
+use rustc_type_ir::data_structures::DelayedSet;
 use tracing::instrument;
 
 use self::derive_errors::*;
@@ -217,26 +218,30 @@ where
         &mut self,
         infcx: &InferCtxt<'tcx>,
     ) -> PredicateObligations<'tcx> {
-        self.obligations.drain_pending(|obl| {
-            let stalled_generators = match infcx.typing_mode() {
-                TypingMode::Analysis { defining_opaque_types: _, stalled_generators } => {
-                    stalled_generators
-                }
-                TypingMode::Coherence
-                | TypingMode::Borrowck { defining_opaque_types: _ }
-                | TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ }
-                | TypingMode::PostAnalysis => return false,
-            };
-
-            if stalled_generators.is_empty() {
-                return false;
+        let stalled_generators = match infcx.typing_mode() {
+            TypingMode::Analysis { defining_opaque_types_and_generators } => {
+                defining_opaque_types_and_generators
             }
+            TypingMode::Coherence
+            | TypingMode::Borrowck { defining_opaque_types: _ }
+            | TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ }
+            | TypingMode::PostAnalysis => return Default::default(),
+        };
+
+        if stalled_generators.is_empty() {
+            return Default::default();
+        }
 
+        self.obligations.drain_pending(|obl| {
             infcx.probe(|_| {
                 infcx
                     .visit_proof_tree(
                         obl.as_goal(),
-                        &mut StalledOnCoroutines { stalled_generators, span: obl.cause.span },
+                        &mut StalledOnCoroutines {
+                            stalled_generators,
+                            span: obl.cause.span,
+                            cache: Default::default(),
+                        },
                     )
                     .is_break()
             })
@@ -244,10 +249,18 @@ where
     }
 }
 
+/// Detect if a goal is stalled on a coroutine that is owned by the current typeck root.
+///
+/// This function can (erroneously) fail to detect a predicate, i.e. it doesn't need to
+/// be complete. However, this will lead to ambiguity errors, so we want to make it
+/// accurate.
+///
+/// This function can be also return false positives, which will lead to poor diagnostics
+/// so we want to keep this visitor *precise* too.
 struct StalledOnCoroutines<'tcx> {
     stalled_generators: &'tcx ty::List<LocalDefId>,
     span: Span,
-    // TODO: Cache
+    cache: DelayedSet<Ty<'tcx>>,
 }
 
 impl<'tcx> inspect::ProofTreeVisitor<'tcx> for StalledOnCoroutines<'tcx> {
@@ -272,6 +285,10 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for StalledOnCoroutines<'tcx> {
     type Result = ControlFlow<()>;
 
     fn visit_ty(&mut self, ty: Ty<'tcx>) -> Self::Result {
+        if !self.cache.insert(ty) {
+            return ControlFlow::Continue(());
+        }
+
         if let ty::CoroutineWitness(def_id, _) = *ty.kind()
             && def_id.as_local().is_some_and(|def_id| self.stalled_generators.contains(&def_id))
         {
diff --git a/compiler/rustc_trait_selection/src/solve/normalize.rs b/compiler/rustc_trait_selection/src/solve/normalize.rs
index 65ab14ae07c..5f1e63ab225 100644
--- a/compiler/rustc_trait_selection/src/solve/normalize.rs
+++ b/compiler/rustc_trait_selection/src/solve/normalize.rs
@@ -1,6 +1,5 @@
 use std::assert_matches::assert_matches;
 use std::fmt::Debug;
-use std::marker::PhantomData;
 
 use rustc_data_structures::stack::ensure_sufficient_stack;
 use rustc_infer::infer::InferCtxt;
@@ -60,7 +59,8 @@ where
 /// entered before passing `value` to the function. This is currently needed for
 /// `normalize_erasing_regions`, which skips binders as it walks through a type.
 ///
-/// TODO: doc
+/// This returns a set of stalled obligations if the typing mode of the underlying infcx
+/// has any stalled coroutine def ids.
 pub fn deeply_normalize_with_skipped_universes_and_ambiguous_goals<'tcx, T, E>(
     at: At<'_, 'tcx>,
     value: T,
@@ -72,16 +72,10 @@ where
 {
     let fulfill_cx = FulfillmentCtxt::new(at.infcx);
     let mut folder =
-        NormalizationFolder { at, fulfill_cx, depth: 0, universes, _errors: PhantomData };
+        NormalizationFolder { at, fulfill_cx, depth: 0, universes, stalled_goals: vec![] };
     let value = value.try_fold_with(&mut folder)?;
-    let goals = folder
-        .fulfill_cx
-        .drain_stalled_obligations_for_coroutines(at.infcx)
-        .into_iter()
-        .map(|obl| obl.as_goal())
-        .collect();
     let errors = folder.fulfill_cx.select_all_or_error(at.infcx);
-    if errors.is_empty() { Ok((value, goals)) } else { Err(errors) }
+    if errors.is_empty() { Ok((value, folder.stalled_goals)) } else { Err(errors) }
 }
 
 struct NormalizationFolder<'me, 'tcx, E> {
@@ -89,7 +83,7 @@ struct NormalizationFolder<'me, 'tcx, E> {
     fulfill_cx: FulfillmentCtxt<'tcx, E>,
     depth: usize,
     universes: Vec<Option<UniverseIndex>>,
-    _errors: PhantomData<E>,
+    stalled_goals: Vec<Goal<'tcx, ty::Predicate<'tcx>>>,
 }
 
 impl<'tcx, E> NormalizationFolder<'_, 'tcx, E>
@@ -130,10 +124,7 @@ where
         );
 
         self.fulfill_cx.register_predicate_obligation(infcx, obligation);
-        let errors = self.fulfill_cx.select_where_possible(infcx);
-        if !errors.is_empty() {
-            return Err(errors);
-        }
+        self.select_all_and_stall_coroutine_predicates()?;
 
         // Alias is guaranteed to be fully structurally resolved,
         // so we can super fold here.
@@ -184,6 +175,27 @@ where
         self.depth -= 1;
         Ok(result)
     }
+
+    fn select_all_and_stall_coroutine_predicates(&mut self) -> Result<(), Vec<E>> {
+        let errors = self.fulfill_cx.select_where_possible(self.at.infcx);
+        if !errors.is_empty() {
+            return Err(errors);
+        }
+
+        self.stalled_goals.extend(
+            self.fulfill_cx
+                .drain_stalled_obligations_for_coroutines(self.at.infcx)
+                .into_iter()
+                .map(|obl| obl.as_goal()),
+        );
+
+        let errors = self.fulfill_cx.collect_remaining_errors(self.at.infcx);
+        if !errors.is_empty() {
+            return Err(errors);
+        }
+
+        Ok(())
+    }
 }
 
 impl<'tcx, E> FallibleTypeFolder<TyCtxt<'tcx>> for NormalizationFolder<'_, 'tcx, E>
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index 5255b57c791..c7ce13c8014 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -1498,7 +1498,9 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
             // However, if we disqualify *all* goals from being cached, perf suffers.
             // This is likely fixed by better caching in general in the new solver.
             // See: <https://github.com/rust-lang/rust/issues/132064>.
-            TypingMode::Analysis { defining_opaque_types, stalled_generators: _ }
+            TypingMode::Analysis {
+                defining_opaque_types_and_generators: defining_opaque_types,
+            }
             | TypingMode::Borrowck { defining_opaque_types } => {
                 defining_opaque_types.is_empty() || !pred.has_opaque_types()
             }