about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-04-15 19:44:23 -0400
committerMichael Goulet <michael@errs.io>2024-04-15 19:44:58 -0400
commitd29178c2efe7846fdb8c442219d7b70ae83e18e7 (patch)
tree89387e3ffde26939db648f1d631aedad80fa3b66 /compiler
parentaa1653e5be13cdcb818ae1f933bec7ad15e33b70 (diff)
downloadrust-d29178c2efe7846fdb8c442219d7b70ae83e18e7.tar.gz
rust-d29178c2efe7846fdb8c442219d7b70ae83e18e7.zip
Do check_coroutine_obligations once per typeck root
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_hir_analysis/src/check/check.rs47
-rw-r--r--compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs8
-rw-r--r--compiler/rustc_hir_typeck/src/writeback.rs13
-rw-r--r--compiler/rustc_interface/src/passes.rs4
-rw-r--r--compiler/rustc_middle/src/ty/typeck_results.rs9
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs40
6 files changed, 63 insertions, 58 deletions
diff --git a/compiler/rustc_hir_analysis/src/check/check.rs b/compiler/rustc_hir_analysis/src/check/check.rs
index ae3623aa329..cb0ed68edff 100644
--- a/compiler/rustc_hir_analysis/src/check/check.rs
+++ b/compiler/rustc_hir_analysis/src/check/check.rs
@@ -10,10 +10,9 @@ use rustc_hir as hir;
 use rustc_hir::def::{CtorKind, DefKind};
 use rustc_hir::Node;
 use rustc_infer::infer::{RegionVariableOrigin, TyCtxtInferExt};
-use rustc_infer::traits::{Obligation, TraitEngineExt as _};
+use rustc_infer::traits::Obligation;
 use rustc_lint_defs::builtin::REPR_TRANSPARENT_EXTERNAL_PRIVATE_FIELDS;
 use rustc_middle::middle::stability::EvalResult;
-use rustc_middle::traits::ObligationCauseCode;
 use rustc_middle::ty::fold::BottomUpFolder;
 use rustc_middle::ty::layout::{LayoutError, MAX_SIMD_LANES};
 use rustc_middle::ty::util::{Discr, InspectCoroutineFields, IntTypeExt};
@@ -26,7 +25,7 @@ use rustc_target::abi::FieldIdx;
 use rustc_trait_selection::traits::error_reporting::on_unimplemented::OnUnimplementedDirective;
 use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
 use rustc_trait_selection::traits::outlives_bounds::InferCtxtExt as _;
-use rustc_trait_selection::traits::{self, TraitEngine, TraitEngineExt as _};
+use rustc_trait_selection::traits::{self};
 use rustc_type_ir::fold::TypeFoldable;
 
 use std::cell::LazyCell;
@@ -1541,55 +1540,31 @@ fn opaque_type_cycle_error(
     err.emit()
 }
 
