about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2025-05-29 12:34:24 +0000
committerMichael Goulet <michael@errs.io>2025-07-17 17:38:23 +0000
commit72bc11d14688599fbcaedf2be9175aa374e1c0d9 (patch)
tree5c046b0b180136252da4661ade3b2813646db1fb
parent96171dc78f23430a28ed8eb6a3879758d3d0d3d5 (diff)
downloadrust-72bc11d14688599fbcaedf2be9175aa374e1c0d9.tar.gz
rust-72bc11d14688599fbcaedf2be9175aa374e1c0d9.zip
Unstall obligations by looking for coroutines in old solver
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs51
-rw-r--r--compiler/rustc_trait_selection/src/solve/fulfill.rs16
-rw-r--r--compiler/rustc_trait_selection/src/traits/fulfill.rs39
3 files changed, 52 insertions, 54 deletions
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
index af1fc045ac8..0b3d50ff219 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
@@ -625,50 +625,23 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         // trigger query cycle ICEs, as doing so requires MIR.
         self.select_obligations_where_possible(|_| {});
 
-        let coroutines = std::mem::take(&mut *self.deferred_coroutine_interiors.borrow_mut());
-        debug!(?coroutines);
-
-        let mut obligations = vec![];
-
-        if !self.next_trait_solver() {
-            for &(coroutine_def_id, interior) in coroutines.iter() {
-                debug!(?coroutine_def_id);
-
-                // Create the `CoroutineWitness` type that we will unify with `interior`.
-                let args = ty::GenericArgs::identity_for_item(
-                    self.tcx,
-                    self.tcx.typeck_root_def_id(coroutine_def_id.to_def_id()),
-                );
-                let witness =
-                    Ty::new_coroutine_witness(self.tcx, coroutine_def_id.to_def_id(), args);
-
-                // Unify `interior` with `witness` and collect all the resulting obligations.
-                let span = self.tcx.hir_body_owned_by(coroutine_def_id).value.span;
-                let ty::Infer(ty::InferTy::TyVar(_)) = interior.kind() else {
-                    span_bug!(span, "coroutine interior witness not infer: {:?}", interior.kind())
-                };
-                let ok = self
-                    .at(&self.misc(span), self.param_env)
-                    // Will never define opaque types, as all we do is instantiate a type variable.
-                    .eq(DefineOpaqueTypes::Yes, interior, witness)
-                    .expect("Failed to unify coroutine interior type");
-
-                obligations.extend(ok.obligations);
-            }
-        }
+        let ty::TypingMode::Analysis { defining_opaque_types_and_generators } = self.typing_mode()
+        else {
+            bug!();
+        };
 
-        if !coroutines.is_empty() {
-            obligations.extend(
+        if defining_opaque_types_and_generators
+            .iter()
+            .any(|def_id| self.tcx.is_coroutine(def_id.to_def_id()))
+        {
+            self.typeck_results.borrow_mut().coroutine_stalled_predicates.extend(
                 self.fulfillment_cx
                     .borrow_mut()
-                    .drain_stalled_obligations_for_coroutines(&self.infcx),
+                    .drain_stalled_obligations_for_coroutines(&self.infcx)
+                    .into_iter()
+                    .map(|o| (o.predicate, o.cause)),
             );
         }
-
-        self.typeck_results
-            .borrow_mut()
-            .coroutine_stalled_predicates
-            .extend(obligations.into_iter().map(|o| (o.predicate, o.cause)));
     }
 
     #[instrument(skip(self), level = "debug")]
diff --git a/compiler/rustc_trait_selection/src/solve/fulfill.rs b/compiler/rustc_trait_selection/src/solve/fulfill.rs
index 72770535b3e..3ce0f025512 100644
--- a/compiler/rustc_trait_selection/src/solve/fulfill.rs
+++ b/compiler/rustc_trait_selection/src/solve/fulfill.rs
@@ -255,7 +255,7 @@ where
         &mut self,
         infcx: &InferCtxt<'tcx>,
     ) -> PredicateObligations<'tcx> {
