about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-04-01 14:42:00 -0400
committerMichael Goulet <michael@errs.io>2024-04-02 20:07:48 -0400
commita1a1f41027c16a940c1de9e445064692110637c8 (patch)
tree84c3b7cf598b0f8e0f03aee3c476676c4e762c45
parent88c2f4f5f50ace5ddc7655ea311435104d3659bd (diff)
downloadrust-a1a1f41027c16a940c1de9e445064692110637c8.tar.gz
rust-a1a1f41027c16a940c1de9e445064692110637c8.zip
Fix capture analysis for by-move closure bodies
-rw-r--r--compiler/rustc_mir_transform/src/coroutine/by_move_body.rs66
-rw-r--r--src/tools/miri/tests/pass/async-closure-captures.rs89
-rw-r--r--src/tools/miri/tests/pass/async-closure-captures.stdout10
-rw-r--r--tests/ui/async-await/async-closures/captures.rs80
-rw-r--r--tests/ui/async-await/async-closures/captures.run.stdout10
5 files changed, 239 insertions, 16 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
index 99dbb342268..60b52fba219 100644
--- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
+++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
@@ -3,6 +3,8 @@
 //! be a coroutine body that takes all of its upvars by-move, and which we stash
 //! into the `CoroutineInfo` for all coroutines returned by coroutine-closures.
 
+use itertools::Itertools;
+
 use rustc_data_structures::unord::UnordSet;
 use rustc_hir as hir;
 use rustc_middle::mir::visit::MutVisitor;
@@ -26,36 +28,68 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
         if coroutine_ty.references_error() {
             return;
         }
+
         let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
+        let args = args.as_coroutine();
 
-        let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
+        let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap();
         if coroutine_kind == ty::ClosureKind::FnOnce {
             return;
         }
 
-        let mut by_ref_fields = UnordSet::default();
-        let by_move_upvars = Ty::new_tup_from_iter(
-            tcx,
-            tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
-                if capture.is_by_ref() {
-                    by_ref_fields.insert(FieldIdx::from_usize(idx));
-                }
-                capture.place.ty()
-            }),
+        let parent_def_id = tcx.local_parent(coroutine_def_id);
+        let ty::CoroutineClosure(_, parent_args) =
+            *tcx.type_of(parent_def_id).instantiate_identity().kind()
+        else {
+            bug!();
+        };
+        let parent_args = parent_args.as_coroutine_closure();
+        let parent_upvars_ty = parent_args.tupled_upvars_ty();
+        let tupled_inputs_ty = tcx.instantiate_bound_regions_with_erased(
+            parent_args.coroutine_closure_sig().map_bound(|sig| sig.tupled_inputs_ty),
         );
