about summary refs log tree commit diff
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/coroutine/by_move_body.rs45
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs2
-rw-r--r--compiler/rustc_mir_transform/src/inline/cycle.rs2
-rw-r--r--compiler/rustc_mir_transform/src/pass_manager.rs11
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs77
5 files changed, 111 insertions, 26 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
index 1cc0a5026d1..fcd4715b9e8 100644
--- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
+++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
@@ -6,7 +6,7 @@
 use rustc_data_structures::fx::FxIndexSet;
 use rustc_hir as hir;
 use rustc_middle::mir::visit::MutVisitor;
-use rustc_middle::mir::{self, MirPass};
+use rustc_middle::mir::{self, dump_mir, MirPass};
 use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt};
 use rustc_target::abi::FieldIdx;
 
@@ -24,7 +24,9 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
         };
         let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
         let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() };
-        if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce {
+
+        let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
+        if coroutine_kind == ty::ClosureKind::FnOnce {
             return;
         }
 
@@ -58,14 +60,49 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
 
         let mut by_move_body = body.clone();
         MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
+        dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
         by_move_body.source = mir::MirSource {
-            instance: InstanceDef::CoroutineByMoveShim {
+            instance: InstanceDef::CoroutineKindShim {
                 coroutine_def_id: coroutine_def_id.to_def_id(),
+                target_kind: ty::ClosureKind::FnOnce,
             },
             promoted: None,
         };
-
         body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);
+
+        // If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body.
+        // This is actually just a copy of the by-ref body, but with a different self type.
+        // FIXME(async_closures): We could probably unify this with the by-ref body somehow.
+        if coroutine_kind == ty::ClosureKind::Fn {
+            let by_mut_coroutine_ty = Ty::new_coroutine(
+                tcx,
+                coroutine_def_id.to_def_id(),
+                ty::CoroutineArgs::new(
+                    tcx,
+                    ty::CoroutineArgsParts {
+                        parent_args: args.as_coroutine().parent_args(),
+                        kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut),
+                        resume_ty: args.as_coroutine().resume_ty(),
+                        yield_ty: args.as_coroutine().yield_ty(),
+                        return_ty: args.as_coroutine().return_ty(),
+                        witness: args.as_coroutine().witness(),
+                        tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(),
+                    },
+                )
+                .args,
+            );
+            let mut by_mut_body = body.clone();
+            by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty;
+            dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(()));
+            by_mut_body.source = mir::MirSource {
+                instance: InstanceDef::CoroutineKindShim {
+                    coroutine_def_id: coroutine_def_id.to_def_id(),
+                    target_kind: ty::ClosureKind::FnMut,
+                },
+                promoted: None,
+            };
+            body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body);
+        }
     }
 }
 
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 24bc84a235c..e77553a03d6 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -318,7 +318,7 @@ impl<'tcx> Inliner<'tcx> {
             | InstanceDef::FnPtrShim(..)
             | InstanceDef::ClosureOnceShim { .. }
             | InstanceDef::ConstructCoroutineInClosureShim { .. }
-            | InstanceDef::CoroutineByMoveShim { .. }
+            | InstanceDef::CoroutineKindShim { .. }
             | InstanceDef::DropGlue(..)
             | InstanceDef::CloneShim(..)
             | InstanceDef::ThreadLocalShim(..)
diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs
index 77ff780393e..5b03bc361dd 100644
--- a/compiler/rustc_mir_transform/src/inline/cycle.rs
+++ b/compiler/rustc_mir_transform/src/inline/cycle.rs
@@ -88,7 +88,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
                 | InstanceDef::FnPtrShim(..)
                 | InstanceDef::ClosureOnceShim { .. }
                 | InstanceDef::ConstructCoroutineInClosureShim { .. }
-                | InstanceDef::CoroutineByMoveShim { .. }
+                | InstanceDef::CoroutineKindShim { .. }
                 | InstanceDef::ThreadLocalShim { .. }
                 | InstanceDef::CloneShim(..) => {}
 
diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs
index c7e770904fb..605e1ad46d7 100644
--- a/compiler/rustc_mir_transform/src/pass_manager.rs
+++ b/compiler/rustc_mir_transform/src/pass_manager.rs
@@ -190,10 +190,13 @@ fn run_passes_inner<'tcx>(
         body.pass_count = 1;
     }
 
