about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs3
-rw-r--r--compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs8
-rw-r--r--compiler/rustc_const_eval/src/transform/validate.rs22
-rw-r--r--compiler/rustc_middle/src/mir/mod.rs3
-rw-r--r--compiler/rustc_middle/src/mir/pretty.rs2
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs37
-rw-r--r--compiler/rustc_middle/src/ty/sty.rs5
-rw-r--r--compiler/rustc_ty_utils/src/layout.rs4
8 files changed, 61 insertions, 23 deletions
diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
index 4792b0798df..4edef14422e 100644
--- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
+++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs
@@ -683,7 +683,8 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
         _ => unreachable!(),
     };
 
-    let coroutine_layout = cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();
+    let coroutine_layout =
+        cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();
 
     let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
     let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
diff --git a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
index 3dbe820b8ff..115d5187eaf 100644
--- a/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
+++ b/compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs
@@ -135,7 +135,7 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
     unique_type_id: UniqueTypeId<'tcx>,
 ) -> DINodeCreationResult<'ll> {
     let coroutine_type = unique_type_id.expect_ty();
-    let &ty::Coroutine(coroutine_def_id, _) = coroutine_type.kind() else {
+    let &ty::Coroutine(coroutine_def_id, coroutine_args) = coroutine_type.kind() else {
         bug!("build_coroutine_di_node() called with non-coroutine type: `{:?}`", coroutine_type)
     };
 
@@ -158,8 +158,10 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
             DIFlags::FlagZero,
         ),
         |cx, coroutine_type_di_node| {
-            let coroutine_layout =
-                cx.tcx.optimized_mir(coroutine_def_id).coroutine_layout().unwrap();
+            let coroutine_layout = cx
+                .tcx
+                .coroutine_layout(coroutine_def_id, coroutine_args.as_coroutine().kind_ty())
+                .unwrap();
 
             let Variants::Multiple { tag_encoding: TagEncoding::Direct, ref variants, .. } =
                 coroutine_type_and_layout.variants
diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs
index 08e3e42a82e..378b168a50c 100644
--- a/compiler/rustc_const_eval/src/transform/validate.rs
+++ b/compiler/rustc_const_eval/src/transform/validate.rs
@@ -101,18 +101,17 @@ impl<'tcx> MirPass<'tcx> for Validator {
         }
 
         // Enforce that coroutine-closure layouts are identical.
-        if let Some(layout) = body.coroutine_layout()
+        if let Some(layout) = body.coroutine_layout_raw()
             && let Some(by_move_body) = body.coroutine_by_move_body()
-            && let Some(by_move_layout) = by_move_body.coroutine_layout()
+            && let Some(by_move_layout) = by_move_body.coroutine_layout_raw()
         {
-            if layout != by_move_layout {
-                // If this turns out not to be true, please let compiler-errors know.
-                // It is possible to support, but requires some changes to the layout
-                // computation code.
+            // FIXME(async_closures): We could do other validation here?
+            if layout.variant_fields.len() != by_move_layout.variant_fields.len() {
                 cfg_checker.fail(
                     Location::START,
                     format!(
-                        "Coroutine layout differs from by-move coroutine layout:\n\
+                        "Coroutine layout has different number of variant fields from \
+                        by-move coroutine layout:\n\
                         layout: {layout:#?}\n\
                         by_move_layout: {by_move_layout:#?}",
                     ),
@@ -715,13 +714,14 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
                             // args of the coroutine. Otherwise, we prefer to use this body
                             // since we may be in the process of computing this MIR in the
                             // first place.
-                            let gen_body = if def_id == self.caller_body.source.def_id() {
-                                self.caller_body
+                            let layout = if def_id == self.caller_body.source.def_id() {
+                                // FIXME: This is not right for async closures.
+                                self.caller_body.coroutine_layout_raw()
                             } else {
-                                self.tcx.optimized_mir(def_id)
+                                self.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty())
                             };
 
-                            let Some(layout) = gen_body.coroutine_layout() else {
+                            let Some(layout) = layout else {
                                 self.fail(
                                     location,
                                     format!("No coroutine layout for {parent_ty:?}"),
diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs
index e4dce2bdc9e..02af55fbf0e 100644
--- a/compiler/rustc_middle/src/mir/mod.rs
+++ b/compiler/rustc_middle/src/mir/mod.rs
@@ -652,8 +652,9 @@ impl<'tcx> Body<'tcx> {
         self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty)
     }
 
+    /// Prefer going through [`TyCtxt::coroutine_layout`] rather than using this directly.
     #[inline]
-    pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> {
+    pub fn coroutine_layout_raw(&self) -> Option<&CoroutineLayout<'tcx>> {
         self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref())
     }
 
diff --git a/compiler/rustc_middle/src/mir/pretty.rs b/compiler/rustc_middle/src/mir/pretty.rs
index f0499cf344f..41df2e3b587 100644
--- a/compiler/rustc_middle/src/mir/pretty.rs
+++ b/compiler/rustc_middle/src/mir/pretty.rs
@@ -126,7 +126,7 @@ fn dump_matched_mir_node<'tcx, F>(
             Some(promoted) => write!(file, "::{promoted:?}`")?,
         }
         writeln!(file, " {disambiguator} {pass_name}")?;
-        if let Some(ref layout) = body.coroutine_layout() {
+        if let Some(ref layout) = body.coroutine_layout_raw() {
             writeln!(file, "/* coroutine_layout = {layout:#?} */")?;
         }
         writeln!(file)?;
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 6ce53ccc8cd..aad2f6a4cf8 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -60,6 +60,7 @@ pub use rustc_target::abi::{ReprFlags, ReprOptions};
 pub use rustc_type_ir::{DebugWithInfcx, InferCtxtLike, WithInfcx};
 pub use vtable::*;
 
+use std::assert_matches::assert_matches;
 use std::fmt::Debug;
 use std::hash::{Hash, Hasher};
 use std::marker::PhantomData;
@@ -1826,8 +1827,40 @@ impl<'tcx> TyCtxt<'tcx> {
 
     /// Returns layout of a coroutine. Layout might be unavailable if the
     /// coroutine is tainted by errors.
-    pub fn coroutine_layout(self, def_id: DefId) -> Option<&'tcx CoroutineLayout<'tcx>> {
-        self.optimized_mir(def_id).coroutine_layout()
+    ///
+    /// Takes `coroutine_kind` which can be acquired from the `CoroutineArgs::kind_ty`,
+    /// e.g. `args.as_coroutine().kind_ty()`.
+    pub fn coroutine_layout(
+        self,
+        def_id: DefId,
+        coroutine_kind_ty: Ty<'tcx>,
+    ) -> Option<&'tcx CoroutineLayout<'tcx>> {
+        let mir = self.optimized_mir(def_id);
+        // Regular coroutine
+        if coroutine_kind_ty.is_unit() {
+            mir.coroutine_layout_raw()
+        } else {
+            // If we have a `Coroutine` that comes from an coroutine-closure,
+            // then it may be a by-move or by-ref body.
+            let ty::Coroutine(_, identity_args) =
+                *self.type_of(def_id).instantiate_identity().kind()
+            else {
+                unreachable!();
+            };
+            let identity_kind_ty = identity_args.as_coroutine().kind_ty();
+            // If the types differ, then we must be getting the by-move body of
+            // a by-ref coroutine.
+            if identity_kind_ty == coroutine_kind_ty {
+                mir.coroutine_layout_raw()
+            } else {
+                assert_matches!(coroutine_kind_ty.to_opt_closure_kind(), Some(ClosureKind::FnOnce));
+                assert_matches!(
+                    identity_kind_ty.to_opt_closure_kind(),
+                    Some(ClosureKind::Fn | ClosureKind::FnMut)
+                );
+                mir.coroutine_by_move_body().unwrap().coroutine_layout_raw()
+            }
+        }
     }
 
     /// Given the `DefId` of an impl, returns the `DefId` of the trait it implements.
diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs
index c85ee140fa4..510a4b59520 100644
--- a/compiler/rustc_middle/src/ty/sty.rs
+++ b/compiler/rustc_middle/src/ty/sty.rs
@@ -694,7 +694,8 @@ impl<'tcx> CoroutineArgs<'tcx> {
     #[inline]
     pub fn variant_range(&self, def_id: DefId, tcx: TyCtxt<'tcx>) -> Range<VariantIdx> {
         // FIXME requires optimized MIR
-        FIRST_VARIANT..tcx.coroutine_layout(def_id).unwrap().variant_fields.next_index()
+        FIRST_VARIANT
+            ..tcx.coroutine_layout(def_id, tcx.types.unit).unwrap().variant_fields.next_index()
     }
 
     /// The discriminant for the given variant. Panics if the `variant_index` is
@@ -754,7 +755,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
         def_id: DefId,
         tcx: TyCtxt<'tcx>,
     ) -> impl Iterator<Item: Iterator<Item = Ty<'tcx>> + Captures<'tcx>> {
-        let layout = tcx.coroutine_layout(def_id).unwrap();
+        let layout = tcx.coroutine_layout(def_id, self.kind_ty()).unwrap();
         layout.variant_fields.iter().map(move |variant| {
             variant.iter().map(move |field| {
                 ty::EarlyBinder::bind(layout.field_tys[*field].ty).instantiate(tcx, self.args)
diff --git a/compiler/rustc_ty_utils/src/layout.rs b/compiler/rustc_ty_utils/src/layout.rs
index 9c3d39307b2..331970ac362 100644
--- a/compiler/rustc_ty_utils/src/layout.rs
+++ b/compiler/rustc_ty_utils/src/layout.rs
@@ -745,7 +745,7 @@ fn coroutine_layout<'tcx>(
     let tcx = cx.tcx;
     let instantiate_field = |ty: Ty<'tcx>| EarlyBinder::bind(ty).instantiate(tcx, args);
 
-    let Some(info) = tcx.coroutine_layout(def_id) else {
+    let Some(info) = tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) else {
         return Err(error(cx, LayoutError::Unknown(ty)));
     };
     let (ineligible_locals, assignments) = coroutine_saved_local_eligibility(info);
@@ -1072,7 +1072,7 @@ fn variant_info_for_coroutine<'tcx>(
         return (vec![], None);
     };
 
-    let coroutine = cx.tcx.optimized_mir(def_id).coroutine_layout().unwrap();
+    let coroutine = cx.tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()).unwrap();
     let upvar_names = cx.tcx.closure_saved_names_of_captured_variables(def_id);
 
     let mut upvars_size = Size::ZERO;