about summary refs log tree commit diff
diff options
context:
space:
mode:
authorGuillaume Gomez <guillaume1.gomez@gmail.com>2024-03-27 10:13:43 +0100
committerGitHub <noreply@github.com>2024-03-27 10:13:43 +0100
commitbffeb052d1de036e6cc11db1c80b7b07e50d90d8 (patch)
treeba3ebc5f9317712f54c893f47a8855f1ec63161a
parent8a7f285cbc6a6878a359572d751b665a316f3b16 (diff)
parent9bda9ac76e8d232c6cf0efde55dace718c1d428c (diff)
downloadrust-bffeb052d1de036e6cc11db1c80b7b07e50d90d8.tar.gz
rust-bffeb052d1de036e6cc11db1c80b7b07e50d90d8.zip
Rollup merge of #123021 - compiler-errors:coroutine-layout-lol, r=oli-obk
Make `TyCtxt::coroutine_layout` take coroutine's kind parameter

For coroutines that come from coroutine-closures (i.e. async closures), we may have two kinds of bodies stored in the coroutine; one that takes the closure's captures by reference, and one that takes the captures by move.

These currently have identical layouts, but if we do any optimization for these layouts that are related to the upvars, then they will diverge -- e.g. https://github.com/rust-lang/rust/pull/120168#discussion_r1536943728.

This PR relaxes the assertion I added in #121122, and instead make the `TyCtxt::coroutine_layout` method take the `coroutine_kind_ty` argument from the coroutine, which will allow us to differentiate these by-move and by-ref bodies.
-rw-r--r--compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs3
-rw-r--r--compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs8
-rw-r--r--compiler/rustc_const_eval/src/transform/validate.rs22
-rw-r--r--compiler/rustc_middle/src/mir/mod.rs3
-rw-r--r--compiler/rustc_middle/src/mir/pretty.rs2
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs37
-rw-r--r--compiler/rustc_middle/src/ty/sty.rs5
-rw-r--r--compiler/rustc_ty_utils/src/layout.rs4
8 files changed, 61 insertions, 23 deletions
diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
index 4792b0798df..4edef14422e 100644
--- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
+++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
@@ -683,7 +683,8 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
         _ => unreachable!(),
     };
 
-    let coroutine_layout = cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();
+    let coroutine_layout =
+        cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();
 
     let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
     let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
index 3dbe820b8ff..115d5187eaf 100644
--- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
+++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
@@ -135,7 +135,7 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
     unique_type_id: UniqueTypeId<'tcx>,
 ) -> DINodeCreationResult<'ll> {
     let coroutine_type = unique_type_id.expect_ty();
-    let &ty::Coroutine(coroutine_def_id, _) = coroutine_type.kind() else {
+    let &ty::Coroutine(coroutine_def_id, coroutine_args) = coroutine_type.kind() else {
         bug!("build_coroutine_di_node() called with non-coroutine type: `{:?}`", coroutine_type)
     };
 
@@ -158,8 +158,10 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
             DIFlags::FlagZero,
         ),
         |cx, coroutine_type_di_node| {
-            let coroutine_layout =
-                cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();
+            let coroutine_layout = cx
+                .tcx
+                .coroutine_layout(coroutine_def_id, coroutine_args.as_coroutine().kind_ty())
+                .unwrap();
 
             let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } =
                 coroutine_type_and_layout.variants
diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs
index 08e3e42a82e..378b168a50c 100644
--- a/compiler/rustc_const_eval/src/transform/validate.rs
+++ b/compiler/rustc_const_eval/src/transform/validate.rs
@@ -101,18 +101,17 @@ impl<'tcx> MirPass<'tcx> for Validator {
         }
 
         // Enforce that coroutine-closure layouts are identical.
-        if let Some(layout) = body.coroutine_layout()
+        if let Some(layout) = body.coroutine_layout_raw()
             && let Some(by_move_body) = body.coroutine_by_move_body()
