about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-04-02 11:26:57 -0400
committerMichael Goulet <michael@errs.io>2024-04-02 20:07:49 -0400
commitec74a304bb737f24d6e7f2fb8c3a2b3cf3575f0f (patch)
tree0b900530ade29b7b0550bef06042276491d9b631
parenta1a1f41027c16a940c1de9e445064692110637c8 (diff)
downloadrust-ec74a304bb737f24d6e7f2fb8c3a2b3cf3575f0f.tar.gz
rust-ec74a304bb737f24d6e7f2fb8c3a2b3cf3575f0f.zip
Comments, comments, comments
-rw-r--r--compiler/rustc_mir_transform/src/coroutine/by_move_body.rs131
-rw-r--r--src/tools/miri/tests/pass/async-closure-captures.rs2
-rw-r--r--tests/ui/async-await/async-closures/captures.rs2
3 files changed, 96 insertions, 39 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
index 60b52fba219..0866205dfd0 100644
--- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
+++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
@@ -1,7 +1,64 @@
-//! A MIR pass which duplicates a coroutine's body and removes any derefs which
-//! would be present for upvars that are taken by-ref. The result of which will
-//! be a coroutine body that takes all of its upvars by-move, and which we stash
-//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures.
+//! This pass constructs a second coroutine body sufficient for return from
+//! `FnOnce`/`AsyncFnOnce` implementations for coroutine-closures (e.g. async closures).
+//!
+//! Consider an async closure like:
+//! ```rust
+//! #![feature(async_closure)]
+//!
+//! let x = vec![1, 2, 3];
+//!
+//! let closure = async move || {
+//!     println!("{x:#?}");
+//! };
+//! ```
+//!
+//! This desugars to something like:
+//! ```rust,ignore (invalid-borrowck)
+//! let x = vec![1, 2, 3];
+//!
+//! let closure = move || {
+//!     async {
+//!         println!("{x:#?}");
+//!     }
+//! };
+//! ```
+//!
+//! Important to note here is that while the outer closure *moves* `x: Vec<i32>`
+//! into its upvars, the inner `async` coroutine simply captures a ref of `x`.
+//! This is the "magic" of async closures -- the futures that they return are
+//! allowed to borrow from their parent closure's upvars.
+//!
+//! However, what happens when we call `closure` with `AsyncFnOnce` (or `FnOnce`,
+//! since all async closures implement that too)? Well, recall the signature:
+//! ```
+//! use std::future::Future;
+//! pub trait AsyncFnOnce<Args>
+//! {
+//!     type CallOnceFuture: Future<Output = Self::Output>;
+//!     type Output;
+//!     fn async_call_once(
+//!         self,
+//!         args: Args
+//!     ) -> Self::CallOnceFuture;
+//! }
+//! ```
+//!
+//! This signature *consumes* the async closure (`self`) and returns a `CallOnceFuture`.
+//! How do we deal with the fact that the coroutine is supposed to take a reference
+//! to the captured `x` from the parent closure, when that parent closure has been
+//! destroyed?
+//!
+//! This is the second piece of magic of async closures. We can simply create a
+//! *second* `async` coroutine body where that `x` that was previously captured
+//! by reference is now captured by value. This means that we consume the outer
+//! closure and return a new coroutine that will hold onto all of these captures,
+//! and drop them when it is finished (i.e. after it has been `.await`ed).
+//!
+//! We do this with the analysis below, which detects the captures that come from
+//! borrowing from the outer closure, and we simply peel off a `deref` projection
+//! from them. This second body is stored alongside the first body, and optimized
+//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
+//! we use this "by move" body instead.
 
 use itertools::Itertools;
 
