about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-01-24 23:38:33 +0000
committerMichael Goulet <michael@errs.io>2024-02-06 02:22:58 +0000
commit427896dd7e39f1aaf3e3cbc15e5ddf77d45a6aec (patch)
tree6590ff48b2d3114622b983912308b2b4516e5531
parentfc4fff40385252212b9921928927568f233ba02f (diff)
downloadrust-427896dd7e39f1aaf3e3cbc15e5ddf77d45a6aec.tar.gz
rust-427896dd7e39f1aaf3e3cbc15e5ddf77d45a6aec.zip
Construct body for by-move coroutine closure output
-rw-r--r--compiler/rustc_borrowck/src/type_check/input_output.rs1
-rw-r--r--compiler/rustc_const_eval/src/interpret/terminator.rs1
-rw-r--r--compiler/rustc_hir_typeck/src/callee.rs1
-rw-r--r--compiler/rustc_hir_typeck/src/closure.rs11
-rw-r--r--compiler/rustc_hir_typeck/src/upvar.rs10
-rw-r--r--compiler/rustc_middle/src/mir/mod.rs5
-rw-r--r--compiler/rustc_middle/src/mir/mono.rs1
-rw-r--r--compiler/rustc_middle/src/mir/visit.rs1
-rw-r--r--compiler/rustc_middle/src/ty/instance.rs21
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs1
-rw-r--r--compiler/rustc_middle/src/ty/sty.rs47
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs3
-rw-r--r--compiler/rustc_mir_transform/src/coroutine/by_move_body.rs108
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs1
-rw-r--r--compiler/rustc_mir_transform/src/inline/cycle.rs1
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs4
-rw-r--r--compiler/rustc_mir_transform/src/pass_manager.rs6
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs12
-rw-r--r--compiler/rustc_monomorphize/src/collector.rs3
-rw-r--r--compiler/rustc_monomorphize/src/partitioning.rs4
-rw-r--r--compiler/rustc_smir/src/rustc_smir/convert/ty.rs1
-rw-r--r--compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs1
-rw-r--r--compiler/rustc_trait_selection/src/traits/project.rs2
-rw-r--r--tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir2
24 files changed, 233 insertions, 15 deletions
diff --git a/compiler/rustc_borrowck/src/type_check/input_output.rs b/compiler/rustc_borrowck/src/type_check/input_output.rs
index a3e5088ee09..ace9c5ae71d 100644
--- a/compiler/rustc_borrowck/src/type_check/input_output.rs
+++ b/compiler/rustc_borrowck/src/type_check/input_output.rs
@@ -85,6 +85,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                     self.tcx(),
                     ty::CoroutineArgsParts {
                         parent_args: args.parent_args(),
+                        kind_ty: Ty::from_closure_kind(self.tcx(), args.kind()),
                         resume_ty: next_ty_var(),
                         yield_ty: next_ty_var(),
                         witness: next_ty_var(),
diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs
index 4c8f68b25b5..b8d6836da14 100644
--- a/compiler/rustc_const_eval/src/interpret/terminator.rs
+++ b/compiler/rustc_const_eval/src/interpret/terminator.rs
@@ -546,6 +546,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
             | ty::InstanceDef::ReifyShim(..)
             | ty::InstanceDef::ClosureOnceShim { .. }
             | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
+            | ty::InstanceDef::CoroutineByMoveShim { .. }
             | ty::InstanceDef::FnPtrShim(..)
             | ty::InstanceDef::DropGlue(..)
             | ty::InstanceDef::CloneShim(..)
diff --git a/compiler/rustc_hir_typeck/src/callee.rs b/compiler/rustc_hir_typeck/src/callee.rs
index 1858b2770cd..730a475f630 100644
--- a/compiler/rustc_hir_typeck/src/callee.rs
+++ b/compiler/rustc_hir_typeck/src/callee.rs
@@ -183,6 +183,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     coroutine_closure_sig.to_coroutine(
                         self.tcx,
                         closure_args.parent_args(),
+                        closure_args.kind_ty(),
                         self.tcx.coroutine_for_closure(def_id),
                         tupled_upvars_ty,
                     ),
diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs
index 1d024efdd49..014293c1f83 100644
--- a/compiler/rustc_hir_typeck/src/closure.rs
+++ b/compiler/rustc_hir_typeck/src/closure.rs
@@ -175,10 +175,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                     interior,
                 ));
 
