about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2025-07-09 23:57:41 +0000
committerMichael Goulet <michael@errs.io>2025-07-15 16:02:26 +0000
commit3e7dfaa51027d33f3613ed0203df495784b21f19 (patch)
tree63a90b592fbc8232615a216e994ea651091df32a
parent7976ca2442db9b69a7949b5fd6afcf8894441afa (diff)
downloadrust-3e7dfaa51027d33f3613ed0203df495784b21f19.tar.gz
rust-3e7dfaa51027d33f3613ed0203df495784b21f19.zip
Deduce outlives obligations from WF of coroutine interior types
-rw-r--r--compiler/rustc_middle/src/query/mod.rs2
-rw-r--r--compiler/rustc_middle/src/ty/context.rs11
-rw-r--r--compiler/rustc_traits/src/coroutine_witnesses.rs57
3 files changed, 67 insertions, 3 deletions
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index 935cc889574..f605da24bc4 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -988,7 +988,7 @@ rustc_queries! {
     }
 
     query coroutine_hidden_types(
-        def_id: DefId
+        def_id: DefId,
     ) -> ty::EarlyBinder<'tcx, ty::Binder<'tcx, ty::CoroutineWitnessTypes<TyCtxt<'tcx>>>> {
         desc { "looking up the hidden types stored across await points in a coroutine" }
     }
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index 45e311e95e4..5a1d8d215d2 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -3112,6 +3112,17 @@ impl<'tcx> TyCtxt<'tcx> {
         T::collect_and_apply(iter, |xs| self.mk_bound_variable_kinds(xs))
     }
 
+    pub fn mk_outlives_from_iter<I, T>(self, iter: I) -> T::Output
+    where
+        I: Iterator<Item = T>,
+        T: CollectAndApply<
+                ty::OutlivesPredicate<'tcx, ty::GenericArg<'tcx>>,
+                &'tcx ty::List<ty::OutlivesPredicate<'tcx, ty::GenericArg<'tcx>>>,
+            >,
+    {
+        T::collect_and_apply(iter, |xs| self.mk_outlives(xs))
+    }
+
     /// Emit a lint at `span` from a lint struct (some type that implements `LintDiagnostic`,
     /// typically generated by `#[derive(LintDiagnostic)]`).
     #[track_caller]
diff --git a/compiler/rustc_traits/src/coroutine_witnesses.rs b/compiler/rustc_traits/src/coroutine_witnesses.rs
index c1573ba7641..d79777437e6 100644
--- a/compiler/rustc_traits/src/coroutine_witnesses.rs
+++ b/compiler/rustc_traits/src/coroutine_witnesses.rs
@@ -1,5 +1,10 @@
 use rustc_hir::def_id::DefId;
-use rustc_middle::ty::{self, TyCtxt, fold_regions};
+use rustc_infer::infer::TyCtxtInferExt;
+use rustc_infer::infer::canonical::query_response::make_query_region_constraints;
+use rustc_infer::infer::resolve::OpportunisticRegionResolver;
+use rustc_infer::traits::{Obligation, ObligationCause};
+use rustc_middle::ty::{self, Ty, TyCtxt, TypeFoldable, TypeVisitableExt, fold_regions};
+use rustc_trait_selection::traits::{ObligationCtxt, with_replaced_escaping_bound_vars};
 
 /// Return the set of types that should be taken into account when checking
 /// trait bounds on a coroutine's internal state. This properly replaces
@@ -29,8 +34,56 @@ pub(crate) fn coroutine_hidden_types<'tcx>(
                 ty
             }),
     );
+
+    let assumptions = compute_assumptions(tcx, def_id, bound_tys);
+
     ty::EarlyBinder::bind(ty::Binder::bind_with_vars(
-        ty::CoroutineWitnessTypes { types: bound_tys, assumptions: ty::List::empty() },
+        ty::CoroutineWitnessTypes { types: bound_tys, assumptions },
         tcx.mk_bound_variable_kinds(&vars),
     ))
 }
+
+fn compute_assumptions<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    def_id: DefId,
+    bound_tys: &'tcx ty::List<Ty<'tcx>>,
+) -> &'tcx ty::List<ty::OutlivesPredicate<'tcx, ty::GenericArg<'tcx>>> {
+    let infcx = tcx.infer_ctxt().build(ty::TypingMode::Analysis {
+        defining_opaque_types_and_generators: ty::List::empty(),
+    });
+    with_replaced_escaping_bound_vars(&infcx, &mut vec![None], bound_tys, |bound_tys| {
+        let param_env = tcx.param_env(def_id);
+        let ocx = ObligationCtxt::new(&infcx);
+
+        ocx.register_obligations(bound_tys.iter().map(|ty| {
+            Obligation::new(
+                tcx,
+                ObligationCause::dummy(),
+                param_env,
+                ty::ClauseKind::WellFormed(ty.into()),
+            )
+        }));
+        let _errors = ocx.select_all_or_error();
+
+        let region_obligations = infcx.take_registered_region_obligations();
+        let region_constraints = infcx.take_and_reset_region_constraints();
+
+        let outlives = make_query_region_constraints(
+            tcx,
+            region_obligations,
+            &region_constraints,
+        )
+        .outlives
+        .fold_with(&mut OpportunisticRegionResolver::new(&infcx));
+
+        tcx.mk_outlives_from_iter(
+            outlives
+                .into_iter()
+                .map(|(o, _)| o)
+                // FIXME(higher_ranked_auto): We probably should deeply resolve these before
+                // filtering out infers which only correspond to unconstrained infer regions
+                // which we can sometimes get.
+                .filter(|o| !o.has_infer()),
+        )
+    })
+}