+        let num_args = tupled_inputs_ty.tuple_fields().len();
+
+        let mut by_ref_fields = UnordSet::default();
+        for (idx, (coroutine_capture, parent_capture)) in tcx
+            .closure_captures(coroutine_def_id)
+            .iter()
+            // By construction we capture all the args first.
+            .skip(num_args)
+            .zip_eq(tcx.closure_captures(parent_def_id))
+            .enumerate()
+        {
+            // This argument is captured by-move from the parent closure, but by-ref
+            // from the inner async block. That means that it's being borrowed from
+            // the closure body -- we need to change the coroutine take it by move.
+            if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
+                by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
+            }
+
+            // Make sure we're actually talking about the same capture.
+            assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
+        }
+
         let by_move_coroutine_ty = Ty::new_coroutine(
             tcx,
             coroutine_def_id.to_def_id(),
             ty::CoroutineArgs::new(
                 tcx,
                 ty::CoroutineArgsParts {
-                    parent_args: args.as_coroutine().parent_args(),
+                    parent_args: args.parent_args(),
                     kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
-                    resume_ty: args.as_coroutine().resume_ty(),
-                    yield_ty: args.as_coroutine().yield_ty(),
-                    return_ty: args.as_coroutine().return_ty(),
-                    witness: args.as_coroutine().witness(),
-                    tupled_upvars_ty: by_move_upvars,
+                    resume_ty: args.resume_ty(),
+                    yield_ty: args.yield_ty(),
+                    return_ty: args.return_ty(),
+                    witness: args.witness(),
+                    // Concatenate the args + closure's captures (since they're all by move).
+                    tupled_upvars_ty: Ty::new_tup_from_iter(
+                        tcx,
+                        tupled_inputs_ty
+                            .tuple_fields()
+                            .iter()
+                            .chain(parent_upvars_ty.tuple_fields()),
+                    ),
                 },
             )
             .args,
diff --git a/src/tools/miri/tests/pass/async-closure-captures.rs b/src/tools/miri/tests/pass/async-closure-captures.rs
new file mode 100644
index 00000000000..acff4a38338
--- /dev/null
+++ b/src/tools/miri/tests/pass/async-closure-captures.rs
@@ -0,0 +1,89 @@
+#![feature(async_closure, noop_waker)]
+
+use std::future::Future;
+use std::pin::pin;
+use std::task::*;
+
+pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
+    let mut fut = pin!(fut);
+    let ctx = &mut Context::from_waker(Waker::noop());
+
+    loop {
+        match fut.as_mut().poll(ctx) {
+            Poll::Pending => {}
+            Poll::Ready(t) => break t,
+        }
+    }
+}
+
+fn main() {
+    block_on(async_main());
+}
+
+async fn call<T>(f: &impl async Fn() -> T) -> T {
+    f().await
+}
+
+async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
+    f().await
+}
+
+#[derive(Debug)]
+#[allow(unused)]
+struct Hello(i32);
+
+async fn async_main() {
+    // Capture something by-ref
+    {
+        let x = Hello(0);
+        let c = async || {
+            println!("{x:?}");
+        };
+        call(&c).await;
+        call_once(c).await;
+
+        let x = &Hello(1);
+        let c = async || {
+            println!("{x:?}");
+        };
+        call(&c).await;
+        call_once(c).await;
+    }
+
+    // Capture something and consume it (force to `AsyncFnOnce`)
+    {
+        let x = Hello(2);
+        let c = async || {
+            println!("{x:?}");
+            drop(x);
+        };
+        call_once(c).await;
+    }
+
+    // Capture something with `move`, don't consume it
+    {
+        let x = Hello(3);
+        let c = async move || {
+            println!("{x:?}");
+        };
+        call(&c).await;
+        call_once(c).await;
+
+        let x = &Hello(4);
+        let c = async move || {
+            println!("{x:?}");
+        };
+        call(&c).await;
+        call_once(c).await;
+    }
+
+    // Capture something with `move`, also consume it (so `AsyncFnOnce`)
+    {
+        let x = Hello(5);
+        let c = async move || {
+            println!("{x:?}");
+            drop(x);
+        };
+        call_once(c).await;
+    }
+}
diff --git a/src/tools/miri/tests/pass/async-closure-captures.stdout b/src/tools/miri/tests/pass/async-closure-captures.stdout
new file mode 100644
index 00000000000..a0db6d236fe
--- /dev/null
+++ b/src/tools/miri/tests/pass/async-closure-captures.stdout
@@ -0,0 +1,10 @@
+Hello(0)
+Hello(0)
+Hello(1)
+Hello(1)
+Hello(2)
+Hello(3)
+Hello(3)
+Hello(4)
+Hello(4)
+Hello(5)
diff --git a/tests/ui/async-await/async-closures/captures.rs b/tests/ui/async-await/async-closures/captures.rs
new file mode 100644
index 00000000000..46bbf53f0a7
--- /dev/null
+++ b/tests/ui/async-await/async-closures/captures.rs
@@ -0,0 +1,80 @@
+//@ aux-build:block-on.rs
+//@ edition:2021
+//@ run-pass
+//@ check-run-results
+
+#![feature(async_closure)]
+
+extern crate block_on;
+
+fn main() {
+    block_on::block_on(async_main());
+}
+
+async fn call<T>(f: &impl async Fn() -> T) -> T {
+    f().await
+}
+
+async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
+    f().await
+}
+
+#[derive(Debug)]
+#[allow(unused)]
+struct Hello(i32);
+
+async fn async_main() {
+    // Capture something by-ref
+    {
+        let x = Hello(0);
+        let c = async || {
+            println!("{x:?}");
+        };
+        call(&c).await;
+        call_once(c).await;
+
+        let x = &Hello(1);
+        let c = async || {
+            println!("{x:?}");
+        };
+        call(&c).await;
+        call_once(c).await;
+    }
+
+    // Capture something and consume it (force to `AsyncFnOnce`)
+    {
+        let x = Hello(2);
+        let c = async || {
+            println!("{x:?}");
+            drop(x);
+        };
+        call_once(c).await;
+    }
+
+    // Capture something with `move`, don't consume it
+    {
+        let x = Hello(3);
+        let c = async move || {
+            println!("{x:?}");
+        };
+        call(&c).await;
+        call_once(c).await;
+
+        let x = &Hello(4);
+        let c = async move || {
+            println!("{x:?}");
+        };
+        call(&c).await;
+        call_once(c).await;
+    }
+
+    // Capture something with `move`, also consume it (so `AsyncFnOnce`)
+    {
+        let x = Hello(5);
+        let c = async move || {
+            println!("{x:?}");
+            drop(x);
+        };
+        call_once(c).await;
+    }
+}
diff --git a/tests/ui/async-await/async-closures/captures.run.stdout b/tests/ui/async-await/async-closures/captures.run.stdout
new file mode 100644
index 00000000000..a0db6d236fe
--- /dev/null
+++ b/tests/ui/async-await/async-closures/captures.run.stdout
@@ -0,0 +1,10 @@
+Hello(0)
+Hello(0)
+Hello(1)
+Hello(1)
+Hello(2)
+Hello(3)
+Hello(3)
+Hello(4)
+Hello(4)
+Hello(5)