+                let kind_ty = match kind {
+                    hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure) => self
+                        .next_ty_var(TypeVariableOrigin {
+                            kind: TypeVariableOriginKind::ClosureSynthetic,
+                            span: expr_span,
+                        }),
+                    _ => tcx.types.unit,
+                };
+
                 let coroutine_args = ty::CoroutineArgs::new(
                     tcx,
                     ty::CoroutineArgsParts {
                         parent_args,
+                        kind_ty,
                         resume_ty,
                         yield_ty,
                         return_ty: liberated_sig.output(),
@@ -256,6 +266,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                             sig.to_coroutine(
                                 tcx,
                                 parent_args,
+                                closure_kind_ty,
                                 tcx.coroutine_for_closure(expr_def_id),
                                 coroutine_upvars_ty,
                             )
diff --git a/compiler/rustc_hir_typeck/src/upvar.rs b/compiler/rustc_hir_typeck/src/upvar.rs
index b087d6d9e57..d4e072976fa 100644
--- a/compiler/rustc_hir_typeck/src/upvar.rs
+++ b/compiler/rustc_hir_typeck/src/upvar.rs
@@ -393,6 +393,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
                 coroutine_captures_by_ref_ty,
             );
+
+            let ty::Coroutine(_, args) = *self.typeck_results.borrow().expr_ty(body.value).kind()
+            else {
+                bug!();
+            };
+            self.demand_eqtype(
+                span,
+                args.as_coroutine().kind_ty(),
+                Ty::from_closure_kind(self.tcx, closure_kind),
+            );
         }
 
         self.log_closure_min_capture_info(closure_def_id, span);
diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs
index c9e69253701..d88e9261e5a 100644
--- a/compiler/rustc_middle/src/mir/mod.rs
+++ b/compiler/rustc_middle/src/mir/mod.rs
@@ -262,6 +262,10 @@ pub struct CoroutineInfo<'tcx> {
     /// Coroutine drop glue. This field is populated after the state transform pass.
     pub coroutine_drop: Option<Body<'tcx>>,
 
+    /// The body of the coroutine, modified to take its upvars by move.
+    /// TODO:
+    pub by_move_body: Option<Body<'tcx>>,
+
     /// The layout of a coroutine. This field is populated after the state transform pass.
     pub coroutine_layout: Option<CoroutineLayout<'tcx>>,
 
@@ -281,6 +285,7 @@ impl<'tcx> CoroutineInfo<'tcx> {
             coroutine_kind,
             yield_ty: Some(yield_ty),
             resume_ty: Some(resume_ty),
+            by_move_body: None,
             coroutine_drop: None,
             coroutine_layout: None,
         }
diff --git a/compiler/rustc_middle/src/mir/mono.rs b/compiler/rustc_middle/src/mir/mono.rs
index 4a29171d8bf..e6d1535fdf2 100644
--- a/compiler/rustc_middle/src/mir/mono.rs
+++ b/compiler/rustc_middle/src/mir/mono.rs
@@ -403,6 +403,7 @@ impl<'tcx> CodegenUnit<'tcx> {
                             | InstanceDef::Virtual(..)
                             | InstanceDef::ClosureOnceShim { .. }
                             | InstanceDef::ConstructCoroutineInClosureShim { .. }
+                            | InstanceDef::CoroutineByMoveShim { .. }
                             | InstanceDef::DropGlue(..)
                             | InstanceDef::CloneShim(..)
                             | InstanceDef::ThreadLocalShim(..)
diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs
index 6bc58adea0f..ce1859d6ada 100644
--- a/compiler/rustc_middle/src/mir/visit.rs
+++ b/compiler/rustc_middle/src/mir/visit.rs
@@ -346,6 +346,7 @@ macro_rules! make_mir_visitor {
                         ty::InstanceDef::ThreadLocalShim(_def_id) |
                         ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
                         ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } |
+                        ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: _def_id } |
                         ty::InstanceDef::DropGlue(_def_id, None) => {}
 
                         ty::InstanceDef::FnPtrShim(_def_id, ty) |
diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs
index 41ae136851e..44bf3c32b48 100644
--- a/compiler/rustc_middle/src/ty/instance.rs
+++ b/compiler/rustc_middle/src/ty/instance.rs
@@ -101,6 +101,9 @@ pub enum InstanceDef<'tcx> {
         target_kind: ty::ClosureKind,
     },
 
+    /// TODO:
+    CoroutineByMoveShim { coroutine_def_id: DefId },
+
     /// Compiler-generated accessor for thread locals which returns a reference to the thread local
     /// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking
     /// native support.
@@ -186,6 +189,7 @@ impl<'tcx> InstanceDef<'tcx> {
                 coroutine_closure_def_id: def_id,
                 target_kind: _,
             }
+            | ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: def_id }
             | InstanceDef::DropGlue(def_id, _)
             | InstanceDef::CloneShim(def_id, _)
             | InstanceDef::FnPtrAddrShim(def_id, _) => def_id,
@@ -206,6 +210,7 @@ impl<'tcx> InstanceDef<'tcx> {
             | InstanceDef::Intrinsic(..)
             | InstanceDef::ClosureOnceShim { .. }
             | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
+            | ty::InstanceDef::CoroutineByMoveShim { .. }
             | InstanceDef::DropGlue(..)
             | InstanceDef::CloneShim(..)
             | InstanceDef::FnPtrAddrShim(..) => None,
@@ -302,6 +307,7 @@ impl<'tcx> InstanceDef<'tcx> {
             | InstanceDef::DropGlue(_, Some(_)) => false,
             InstanceDef::ClosureOnceShim { .. }
             | InstanceDef::ConstructCoroutineInClosureShim { .. }
+            | InstanceDef::CoroutineByMoveShim { .. }
             | InstanceDef::DropGlue(..)
             | InstanceDef::Item(_)
             | InstanceDef::Intrinsic(..)
@@ -340,6 +346,7 @@ fn fmt_instance(
         InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"),
         InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"),
         InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"),
+        InstanceDef::CoroutineByMoveShim { .. } => write!(f, " - shim"),
         InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"),
         InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"),
         InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"),
@@ -631,7 +638,19 @@ impl<'tcx> Instance<'tcx> {
         };
 
         if tcx.lang_items().get(coroutine_callable_item) == Some(trait_item_id) {
-            Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args })
+            let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind()
+            else {
+                bug!()
+            };
+
+            if args.as_coroutine().kind_ty() == id_args.as_coroutine().kind_ty() {
+                Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
+            } else {
+                Some(Instance {
+                    def: ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id },
+                    args,
+                })
+            }
         } else {
             // All other methods should be defaulted methods of the built-in trait.
             // This is important for `Iterator`'s combinators, but also useful for
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 05875a9798b..9ceb3ec3f61 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -1681,6 +1681,7 @@ impl<'tcx> TyCtxt<'tcx> {
             | ty::InstanceDef::Virtual(..)
             | ty::InstanceDef::ClosureOnceShim { .. }
             | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
+            | ty::InstanceDef::CoroutineByMoveShim { .. }
             | ty::InstanceDef::DropGlue(..)
             | ty::InstanceDef::CloneShim(..)
             | ty::InstanceDef::ThreadLocalShim(..)
diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs
index e2a2e24f06d..8918a3735d6 100644
--- a/compiler/rustc_middle/src/ty/sty.rs
+++ b/compiler/rustc_middle/src/ty/sty.rs
@@ -399,6 +399,7 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
         self,
         tcx: TyCtxt<'tcx>,
         parent_args: &'tcx [GenericArg<'tcx>],
+        kind_ty: Ty<'tcx>,
         coroutine_def_id: DefId,
         tupled_upvars_ty: Ty<'tcx>,
     ) -> Ty<'tcx> {