-            && let Some(by_move_layout) = by_move_body.coroutine_layout()
+            && let Some(by_move_layout) = by_move_body.coroutine_layout_raw()
         {
-            if layout != by_move_layout {
-                // If this turns out not to be true, please let compiler-errors know.
-                // It is possible to support, but requires some changes to the layout
-                // computation code.
+            // FIXME(async_closures): We could do other validation here?
+            if layout.variant_fields.len() != by_move_layout.variant_fields.len() {
                 cfg_checker.fail(
                     Location::START,
                     format!(
-                        "Coroutine layout differs from by-move coroutine layout:\n\
+                        "Coroutine layout has different number of variant fields from \
+                        by-move coroutine layout:\n\
                         layout: {layout:#?}\n\
                         by_move_layout: {by_move_layout:#?}",
                     ),
@@ -715,13 +714,14 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
                             // args of the coroutine. Otherwise, we prefer to use this body
                             // since we may be in the process of computing this MIR in the
                             // first place.
-                            let gen_body = if def_id == self.caller_body.source.def_id() {
-                                self.caller_body
+                            let layout = if def_id == self.caller_body.source.def_id() {
+                                // FIXME: This is not right for async closures.
+                                self.caller_body.coroutine_layout_raw()
                             } else {
-                                self.tcx.optimized_mir(def_id)
+                                self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty())
                             };
 
-                            let Some(layout) = gen_body.coroutine_layout() else {
+                            let Some(layout) = layout else {
                                 self.fail(
                                     location,
                                     format!("No coroutine layout for {parent_ty:?}"),
diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs
index e4dce2bdc9e..02af55fbf0e 100644
--- a/compiler/rustc_middle/src/mir/mod.rs
+++ b/compiler/rustc_middle/src/mir/mod.rs
@@ -652,8 +652,9 @@ impl<'tcx> Body<'tcx> {
         self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty)
     }
 
+    /// Prefer going through [`TyCtxt::coroutine_layout`] rather than using this directly.
     #[inline]
-    pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> {
+    pub fn coroutine_layout_raw(&self) -> Option<&CoroutineLayout<'tcx>> {
         self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref())
     }
 
diff --git a/compiler/rustc_middle/src/mir/pretty.rs b/compiler/rustc_middle/src/mir/pretty.rs
index f0499cf344f..41df2e3b587 100644
--- a/compiler/rustc_middle/src/mir/pretty.rs
+++ b/compiler/rustc_middle/src/mir/pretty.rs
@@ -126,7 +126,7 @@ fn dump_matched_mir_node<'tcx, F>(
             Some(promoted) => write!(file, "::{promoted:?}`")?,
         }
         writeln!(file, " {disambiguator} {pass_name}")?;
-        if let Some(ref layout) = body.coroutine_layout() {
+        if let Some(ref layout) = body.coroutine_layout_raw() {
             writeln!(file, "/* coroutine_layout = {layout:#?} */")?;
         }
         writeln!(file)?;
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 6ce53ccc8cd..aad2f6a4cf8 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -60,6 +60,7 @@ pub use rustc_target::abi::{ReprFlags, ReprOptions};
 pub use rustc_type_ir::{DebugWithInfcx, InferCtxtLike, WithInfcx};
 pub use vtable::*;
 
+use std::assert_matches::assert_matches;
 use std::fmt::Debug;
 use std::hash::{Hash, Hasher};
 use std::marker::PhantomData;
@@ -1826,8 +1827,40 @@ impl<'tcx> TyCtxt<'tcx> {
 
     /// Returns layout of a coroutine. Layout might be unavailable if the
     /// coroutine is tainted by errors.
-    pub fn coroutine_layout(self, def_id: DefId) -> Option<&'tcx CoroutineLayout<'tcx>> {
-        self.optimized_mir(def_id).coroutine_layout()
+    ///
+    /// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
+    /// e.g. `args.as_coroutine().kind_ty()`.
+    pub fn coroutine_layout(
+        self,
+        def_id: DefId,
+        coroutine_kind_ty: Ty<'tcx>,
+    ) -> Option<&'tcx CoroutineLayout<'tcx>> {
+        let mir = self.optimized_mir(def_id);
+        // Regular coroutine
+        if coroutine_kind_ty.is_unit() {
+            mir.coroutine_layout_raw()
+        } else {
+            // If we have a `Coroutine` that comes from an coroutine-closure,
+            // then it may be a by-move or by-ref body.
+            let ty::Coroutine(_, identity_args) =
+                *self.type_of(def_id).instantiate_identity().kind()
+            else {
+                unreachable!();
+            };
+            let identity_kind_ty = identity_args.as_coroutine().kind_ty();
+            // If the types differ, then we must be getting the by-move body of
+            // a by-ref coroutine.
+            if identity_kind_ty == coroutine_kind_ty {
+                mir.coroutine_layout_raw()
+            } else {
+                assert_matches!(coroutine_kind_ty.to_opt_closure_kind(), Some(ClosureKind::FnOnce));
+                assert_matches!(
+                    identity_kind_ty.to_opt_closure_kind(),
+                    Some(ClosureKind::Fn | ClosureKind::FnMut)
+                );
+                mir.coroutine_by_move_body().unwrap().coroutine_layout_raw()
+            }
+        }
     }
 
     /// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.
diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs
index c85ee140fa4..510a4b59520 100644
--- a/compiler/rustc_middle/src/ty/sty.rs
+++ b/compiler/rustc_middle/src/ty/sty.rs
@@ -694,7 +694,8 @@ impl<'tcx> CoroutineArgs<'tcx> {
     #[inline]
     pub fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range<VariantIdx> {
         // FIXME requires optimized MIR
-        FIRST_VARIANT..tcx.coroutine_layout(def_id).unwrap().variant_fields.next_index()
+        FIRST_VARIANT
+            ..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index()
     }
 
     /// The discriminant for the given variant. Panics if the `variant_index` is
@@ -754,7 +755,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
         def_id: DefId,
         tcx: TyCtxt<'tcx>,
     ) -> impl Iterator<Item: Iterator<Item = Ty<'tcx>> + Captures<'tcx>> {
-        let layout = tcx.coroutine_layout(def_id).unwrap();
+        let layout = tcx.coroutine_layout(def_id, self.kind_ty()).unwrap();
         layout.variant_fields.iter().map(move |variant| {
             variant.iter().map(move |field| {
                 ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)
diff --git a/compiler/rustc_ty_utils/src/layout.rs b/compiler/rustc_ty_utils/src/layout.rs
index 9c3d39307b2..331970ac362 100644
--- a/compiler/rustc_ty_utils/src/layout.rs
+++ b/compiler/rustc_ty_utils/src/layout.rs
@@ -745,7 +745,7 @@ fn coroutine_layout<'tcx>(
     let tcx = cx.tcx;
     let instantiate_field = |ty: Ty<'tcx>| EarlyBinder::bind(ty).instantiate(tcx, args);
 
-    let Some(info) = tcx.coroutine_layout(def_id) else {
+    let Some(info) = tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) else {
         return Err(error(cx, LayoutError::Unknown(ty)));
     };
     let (ineligible_locals, assignments) = coroutine_saved_local_eligibility(info);
@@ -1072,7 +1072,7 @@ fn variant_info_for_coroutine<'tcx>(
         return (vec![], None);
     };
 
-    let coroutine = cx.tcx.optimized_mir(def_id).coroutine_layout().unwrap();
+    let coroutine = cx.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()).unwrap();
     let upvar_names = cx.tcx.closure_saved_names_of_captured_variables(def_id);
 
     let mut upvars_size = Size::ZERO;