@@ -16,6 +73,8 @@ pub struct ByMoveBody;
 
 impl<'tcx> MirPass<'tcx> for ByMoveBody {
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) {
+        // We only need to generate by-move coroutine bodies for coroutines that come
+        // from coroutine-closures.
         let Some(coroutine_def_id) = body.source.def_id().as_local() else {
             return;
         };
@@ -24,15 +83,19 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
         else {
             return;
         };
+
+        // Also, let's skip processing any bodies with errors, since there's no guarantee
+        // the MIR body will be constructed well.
         let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
         if coroutine_ty.references_error() {
             return;
         }
 
-        let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
-        let args = args.as_coroutine();
-
-        let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap();
+        let ty::Coroutine(_, coroutine_args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
+        // We don't need to generate a by-move coroutine if the kind of the coroutine is
+        // already `FnOnce` -- that means that any upvars that the closure consumes have
+        // already been taken by-value.
+        let coroutine_kind = coroutine_args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
         if coroutine_kind == ty::ClosureKind::FnOnce {
             return;
         }
@@ -43,12 +106,13 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
         else {
             bug!();
         };
-        let parent_args = parent_args.as_coroutine_closure();
-        let parent_upvars_ty = parent_args.tupled_upvars_ty();
-        let tupled_inputs_ty = tcx.instantiate_bound_regions_with_erased(
-            parent_args.coroutine_closure_sig().map_bound(|sig| sig.tupled_inputs_ty),
-        );
-        let num_args = tupled_inputs_ty.tuple_fields().len();
+        let parent_closure_args = parent_args.as_coroutine_closure();
+        let num_args = parent_closure_args
+            .coroutine_closure_sig()
+            .skip_binder()
+            .tupled_inputs_ty
+            .tuple_fields()
+            .len();
 
         let mut by_ref_fields = UnordSet::default();
         for (idx, (coroutine_capture, parent_capture)) in tcx
@@ -59,41 +123,30 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
             .zip_eq(tcx.closure_captures(parent_def_id))
             .enumerate()
         {
-            // This argument is captured by-move from the parent closure, but by-ref
+            // This upvar is captured by-move from the parent closure, but by-ref
             // from the inner async block. That means that it's being borrowed from
-            // the closure body -- we need to change the coroutine take it by move.
+            // the outer closure body -- we need to change the coroutine to take the
+            // upvar by value.
             if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
                 by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
             }
 
             // Make sure we're actually talking about the same capture.
+            // FIXME(async_closures): We could look at the `hir::Upvar` instead?
             assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
         }
 
-        let by_move_coroutine_ty = Ty::new_coroutine(
-            tcx,
-            coroutine_def_id.to_def_id(),
-            ty::CoroutineArgs::new(
+        let by_move_coroutine_ty = tcx
+            .instantiate_bound_regions_with_erased(parent_closure_args.coroutine_closure_sig())
+            .to_coroutine_given_kind_and_upvars(
                 tcx,
-                ty::CoroutineArgsParts {
-                    parent_args: args.parent_args(),
-                    kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
-                    resume_ty: args.resume_ty(),
-                    yield_ty: args.yield_ty(),
-                    return_ty: args.return_ty(),
-                    witness: args.witness(),
-                    // Concatenate the args + closure's captures (since they're all by move).
-                    tupled_upvars_ty: Ty::new_tup_from_iter(
-                        tcx,
-                        tupled_inputs_ty
-                            .tuple_fields()
-                            .iter()
-                            .chain(parent_upvars_ty.tuple_fields()),
-                    ),
-                },
-            )
-            .args,
-        );
+                parent_closure_args.parent_args(),
+                coroutine_def_id.to_def_id(),
+                ty::ClosureKind::FnOnce,
+                tcx.lifetimes.re_erased,
+                parent_closure_args.tupled_upvars_ty(),
+                parent_closure_args.coroutine_captures_by_ref_ty(),
+            );
 
         let mut by_move_body = body.clone();
         MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
diff --git a/src/tools/miri/tests/pass/async-closure-captures.rs b/src/tools/miri/tests/pass/async-closure-captures.rs
index acff4a38338..3e33de32efb 100644
--- a/src/tools/miri/tests/pass/async-closure-captures.rs
+++ b/src/tools/miri/tests/pass/async-closure-captures.rs
@@ -1,3 +1,5 @@
+// Same as rustc's `tests/ui/async-await/async-closures/captures.rs`, keep in sync
+
 #![feature(async_closure, noop_waker)]
 
 use std::future::Future;
diff --git a/tests/ui/async-await/async-closures/captures.rs b/tests/ui/async-await/async-closures/captures.rs
index 46bbf53f0a7..e3ab8713709 100644
--- a/tests/ui/async-await/async-closures/captures.rs
+++ b/tests/ui/async-await/async-closures/captures.rs
@@ -3,6 +3,8 @@
 //@ run-pass
 //@ check-run-results
 
+// Same as miri's `tests/pass/async-closure-captures.rs`, keep in sync
+
 #![feature(async_closure)]
 
 extern crate block_on;