about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-02-28 20:25:25 +0000
committerMichael Goulet <michael@errs.io>2024-03-19 16:59:24 -0400
commitf1fef64e19909487ff2640bce58ce49fcfb4b85d (patch)
tree7281d36cff4865852e3d9c6193ba6689204ab033
parent05116c5c30dea6895fb65fe31b6f2dd0f1198b51 (diff)
downloadrust-f1fef64e19909487ff2640bce58ce49fcfb4b85d.tar.gz
rust-f1fef64e19909487ff2640bce58ce49fcfb4b85d.zip
Fix ABI for FnMut/Fn impls for async closures
-rw-r--r--compiler/rustc_middle/src/mir/visit.rs1
-rw-r--r--compiler/rustc_middle/src/ty/instance.rs11
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs24
-rw-r--r--compiler/rustc_ty_utils/src/abi.rs15
-rw-r--r--compiler/rustc_ty_utils/src/instance.rs2
-rw-r--r--tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir2
-rw-r--r--tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir2
-rw-r--r--tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir6
-rw-r--r--tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir6
-rw-r--r--tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-abort.mir10
-rw-r--r--tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-unwind.mir10
-rw-r--r--tests/mir-opt/async_closure_shims.rs10
12 files changed, 81 insertions, 18 deletions
diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs
index 562aed5a643..be960669ff4 100644
--- a/compiler/rustc_middle/src/mir/visit.rs
+++ b/compiler/rustc_middle/src/mir/visit.rs
@@ -347,6 +347,7 @@ macro_rules! make_mir_visitor {
                         ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
                         ty::InstanceDef::ConstructCoroutineInClosureShim {
                             coroutine_closure_def_id: _def_id,
+                            receiver_by_ref: _,
                         } |
                         ty::InstanceDef::CoroutineKindShim { coroutine_def_id: _def_id } |
                         ty::InstanceDef::DropGlue(_def_id, None) => {}
diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs
index bbe0915baa2..18ef4ed549b 100644
--- a/compiler/rustc_middle/src/ty/instance.rs
+++ b/compiler/rustc_middle/src/ty/instance.rs
@@ -95,7 +95,15 @@ pub enum InstanceDef<'tcx> {
     /// The body generated here differs significantly from the `ClosureOnceShim`,
     /// since we need to generate a distinct coroutine type that will move the
     /// closure's upvars *out* of the closure.
-    ConstructCoroutineInClosureShim { coroutine_closure_def_id: DefId },
+    ConstructCoroutineInClosureShim {
+        coroutine_closure_def_id: DefId,
+        // Whether the generated MIR body takes the coroutine by-ref. This is
+        // because the signature of `<{async fn} as FnMut>::call_mut` is:
+        // `fn(&mut self, args: A) -> <Self as FnOnce>::Output`, that is to say
+        // that it returns the `FnOnce`-flavored coroutine but takes the closure
+        // by ref (and similarly for `Fn::call`).
+        receiver_by_ref: bool,
+    },
 
     /// `<[coroutine] as Future>::poll`, but for coroutines produced when `AsyncFnOnce`
     /// is called on a coroutine-closure whose closure kind greater than `FnOnce`, or
@@ -188,6 +196,7 @@ impl<'tcx> InstanceDef<'tcx> {
             | InstanceDef::ClosureOnceShim { call_once: def_id, track_caller: _ }
             | ty::InstanceDef::ConstructCoroutineInClosureShim {
                 coroutine_closure_def_id: def_id,
+                receiver_by_ref: _,
             }
             | ty::InstanceDef::CoroutineKindShim { coroutine_def_id: def_id }
             | InstanceDef::DropGlue(def_id, _)
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index 3efaa69a7e7..4b2243598dc 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -70,9 +70,10 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
             build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut))
         }
 
-        ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id } => {
-            build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
-        }
+        ty::InstanceDef::ConstructCoroutineInClosureShim {
+            coroutine_closure_def_id,
+            receiver_by_ref,
+        } => build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id, receiver_by_ref),
 
         ty::InstanceDef::CoroutineKindShim { coroutine_def_id } => {
             return tcx.optimized_mir(coroutine_def_id).coroutine_by_move_body().unwrap().clone();
@@ -1015,12 +1016,17 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t
 fn build_construct_coroutine_by_move_shim<'tcx>(
     tcx: TyCtxt<'tcx>,
     coroutine_closure_def_id: DefId,
+    receiver_by_ref: bool,
 ) -> Body<'tcx> {
-    let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
+    let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
     let ty::CoroutineClosure(_, args) = *self_ty.kind() else {
         bug!();
     };
 
+    if receiver_by_ref {
+        self_ty = Ty::new_mut_ptr(tcx, self_ty);
+    }
+
     let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
         tcx.mk_fn_sig(
             [self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()),
@@ -1076,11 +1082,19 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
 
     let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
         coroutine_closure_def_id,
+        receiver_by_ref,
     });
 
     let body =
         new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span);
-    dump_mir(tcx, false, "coroutine_closure_by_move", &0, &body, |_, _| Ok(()));
+    dump_mir(
+        tcx,
+        false,
+        if receiver_by_ref { "coroutine_closure_by_ref" } else { "coroutine_closure_by_move" },
+        &0,
+        &body,
+        |_, _| Ok(()),
+    );
 
     body
 }
diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs
index 7d54083fbd5..baf4de768c5 100644
--- a/compiler/rustc_ty_utils/src/abi.rs
+++ b/compiler/rustc_ty_utils/src/abi.rs
@@ -118,11 +118,18 @@ fn fn_sig_for_fn_abi<'tcx>(
             // a separate def-id for these bodies.
             let mut coroutine_kind = args.as_coroutine_closure().kind();
 
-            if let InstanceDef::ConstructCoroutineInClosureShim { .. } = instance.def {
-                coroutine_kind = ty::ClosureKind::FnOnce;
-            }
+            let env_ty =
+                if let InstanceDef::ConstructCoroutineInClosureShim { receiver_by_ref, .. } =
+                    instance.def
+                {
+                    coroutine_kind = ty::ClosureKind::FnOnce;
 
-            let env_ty = tcx.closure_env_ty(coroutine_ty, coroutine_kind, env_region);
+                    // Implementations of `FnMut` and `Fn` for coroutine-closures
+                    // still take their receiver by ref.
+                    if receiver_by_ref { Ty::new_mut_ptr(tcx, coroutine_ty) } else { coroutine_ty }
+                } else {
+                    tcx.closure_env_ty(coroutine_ty, coroutine_kind, env_region)
+                };
 
             let sig = sig.skip_binder();
             ty::Binder::bind_with_vars(
diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs
index c2ea89f4c29..a8f9afb87dd 100644
--- a/compiler/rustc_ty_utils/src/instance.rs
+++ b/compiler/rustc_ty_utils/src/instance.rs
@@ -282,6 +282,7 @@ fn resolve_associated_item<'tcx>(
                             Some(Instance {
                                 def: ty::InstanceDef::ConstructCoroutineInClosureShim {
                                     coroutine_closure_def_id,
+                                    receiver_by_ref: target_kind != ty::ClosureKind::FnOnce,
                                 },
                                 args,
                             })
@@ -304,6 +305,7 @@ fn resolve_associated_item<'tcx>(
                             Some(Instance {
                                 def: ty::InstanceDef::ConstructCoroutineInClosureShim {
                                     coroutine_closure_def_id,
+                                    receiver_by_ref: false,
                                 },
                                 args,
                             })
diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir
index 6ca3dd61005..06028487d01 100644
--- a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir
+++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir
@@ -1,6 +1,6 @@
 // MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move
 
-fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10}, _2: ResumeTy) -> ()
+fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> ()
 yields ()
  {
     debug _task_context => _2;
diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir
index 6ca3dd61005..06028487d01 100644
--- a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir
+++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir
@@ -1,6 +1,6 @@
 // MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move
 
-fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10}, _2: ResumeTy) -> ()
+fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> ()
 yields ()
  {
     debug _task_context => _2;
diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir
index b5768e14452..93447b1388d 100644
--- a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir
+++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir
@@ -1,10 +1,10 @@
 // MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move
 
-fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:37:33: 37:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10} {
-    let mut _0: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10};
+fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} {
+    let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10};
 
     bb0: {
-        _0 = {coroutine@$DIR/async_closure_shims.rs:37:53: 40:10 (#0)} { a: move _2, b: move (_1.0: i32) };
+        _0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) };
         return;
     }
 }
diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir
index b5768e14452..93447b1388d 100644
--- a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir
+++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir
@@ -1,10 +1,10 @@
 // MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move
 
-fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:37:33: 37:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10} {
-    let mut _0: {async closure body@$DIR/async_closure_shims.rs:37:53: 40:10};
+fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} {
+    let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10};
 
     bb0: {
-        _0 = {coroutine@$DIR/async_closure_shims.rs:37:53: 40:10 (#0)} { a: move _2, b: move (_1.0: i32) };
+        _0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) };
         return;
     }
 }
diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-abort.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-abort.mir
new file mode 100644
index 00000000000..f51540bcfff
--- /dev/null
+++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-abort.mir
@@ -0,0 +1,10 @@
+// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref
+
+fn main::{closure#0}::{closure#1}(_1: *mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} {
+    let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10};
+
+    bb0: {
+        _0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 };
+        return;
+    }
+}
diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-unwind.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-unwind.mir
new file mode 100644
index 00000000000..f51540bcfff
--- /dev/null
+++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.panic-unwind.mir
@@ -0,0 +1,10 @@
+// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref
+
+fn main::{closure#0}::{closure#1}(_1: *mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} {
+    let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10};
+
+    bb0: {
+        _0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 };
+        return;
+    }
+}
diff --git a/tests/mir-opt/async_closure_shims.rs b/tests/mir-opt/async_closure_shims.rs
index 47c41ed0500..7d226df6866 100644
--- a/tests/mir-opt/async_closure_shims.rs
+++ b/tests/mir-opt/async_closure_shims.rs
@@ -29,8 +29,13 @@ async fn call_once(f: impl AsyncFnOnce(i32)) {
     f(1).await;
 }
 
+async fn call_normal<F: Future<Output = ()>>(f: &impl Fn(i32) -> F) {
+    f(1).await;
+}
+
 // EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.mir
 // EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.mir
+// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#1}.coroutine_closure_by_ref.0.mir
 pub fn main() {
     block_on(async {
         let b = 2i32;
@@ -40,5 +45,10 @@ pub fn main() {
         };
         call_mut(&mut async_closure).await;
         call_once(async_closure).await;
+
+        let async_closure = async move |a: i32| {
+            let a = &a;
+        };
+        call_normal(&async_closure).await;
     });
 }