-    if let Some(coroutine) = body.coroutine.as_mut()
-        && let Some(by_move_body) = coroutine.by_move_body.as_mut()
-    {
-        run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each);
+    if let Some(coroutine) = body.coroutine.as_mut() {
+        if let Some(by_move_body) = coroutine.by_move_body.as_mut() {
+            run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each);
+        }
+        if let Some(by_mut_body) = coroutine.by_mut_body.as_mut() {
+            run_passes_inner(tcx, by_mut_body, passes, phase_change, validate_each);
+        }
     }
 }
 
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index 668ccdd8735..7b6de3a5439 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -72,32 +72,70 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
         } => match target_kind {
             ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"),
             ty::ClosureKind::FnMut => {
-                let body = build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id);
-                // No need to optimize the body, it has already been optimized.
-                return body;
+                // No need to optimize the body, it has already been optimized
+                // since we steal it from the `AsyncFn::call` body and just fix
+                // the return type.
+                return build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id);
             }
             ty::ClosureKind::FnOnce => {
                 build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
             }
         },
 
-        ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id } => {
-            return tcx
-                .optimized_mir(coroutine_def_id)
-                .coroutine
-                .as_ref()
-                .unwrap()
-                .by_move_body
-                .as_ref()
-                .unwrap()
-                .clone();
-        }
+        ty::InstanceDef::CoroutineKindShim { coroutine_def_id, target_kind } => match target_kind {
+            ty::ClosureKind::Fn => unreachable!(),
+            ty::ClosureKind::FnMut => {
+                return tcx
+                    .optimized_mir(coroutine_def_id)
+                    .coroutine_by_mut_body()
+                    .unwrap()
+                    .clone();
+            }
+            ty::ClosureKind::FnOnce => {
+                return tcx
+                    .optimized_mir(coroutine_def_id)
+                    .coroutine_by_move_body()
+                    .unwrap()
+                    .clone();
+            }
+        },
 
         ty::InstanceDef::DropGlue(def_id, ty) => {
             // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
             // of this function. Is this intentional?
             if let Some(ty::Coroutine(coroutine_def_id, args)) = ty.map(Ty::kind) {
-                let body = tcx.optimized_mir(*coroutine_def_id).coroutine_drop().unwrap();
+                let coroutine_body = tcx.optimized_mir(*coroutine_def_id);
+
+                let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind()
+                else {
+                    bug!()
+                };
+
+                // If this is a regular coroutine, grab its drop shim. If this is a coroutine
+                // that comes from a coroutine-closure, and the kind ty differs from the "maximum"
+                // kind that it supports, then grab the appropriate drop shim. This ensures that
+                // the future returned by `<[coroutine-closure] as AsyncFnOnce>::call_once` will
+                // drop the coroutine-closure's upvars.
+                let body = if id_args.as_coroutine().kind_ty() == args.as_coroutine().kind_ty() {
+                    coroutine_body.coroutine_drop().unwrap()
+                } else {
+                    match args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() {
+                        ty::ClosureKind::Fn => {
+                            unreachable!()
+                        }
+                        ty::ClosureKind::FnMut => coroutine_body
+                            .coroutine_by_mut_body()
+                            .unwrap()
+                            .coroutine_drop()
+                            .unwrap(),
+                        ty::ClosureKind::FnOnce => coroutine_body
+                            .coroutine_by_move_body()
+                            .unwrap()
+                            .coroutine_drop()
+                            .unwrap(),
+                    }
+                };
+
                 let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args);
                 debug!("make_shim({:?}) = {:?}", instance, body);
 
@@ -1076,7 +1114,11 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
         target_kind: ty::ClosureKind::FnOnce,
     });
 
-    new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span)
+    let body =
+        new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span);
+    dump_mir(tcx, false, "coroutine_closure_by_move", &0, &body, |_, _| Ok(()));
+
+    body
 }
 
 fn build_construct_coroutine_by_mut_shim<'tcx>(
@@ -1110,5 +1152,8 @@ fn build_construct_coroutine_by_mut_shim<'tcx>(
         target_kind: ty::ClosureKind::FnMut,
     });
 
+    body.pass_count = 0;
+    dump_mir(tcx, false, "coroutine_closure_by_mut", &0, &body, |_, _| Ok(()));
+
     body
 }