@@ -406,6 +407,7 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
             tcx,
             ty::CoroutineArgsParts {
                 parent_args,
+                kind_ty,
                 resume_ty: self.resume_ty,
                 yield_ty: self.yield_ty,
                 return_ty: self.return_ty,
@@ -436,7 +438,13 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
             env_region,
         );
 
-        self.to_coroutine(tcx, parent_args, coroutine_def_id, tupled_upvars_ty)
+        self.to_coroutine(
+            tcx,
+            parent_args,
+            Ty::from_closure_kind(tcx, closure_kind),
+            coroutine_def_id,
+            tupled_upvars_ty,
+        )
     }
 
     /// Given a closure kind, compute the tupled upvars that the given coroutine would return.
@@ -488,6 +496,8 @@ pub struct CoroutineArgs<'tcx> {
 pub struct CoroutineArgsParts<'tcx> {
     /// This is the args of the typeck root.
     pub parent_args: &'tcx [GenericArg<'tcx>],
+    // TODO: why
+    pub kind_ty: Ty<'tcx>,
     pub resume_ty: Ty<'tcx>,
     pub yield_ty: Ty<'tcx>,
     pub return_ty: Ty<'tcx>,
@@ -506,6 +516,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
     pub fn new(tcx: TyCtxt<'tcx>, parts: CoroutineArgsParts<'tcx>) -> CoroutineArgs<'tcx> {
         CoroutineArgs {
             args: tcx.mk_args_from_iter(parts.parent_args.iter().copied().chain([
+                parts.kind_ty.into(),
                 parts.resume_ty.into(),
                 parts.yield_ty.into(),
                 parts.return_ty.into(),
@@ -519,16 +530,23 @@ impl<'tcx> CoroutineArgs<'tcx> {
     /// The ordering assumed here must match that used by `CoroutineArgs::new` above.
     fn split(self) -> CoroutineArgsParts<'tcx> {
         match self.args[..] {
-            [ref parent_args @ .., resume_ty, yield_ty, return_ty, witness, tupled_upvars_ty] => {
-                CoroutineArgsParts {
-                    parent_args,
-                    resume_ty: resume_ty.expect_ty(),
-                    yield_ty: yield_ty.expect_ty(),
-                    return_ty: return_ty.expect_ty(),
-                    witness: witness.expect_ty(),
-                    tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
-                }
-            }
+            [
+                ref parent_args @ ..,
+                kind_ty,
+                resume_ty,
+                yield_ty,
+                return_ty,
+                witness,
+                tupled_upvars_ty,
+            ] => CoroutineArgsParts {
+                parent_args,
+                kind_ty: kind_ty.expect_ty(),
+                resume_ty: resume_ty.expect_ty(),
+                yield_ty: yield_ty.expect_ty(),
+                return_ty: return_ty.expect_ty(),
+                witness: witness.expect_ty(),
+                tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
+            },
             _ => bug!("coroutine args missing synthetics"),
         }
     }
@@ -538,6 +556,11 @@ impl<'tcx> CoroutineArgs<'tcx> {
         self.split().parent_args
     }
 
+    // TODO:
+    pub fn kind_ty(self) -> Ty<'tcx> {
+        self.split().kind_ty
+    }
+
     /// This describes the types that can be contained in a coroutine.
     /// It will be a type variable initially and unified in the last stages of typeck of a body.
     /// It contains a tuple of all the types that could end up on a coroutine frame.
@@ -1628,7 +1651,7 @@ impl<'tcx> Ty<'tcx> {
     ) -> Ty<'tcx> {
         debug_assert_eq!(
             coroutine_args.len(),
-            tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 5,
+            tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 6,
             "coroutine constructed with incorrect number of substitutions"
         );
         Ty::new(tcx, Coroutine(def_id, coroutine_args))
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index bde879f6067..297b2fa143d 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -50,6 +50,9 @@
 //! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing.
 //! Otherwise it drops all the values in scope at the last suspension point.
 
+mod by_move_body;
+pub use by_move_body::ByMoveBody;
+
 use crate::abort_unwinding_calls;
 use crate::deref_separator::deref_finder;
 use crate::errors;
diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
new file mode 100644
index 00000000000..4e3e70bdafe
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
@@ -0,0 +1,108 @@
+use rustc_data_structures::fx::FxIndexSet;
+use rustc_hir as hir;
+use rustc_middle::mir::visit::MutVisitor;
+use rustc_middle::mir::{self, MirPass};
+use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt};
+use rustc_target::abi::FieldIdx;
+
+pub struct ByMoveBody;
+
+impl<'tcx> MirPass<'tcx> for ByMoveBody {
+    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) {
+        let Some(coroutine_def_id) = body.source.def_id().as_local() else {
+            return;
+        };
+        let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) =
+            tcx.coroutine_kind(coroutine_def_id)
+        else {
+            return;
+        };
+        let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
+        let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() };
+        if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce {
+            return;
+        }
+
+        let mut by_ref_fields = FxIndexSet::default();
+        let by_move_upvars = Ty::new_tup_from_iter(
+            tcx,
+            tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
+                if capture.is_by_ref() {
+                    by_ref_fields.insert(FieldIdx::from_usize(idx));
+                }
+                capture.place.ty()
+            }),
+        );
+        let by_move_coroutine_ty = Ty::new_coroutine(
+            tcx,
+            coroutine_def_id.to_def_id(),
+            ty::CoroutineArgs::new(
+                tcx,
+                ty::CoroutineArgsParts {
+                    parent_args: args.as_coroutine().parent_args(),
+                    kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
+                    resume_ty: args.as_coroutine().resume_ty(),
+                    yield_ty: args.as_coroutine().yield_ty(),
+                    return_ty: args.as_coroutine().return_ty(),
+                    witness: args.as_coroutine().witness(),
+                    tupled_upvars_ty: by_move_upvars,
+                },
+            )
+            .args,
+        );
+
+        let mut by_move_body = body.clone();
+        MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
+        by_move_body.source = mir::MirSource {
+            instance: InstanceDef::CoroutineByMoveShim {
+                coroutine_def_id: coroutine_def_id.to_def_id(),
+            },
+            promoted: None,
+        };
+
+        body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);
+    }
+}
+
+struct MakeByMoveBody<'tcx> {
+    tcx: TyCtxt<'tcx>,
+    by_ref_fields: FxIndexSet<FieldIdx>,
+    by_move_coroutine_ty: Ty<'tcx>,
+}
+
+impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
+    fn tcx(&self) -> TyCtxt<'tcx> {
+        self.tcx
+    }
+
+    fn visit_place(
+        &mut self,
+        place: &mut mir::Place<'tcx>,
+        context: mir::visit::PlaceContext,
+        location: mir::Location,
+    ) {
+        if place.local == ty::CAPTURE_STRUCT_LOCAL
+            && !place.projection.is_empty()
+            && let mir::ProjectionElem::Field(idx, ty) = place.projection[0]
+            && self.by_ref_fields.contains(&idx)
+        {
+            let (begin, end) = place.projection[1..].split_first().unwrap();
+            assert_eq!(*begin, mir::ProjectionElem::Deref);
+            *place = mir::Place {
+                local: place.local,
+                projection: self.tcx.mk_place_elems_from_iter(
+                    [mir::ProjectionElem::Field(idx, ty.builtin_deref(true).unwrap().ty)]
+                        .into_iter()
+                        .chain(end.iter().copied()),
+                ),
+            };
+        }
+        self.super_place(place, context, location);
+    }
+
+    fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
+        if local == ty::CAPTURE_STRUCT_LOCAL {
+            local_decl.ty = self.by_move_coroutine_ty;
+        }
+    }
+}
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 7c731b070a7..24bc84a235c 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -318,6 +318,7 @@ impl<'tcx> Inliner<'tcx> {
             | InstanceDef::FnPtrShim(..)
             | InstanceDef::ClosureOnceShim { .. }
             | InstanceDef::ConstructCoroutineInClosureShim { .. }
+            | InstanceDef::CoroutineByMoveShim { .. }
             | InstanceDef::DropGlue(..)
             | InstanceDef::CloneShim(..)
             | InstanceDef::ThreadLocalShim(..)
diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs
index 3f3dc9145b6..77ff780393e 100644
--- a/compiler/rustc_mir_transform/src/inline/cycle.rs
+++ b/compiler/rustc_mir_transform/src/inline/cycle.rs
@@ -88,6 +88,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
                 | InstanceDef::FnPtrShim(..)
                 | InstanceDef::ClosureOnceShim { .. }
                 | InstanceDef::ConstructCoroutineInClosureShim { .. }
+                | InstanceDef::CoroutineByMoveShim { .. }
                 | InstanceDef::ThreadLocalShim { .. }
                 | InstanceDef::CloneShim(..) => {}
 
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 69f93fa3a0e..031515ea958 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -307,6 +307,10 @@ fn mir_const(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> {
             &Lint(check_packed_ref::CheckPackedRef),
             &Lint(check_const_item_mutation::CheckConstItemMutation),
             &Lint(function_item_references::FunctionItemReferences),
+            // If this is an async closure's output coroutine, generate
+            // by-move and by-mut bodies if needed. We do this first so
+            // they can be optimized in lockstep with their parent bodies.
+            &coroutine::ByMoveBody,
             // What we need to do constant evaluation.
             &simplify::SimplifyCfg::Initial,
             &rustc_peek::SanityCheck, // Just a lint
diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs
index c1ef2b9f887..c7e770904fb 100644
--- a/compiler/rustc_mir_transform/src/pass_manager.rs
+++ b/compiler/rustc_mir_transform/src/pass_manager.rs
@@ -189,6 +189,12 @@ fn run_passes_inner<'tcx>(
 
         body.pass_count = 1;
     }
+
+    if let Some(coroutine) = body.coroutine.as_mut()
+        && let Some(by_move_body) = coroutine.by_move_body.as_mut()
+    {
+        run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each);
+    }
 }
 
 pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: String) {
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index 29b83f58ef5..668ccdd8735 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -81,6 +81,18 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
             }
         },
 