-// FIXME(@lcnr): This should not be computed per coroutine, but instead once for
-// each typeck root.
 pub(super) fn check_coroutine_obligations(
     tcx: TyCtxt<'_>,
     def_id: LocalDefId,
 ) -> Result<(), ErrorGuaranteed> {
-    debug_assert!(tcx.is_coroutine(def_id.to_def_id()));
+    debug_assert!(!tcx.is_typeck_child(def_id.to_def_id()));
 
-    let typeck = tcx.typeck(def_id);
-    let param_env = tcx.param_env(typeck.hir_owner.def_id);
+    let typeck_results = tcx.typeck(def_id);
+    let param_env = tcx.param_env(def_id);
 
-    let coroutine_stalled_predicates = &typeck.coroutine_stalled_predicates[&def_id];
-    debug!(?coroutine_stalled_predicates);
+    debug!(?typeck_results.coroutine_stalled_predicates);
 
     let infcx = tcx
         .infer_ctxt()
         // typeck writeback gives us predicates with their regions erased.
         // As borrowck already has checked lifetimes, we do not need to do it again.
         .ignoring_regions()
-        // Bind opaque types to type checking root, as they should have been checked by borrowck,
-        // but may show up in some cases, like when (root) obligations are stalled in the new solver.
-        .with_opaque_type_inference(typeck.hir_owner.def_id)
+        .with_opaque_type_inference(def_id)
         .build();
 
-    let mut fulfillment_cx = <dyn TraitEngine<'_>>::new(&infcx);
-    for (predicate, cause) in coroutine_stalled_predicates {
-        let obligation = Obligation::new(tcx, cause.clone(), param_env, *predicate);
-        fulfillment_cx.register_predicate_obligation(&infcx, obligation);
-    }
-
-    if (tcx.features().unsized_locals || tcx.features().unsized_fn_params)
-        && let Some(coroutine) = tcx.mir_coroutine_witnesses(def_id)
-    {
-        for field_ty in coroutine.field_tys.iter() {
-            fulfillment_cx.register_bound(
-                &infcx,
-                param_env,
-                field_ty.ty,
-                tcx.require_lang_item(hir::LangItem::Sized, Some(field_ty.source_info.span)),
-                ObligationCause::new(
-                    field_ty.source_info.span,
-                    def_id,
-                    ObligationCauseCode::SizedCoroutineInterior(def_id),
-                ),
-            );
-        }
+    let ocx = ObligationCtxt::new(&infcx);
+    for (predicate, cause) in &typeck_results.coroutine_stalled_predicates {
+        ocx.register_obligation(Obligation::new(tcx, cause.clone(), param_env, *predicate));
     }
 
-    let errors = fulfillment_cx.select_all_or_error(&infcx);
+    let errors = ocx.select_all_or_error();
     debug!(?errors);
     if !errors.is_empty() {
         return Err(infcx.err_ctxt().report_fulfillment_errors(errors));
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
index 1024dec1d9f..14e4236ad78 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
@@ -575,12 +575,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             obligations
                 .extend(self.fulfillment_cx.borrow_mut().drain_unstalled_obligations(&self.infcx));
 
-            let obligations = obligations.into_iter().map(|o| (o.predicate, o.cause)).collect();
-            debug!(?obligations);
-            self.typeck_results
-                .borrow_mut()
-                .coroutine_stalled_predicates
-                .insert(expr_def_id, obligations);
+            let obligations = obligations.into_iter().map(|o| (o.predicate, o.cause));
+            self.typeck_results.borrow_mut().coroutine_stalled_predicates.extend(obligations);
         }
     }
 
diff --git a/compiler/rustc_hir_typeck/src/writeback.rs b/compiler/rustc_hir_typeck/src/writeback.rs
index 73613923e69..fa2bbc3dbd3 100644
--- a/compiler/rustc_hir_typeck/src/writeback.rs
+++ b/compiler/rustc_hir_typeck/src/writeback.rs
@@ -550,15 +550,10 @@ impl<'cx, 'tcx> WritebackCx<'cx, 'tcx> {
     fn visit_coroutine_interior(&mut self) {
         let fcx_typeck_results = self.fcx.typeck_results.borrow();
         assert_eq!(fcx_typeck_results.hir_owner, self.typeck_results.hir_owner);
-        self.tcx().with_stable_hashing_context(move |ref hcx| {
-            for (&expr_def_id, predicates) in
-                fcx_typeck_results.coroutine_stalled_predicates.to_sorted(hcx, false).into_iter()
-            {
-                let predicates =
-                    self.resolve(predicates.clone(), &self.fcx.tcx.def_span(expr_def_id));
-                self.typeck_results.coroutine_stalled_predicates.insert(expr_def_id, predicates);
-            }
-        })
+        for (predicate, cause) in &fcx_typeck_results.coroutine_stalled_predicates {
+            let (predicate, cause) = self.resolve((*predicate, cause.clone()), &cause.span);
+            self.typeck_results.coroutine_stalled_predicates.insert((predicate, cause));
+        }
     }
 
     #[instrument(skip(self), level = "debug")]
diff --git a/compiler/rustc_interface/src/passes.rs b/compiler/rustc_interface/src/passes.rs
index 91cef02c7d1..b593c41a8ea 100644
--- a/compiler/rustc_interface/src/passes.rs
+++ b/compiler/rustc_interface/src/passes.rs
@@ -759,7 +759,9 @@ fn run_required_analyses(tcx: TyCtxt<'_>) {
     tcx.hir().par_body_owners(|def_id| {
         if tcx.is_coroutine(def_id.to_def_id()) {
             tcx.ensure().mir_coroutine_witnesses(def_id);
-            tcx.ensure().check_coroutine_obligations(def_id);
+            tcx.ensure().check_coroutine_obligations(
+                tcx.typeck_root_def_id(def_id.to_def_id()).expect_local(),
+            );
         }
     });
     sess.time("layout_testing", || layout_test::test_layout(tcx));
diff --git a/compiler/rustc_middle/src/ty/typeck_results.rs b/compiler/rustc_middle/src/ty/typeck_results.rs
index aa905466fbc..904d56c516d 100644
--- a/compiler/rustc_middle/src/ty/typeck_results.rs
+++ b/compiler/rustc_middle/src/ty/typeck_results.rs
@@ -7,10 +7,8 @@ use crate::{
         GenericArgs, GenericArgsRef, Ty, UserArgs,
     },
 };
-use rustc_data_structures::{
-    fx::FxIndexMap,
-    unord::{ExtendUnord, UnordItems, UnordSet},
-};
+use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
+use rustc_data_structures::unord::{ExtendUnord, UnordItems, UnordSet};
 use rustc_errors::ErrorGuaranteed;
 use rustc_hir::{
     self as hir,
@@ -201,8 +199,7 @@ pub struct TypeckResults<'tcx> {
 
     /// Stores the predicates that apply on coroutine witness types.
     /// formatting modified file tests/ui/coroutine/retain-resume-ref.rs
-    pub coroutine_stalled_predicates:
-        LocalDefIdMap<Vec<(ty::Predicate<'tcx>, ObligationCause<'tcx>)>>,
+    pub coroutine_stalled_predicates: FxIndexSet<(ty::Predicate<'tcx>, ObligationCause<'tcx>)>,
 
     /// We sometimes treat byte string literals (which are of type `&[u8; N]`)
     /// as `&[u8]`, depending on the pattern in which they are used.
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index e2a911f0dc7..b745d97567d 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -80,6 +80,10 @@ use rustc_span::symbol::sym;
 use rustc_span::Span;
 use rustc_target::abi::{FieldIdx, VariantIdx};
 use rustc_target::spec::PanicStrategy;
+use rustc_trait_selection::infer::TyCtxtInferExt as _;
+use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
+use rustc_trait_selection::traits::ObligationCtxt;
+use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};
 use std::{iter, ops};
 
 pub struct StateTransform;
@@ -1584,10 +1588,46 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(
     let (_, coroutine_layout, _) = compute_layout(liveness_info, body);
 
     check_suspend_tys(tcx, &coroutine_layout, body);
+    check_field_tys_sized(tcx, &coroutine_layout, def_id);
 
     Some(coroutine_layout)
 }
 
+fn check_field_tys_sized<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    coroutine_layout: &CoroutineLayout<'tcx>,
+    def_id: LocalDefId,
+) {
+    // No need to check if unsized_locals/unsized_fn_params is disabled,
+    // since we will error during typeck.
+    if !tcx.features().unsized_locals && !tcx.features().unsized_fn_params {
+        return;
+    }
+
+    let infcx = tcx.infer_ctxt().ignoring_regions().build();
+    let param_env = tcx.param_env(def_id);
+
+    let ocx = ObligationCtxt::new(&infcx);
+    for field_ty in &coroutine_layout.field_tys {
+        ocx.register_bound(
+            ObligationCause::new(
+                field_ty.source_info.span,
+                def_id,
+                ObligationCauseCode::SizedCoroutineInterior(def_id),
+            ),
+            param_env,
+            field_ty.ty,
+            tcx.require_lang_item(hir::LangItem::Sized, Some(field_ty.source_info.span)),
+        );
+    }
+
+    let errors = ocx.select_all_or_error();
+    debug!(?errors);
+    if !errors.is_empty() {
+        infcx.err_ctxt().report_fulfillment_errors(errors);
+    }
+}
+
 impl<'tcx> MirPass<'tcx> for StateTransform {
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
         let Some(old_yield_ty) = body.yield_ty() else {