-        let stalled_generators = match infcx.typing_mode() {
+        let stalled_coroutines = match infcx.typing_mode() {
             TypingMode::Analysis { defining_opaque_types_and_generators } => {
                 defining_opaque_types_and_generators
             }
@@ -265,7 +265,7 @@ where
             | TypingMode::PostAnalysis => return Default::default(),
         };
 
-        if stalled_generators.is_empty() {
+        if stalled_coroutines.is_empty() {
             return Default::default();
         }
 
@@ -276,7 +276,7 @@ where
                         .visit_proof_tree(
                             obl.as_goal(),
                             &mut StalledOnCoroutines {
-                                stalled_generators,
+                                stalled_coroutines,
                                 span: obl.cause.span,
                                 cache: Default::default(),
                             },
@@ -298,10 +298,10 @@ where
 ///
 /// This function can be also return false positives, which will lead to poor diagnostics
 /// so we want to keep this visitor *precise* too.
-struct StalledOnCoroutines<'tcx> {
-    stalled_generators: &'tcx ty::List<LocalDefId>,
-    span: Span,
-    cache: DelayedSet<Ty<'tcx>>,
+pub struct StalledOnCoroutines<'tcx> {
+    pub stalled_coroutines: &'tcx ty::List<LocalDefId>,
+    pub span: Span,
+    pub cache: DelayedSet<Ty<'tcx>>,
 }
 
 impl<'tcx> inspect::ProofTreeVisitor<'tcx> for StalledOnCoroutines<'tcx> {
@@ -331,7 +331,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for StalledOnCoroutines<'tcx> {
         }
 
         if let ty::CoroutineWitness(def_id, _) = *ty.kind()
-            && def_id.as_local().is_some_and(|def_id| self.stalled_generators.contains(&def_id))
+            && def_id.as_local().is_some_and(|def_id| self.stalled_coroutines.contains(&def_id))
         {
             ControlFlow::Break(())
         } else if ty.has_coroutines() {
diff --git a/compiler/rustc_trait_selection/src/traits/fulfill.rs b/compiler/rustc_trait_selection/src/traits/fulfill.rs
index 2b5a41ef5a7..c6c68b1c401 100644
--- a/compiler/rustc_trait_selection/src/traits/fulfill.rs
+++ b/compiler/rustc_trait_selection/src/traits/fulfill.rs
@@ -3,6 +3,7 @@ use std::marker::PhantomData;
 use rustc_data_structures::obligation_forest::{
     Error, ForestObligation, ObligationForest, ObligationProcessor, Outcome, ProcessResult,
 };
+use rustc_hir::def_id::LocalDefId;
 use rustc_infer::infer::DefineOpaqueTypes;
 use rustc_infer::traits::{
     FromSolverError, PolyTraitObligation, PredicateObligations, ProjectionCacheKey, SelectionError,
@@ -12,8 +13,9 @@ use rustc_middle::bug;
 use rustc_middle::ty::abstract_const::NotConstEvaluatable;
 use rustc_middle::ty::error::{ExpectedFound, TypeError};
 use rustc_middle::ty::{
-    self, Binder, Const, GenericArgsRef, TypeVisitableExt, TypingMode, may_use_unstable_feature,
+    self, Binder, Const, GenericArgsRef, TypeVisitable, TypeVisitableExt, TypingMode,
 };
+use rustc_span::DUMMY_SP;
 use thin_vec::{ThinVec, thin_vec};
 use tracing::{debug, debug_span, instrument};
 
@@ -26,6 +28,7 @@ use super::{
 };
 use crate::error_reporting::InferCtxtErrorExt;
 use crate::infer::{InferCtxt, TyOrConstInferVar};
+use crate::solve::StalledOnCoroutines;
 use crate::traits::normalize::normalize_with_depth_to;
 use crate::traits::project::{PolyProjectionObligation, ProjectionCacheKeyExt as _};
 use crate::traits::query::evaluate_obligation::InferCtxtExt;
@@ -168,8 +171,25 @@ where
         &mut self,
         infcx: &InferCtxt<'tcx>,
     ) -> PredicateObligations<'tcx> {
-        let mut processor =
-            DrainProcessor { removed_predicates: PredicateObligations::new(), infcx };
+        let stalled_coroutines = match infcx.typing_mode() {
+            TypingMode::Analysis { defining_opaque_types_and_generators } => {
+                defining_opaque_types_and_generators
+            }
+            TypingMode::Coherence
+            | TypingMode::Borrowck { defining_opaque_types: _ }
+            | TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ }
+            | TypingMode::PostAnalysis => return Default::default(),
+        };
+
+        if stalled_coroutines.is_empty() {
+            return Default::default();
+        }
+
+        let mut processor = DrainProcessor {
+            infcx,
+            removed_predicates: PredicateObligations::new(),
+            stalled_coroutines,
+        };
         let outcome: Outcome<_, _> = self.predicates.process_obligations(&mut processor);
         assert!(outcome.errors.is_empty());
         return processor.removed_predicates;
@@ -177,6 +197,7 @@ where
         struct DrainProcessor<'a, 'tcx> {
             infcx: &'a InferCtxt<'tcx>,
             removed_predicates: PredicateObligations<'tcx>,
+            stalled_coroutines: &'tcx ty::List<LocalDefId>,
         }
 
         impl<'tcx> ObligationProcessor for DrainProcessor<'_, 'tcx> {
@@ -185,10 +206,14 @@ where
             type OUT = Outcome<Self::Obligation, Self::Error>;
 
             fn needs_process_obligation(&self, pending_obligation: &Self::Obligation) -> bool {
-                pending_obligation
-                    .stalled_on
-                    .iter()
-                    .any(|&var| self.infcx.ty_or_const_infer_var_changed(var))
+                self.infcx
+                    .resolve_vars_if_possible(pending_obligation.obligation.predicate)
+                    .visit_with(&mut StalledOnCoroutines {
+                        stalled_coroutines: self.stalled_coroutines,
+                        span: DUMMY_SP,
+                        cache: Default::default(),
+                    })
+                    .is_break()
             }
 
             fn process_obligation(