+        ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id } => {
+            return tcx
+                .optimized_mir(coroutine_def_id)
+                .coroutine
+                .as_ref()
+                .unwrap()
+                .by_move_body
+                .as_ref()
+                .unwrap()
+                .clone();
+        }
+
         ty::InstanceDef::DropGlue(def_id, ty) => {
             // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
             // of this function. Is this intentional?
diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs
index 698fd634114..cf3c8e1fdd3 100644
--- a/compiler/rustc_monomorphize/src/collector.rs
+++ b/compiler/rustc_monomorphize/src/collector.rs
@@ -983,7 +983,8 @@ fn visit_instance_use<'tcx>(
         | ty::InstanceDef::VTableShim(..)
         | ty::InstanceDef::ReifyShim(..)
         | ty::InstanceDef::ClosureOnceShim { .. }
-        | InstanceDef::ConstructCoroutineInClosureShim { .. }
+        | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
+        | ty::InstanceDef::CoroutineByMoveShim { .. }
         | ty::InstanceDef::Item(..)
         | ty::InstanceDef::FnPtrShim(..)
         | ty::InstanceDef::CloneShim(..)
diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs
index 46bd33c89e7..22b35c4344b 100644
--- a/compiler/rustc_monomorphize/src/partitioning.rs
+++ b/compiler/rustc_monomorphize/src/partitioning.rs
@@ -620,7 +620,8 @@ fn characteristic_def_id_of_mono_item<'tcx>(
                 | ty::InstanceDef::ReifyShim(..)
                 | ty::InstanceDef::FnPtrShim(..)
                 | ty::InstanceDef::ClosureOnceShim { .. }
-                | InstanceDef::ConstructCoroutineInClosureShim { .. }
+                | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
+                | ty::InstanceDef::CoroutineByMoveShim { .. }
                 | ty::InstanceDef::Intrinsic(..)
                 | ty::InstanceDef::DropGlue(..)
                 | ty::InstanceDef::Virtual(..)
@@ -785,6 +786,7 @@ fn mono_item_visibility<'tcx>(
         | InstanceDef::Intrinsic(..)
         | InstanceDef::ClosureOnceShim { .. }
         | InstanceDef::ConstructCoroutineInClosureShim { .. }
+        | InstanceDef::CoroutineByMoveShim { .. }
         | InstanceDef::DropGlue(..)
         | InstanceDef::CloneShim(..)
         | InstanceDef::FnPtrAddrShim(..) => return Visibility::Hidden,
diff --git a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs
index e0e9815cf40..3c1858e920b 100644
--- a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs
+++ b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs
@@ -800,6 +800,7 @@ impl<'tcx> Stable<'tcx> for ty::Instance<'tcx> {
             | ty::InstanceDef::FnPtrAddrShim(..)
             | ty::InstanceDef::ClosureOnceShim { .. }
             | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
+            | ty::InstanceDef::CoroutineByMoveShim { .. }
             | ty::InstanceDef::ThreadLocalShim(..)
             | ty::InstanceDef::DropGlue(..)
             | ty::InstanceDef::CloneShim(..)
diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
index c35134c78eb..0699026117d 100644
--- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
+++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
@@ -366,6 +366,7 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc
                         let coroutine_ty = sig.to_coroutine(
                             tcx,
                             args.parent_args(),
+                            Ty::from_closure_kind(tcx, goal_kind),
                             tcx.coroutine_for_closure(def_id),
                             tupled_upvars_ty,
                         );
diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs
index 648f14beaa7..db1e89ae72f 100644
--- a/compiler/rustc_trait_selection/src/traits/project.rs
+++ b/compiler/rustc_trait_selection/src/traits/project.rs
@@ -2505,6 +2505,7 @@ fn confirm_async_closure_candidate<'cx, 'tcx>(
                 let coroutine_ty = sig.to_coroutine(
                     tcx,
                     args.parent_args(),
+                    Ty::from_closure_kind(tcx, goal_kind),
                     tcx.coroutine_for_closure(def_id),
                     tupled_upvars_ty,
                 );
@@ -2533,6 +2534,7 @@ fn confirm_async_closure_candidate<'cx, 'tcx>(
                 let coroutine_ty = sig.to_coroutine(
                     tcx,
                     args.parent_args(),
+                    Ty::from_closure_kind(tcx, goal_kind),
                     tcx.coroutine_for_closure(def_id),
                     tupled_upvars_ty,
                 );
diff --git a/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir b/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir
index 3c0d4008c90..9c8cf8763fd 100644
--- a/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir
+++ b/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir
@@ -5,6 +5,7 @@
             ty: Coroutine(
                 DefId(0:4 ~ async_await[ccf8]::a::{closure#0}),
                 [
+                (),
                 std::future::ResumeTy,
                 (),
                 (),
@@ -22,6 +23,7 @@
             ty: Coroutine(
                 DefId(0:4 ~ async_await[ccf8]::a::{closure#0}),
                 [
+                (),
                 std::future::ResumeTy,
                 (),
                 (),