about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-04-05 15:12:34 -0400
committerMichael Goulet <michael@errs.io>2024-04-05 15:28:13 -0400
commit3674032eb21d145992f1b8374e4ab201d606fbd9 (patch)
treeea778dc713e964bcb7cd1bfc72d6c968af6d0baf
parent1921968cc5403892739b43bdefe793a130badd15 (diff)
downloadrust-3674032eb21d145992f1b8374e4ab201d606fbd9.tar.gz
rust-3674032eb21d145992f1b8374e4ab201d606fbd9.zip
Rework the ByMoveBody shim to actually work correctly
-rw-r--r--compiler/rustc_mir_transform/src/coroutine/by_move_body.rs124
-rw-r--r--tests/ui/async-await/async-closures/precise-captures.call.run.stdout29
-rw-r--r--tests/ui/async-await/async-closures/precise-captures.call_once.run.stdout29
-rw-r--r--tests/ui/async-await/async-closures/precise-captures.force_once.run.stdout29
-rw-r--r--tests/ui/async-await/async-closures/precise-captures.rs157
5 files changed, 334 insertions, 34 deletions
diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
index de43f9faff9..1fb2c80dd40 100644
--- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
+++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
@@ -60,14 +60,13 @@
 //! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
 //! we use this "by move" body instead.
 
-use itertools::Itertools;
-
-use rustc_data_structures::unord::UnordSet;
+use rustc_data_structures::unord::UnordMap;
 use rustc_hir as hir;
+use rustc_middle::hir::place::{Projection, ProjectionKind};
 use rustc_middle::mir::visit::MutVisitor;
 use rustc_middle::mir::{self, dump_mir, MirPass};
 use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
-use rustc_target::abi::FieldIdx;
+use rustc_target::abi::{FieldIdx, VariantIdx};
 
 pub struct ByMoveBody;
 
@@ -116,32 +115,76 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
             .tuple_fields()
             .len();
 
-        let mut by_ref_fields = UnordSet::default();
-        for (idx, (coroutine_capture, parent_capture)) in tcx
+        let mut field_remapping = UnordMap::default();
+
+        let mut parent_captures =
+            tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable();
+
+        for (child_field_idx, child_capture) in tcx
             .closure_captures(coroutine_def_id)
             .iter()
+            .copied()
             // By construction we capture all the args first.
             .skip(num_args)
-            .zip_eq(tcx.closure_captures(parent_def_id))
             .enumerate()
         {
-            // This upvar is captured by-move from the parent closure, but by-ref
-            // from the inner async block. That means that it's being borrowed from
-            // the outer closure body -- we need to change the coroutine to take the
-            // upvar by value.
-            if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
-                assert_ne!(
-                    coroutine_kind,
-                    ty::ClosureKind::FnOnce,
-                    "`FnOnce` coroutine-closures return coroutines that capture from \
-                    their body; it will always result in a borrowck error!"
+            loop {
+                let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else {
+                    bug!("we ran out of parent captures!")
+                };
+
+                if !std::iter::zip(
+                    &child_capture.place.projections,
+                    &parent_capture.place.projections,
+                )
+                .all(|(child, parent)| child.kind == parent.kind)
+                {
+                    // Skip this field.
+                    let _ = parent_captures.next().unwrap();
+                    continue;
+                }
+
+                let child_precise_captures =
+                    &child_capture.place.projections[parent_capture.place.projections.len()..];
+
+                let needs_deref = child_capture.is_by_ref() && !parent_capture.is_by_ref();
+                if needs_deref {
+                    assert_ne!(
+                        coroutine_kind,
+                        ty::ClosureKind::FnOnce,
+                        "`FnOnce` coroutine-closures return coroutines that capture from \
+                        their body; it will always result in a borrowck error!"
+                    );
+                }
+
+                let mut parent_capture_ty = parent_capture.place.ty();
+                parent_capture_ty = match parent_capture.info.capture_kind {
+                    ty::UpvarCapture::ByValue => parent_capture_ty,
+                    ty::UpvarCapture::ByRef(kind) => Ty::new_ref(
+                        tcx,
+                        tcx.lifetimes.re_erased,
+                        parent_capture_ty,
+                        kind.to_mutbl_lossy(),
+                    ),
+                };
+
+                field_remapping.insert(
+                    FieldIdx::from_usize(child_field_idx + num_args),
+                    (
+                        FieldIdx::from_usize(parent_field_idx + num_args),
+                        parent_capture_ty,
+                        needs_deref,
+                        child_precise_captures,
+                    ),
                 );
-                by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
+
+                break;
             }
+        }
 
-            // Make sure we're actually talking about the same capture.
-            // FIXME(async_closures): We could look at the `hir::Upvar` instead?
-            assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
+        if coroutine_kind == ty::ClosureKind::FnOnce {
+            assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
+            return;
         }
 
         let by_move_coroutine_ty = tcx
@@ -157,7 +200,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
             );
 
         let mut by_move_body = body.clone();
-        MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
+        MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body);
         dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
         by_move_body.source = mir::MirSource::from_instance(InstanceDef::CoroutineKindShim {
             coroutine_def_id: coroutine_def_id.to_def_id(),
@@ -168,7 +211,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
 
 struct MakeByMoveBody<'tcx> {
     tcx: TyCtxt<'tcx>,
-    by_ref_fields: UnordSet<FieldIdx>,
+    field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, &'tcx [Projection<'tcx>])>,
     by_move_coroutine_ty: Ty<'tcx>,
 }
 
@@ -184,23 +227,36 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
         location: mir::Location,
     ) {
         if place.local == ty::CAPTURE_STRUCT_LOCAL
-            && let Some((&mir::ProjectionElem::Field(idx, ty), projection)) =
+            && let Some((&mir::ProjectionElem::Field(idx, _), projection)) =
                 place.projection.split_first()
-            && self.by_ref_fields.contains(&idx)
+            && let Some(&(remapped_idx, remapped_ty, needs_deref, additional_projections)) =
+                self.field_remapping.get(&idx)
         {
-            let (begin, end) = projection.split_first().unwrap();
-            // FIXME(async_closures): I'm actually a bit surprised to see that we always
-            // initially deref the by-ref upvars. If this is not actually true, then we
-            // will at least get an ICE that explains why this isn't true :^)
-            assert_eq!(*begin, mir::ProjectionElem::Deref);
-            // Peel one ref off of the ty.
-            let peeled_ty = ty.builtin_deref(true).unwrap().ty;
+            let final_deref = if needs_deref {
+                let Some((mir::ProjectionElem::Deref, rest)) = projection.split_first() else {
+                    bug!();
+                };
+                rest
+            } else {
+                projection
+            };
+
+            let additional_projections =
+                additional_projections.iter().map(|elem| match elem.kind {
+                    ProjectionKind::Deref => mir::ProjectionElem::Deref,
+                    ProjectionKind::Field(idx, VariantIdx::ZERO) => {
+                        mir::ProjectionElem::Field(idx, elem.ty)
+                    }
+                    _ => unreachable!("precise captures only through fields and derefs"),
+                });
+
             *place = mir::Place {
                 local: place.local,
                 projection: self.tcx.mk_place_elems_from_iter(
-                    [mir::ProjectionElem::Field(idx, peeled_ty)]
+                    [mir::ProjectionElem::Field(remapped_idx, remapped_ty)]
                         .into_iter()
-                        .chain(end.iter().copied()),
+                        .chain(additional_projections)
+                        .chain(final_deref.iter().copied()),
                 ),
             };
         }
diff --git a/tests/ui/async-await/async-closures/precise-captures.call.run.stdout b/tests/ui/async-await/async-closures/precise-captures.call.run.stdout
new file mode 100644
index 00000000000..6062556837c
--- /dev/null
+++ b/tests/ui/async-await/async-closures/precise-captures.call.run.stdout
@@ -0,0 +1,29 @@
+after call
+after await
+fixed
+uncaptured
+
+after call
+after await
+fixed
+uncaptured
+
+after call
+after await
+fixed
+uncaptured
+
+after call
+after await
+fixed
+untouched
+
+after call
+drop first
+after await
+uncaptured
+
+after call
+drop first
+after await
+uncaptured
diff --git a/tests/ui/async-await/async-closures/precise-captures.call_once.run.stdout b/tests/ui/async-await/async-closures/precise-captures.call_once.run.stdout
new file mode 100644
index 00000000000..ddb02d47600
--- /dev/null
+++ b/tests/ui/async-await/async-closures/precise-captures.call_once.run.stdout
@@ -0,0 +1,29 @@
+after call
+after await
+fixed
+uncaptured
+
+after call
+after await
+fixed
+uncaptured
+
+after call
+fixed
+after await
+uncaptured
+
+after call
+after await
+fixed
+untouched
+
+after call
+drop first
+after await
+uncaptured
+
+after call
+drop first
+after await
+uncaptured
diff --git a/tests/ui/async-await/async-closures/precise-captures.force_once.run.stdout b/tests/ui/async-await/async-closures/precise-captures.force_once.run.stdout
new file mode 100644
index 00000000000..ddb02d47600
--- /dev/null
+++ b/tests/ui/async-await/async-closures/precise-captures.force_once.run.stdout
@@ -0,0 +1,29 @@
+after call
+after await
+fixed
+uncaptured
+
+after call
+after await
+fixed
+uncaptured
+
+after call
+fixed
+after await
+uncaptured
+
+after call
+after await
+fixed
+untouched
+
+after call
+drop first
+after await
+uncaptured
+
+after call
+drop first
+after await
+uncaptured
diff --git a/tests/ui/async-await/async-closures/precise-captures.rs b/tests/ui/async-await/async-closures/precise-captures.rs
new file mode 100644
index 00000000000..e82dd1dbaf0
--- /dev/null
+++ b/tests/ui/async-await/async-closures/precise-captures.rs
@@ -0,0 +1,157 @@
+//@ aux-build:block-on.rs
+//@ edition:2021
+//@ run-pass
+//@ check-run-results
+//@ revisions: call call_once force_once
+
+// call - Call the closure regularly.
+// call_once - Call the closure w/ `async FnOnce`, so exercising the by_move shim.
+// force_once - Force the closure mode to `FnOnce`, so exercising what was fixed
+//   in <https://github.com/rust-lang/rust/pull/123350>.
+
+#![feature(async_closure)]
+#![allow(unused_mut)]
+
+extern crate block_on;
+
+#[cfg(any(call, force_once))]
+macro_rules! call {
+    ($c:expr) => { ($c)() }
+}
+
+#[cfg(call_once)]
+async fn call_once(f: impl async FnOnce()) {
+    f().await
+}
+
+#[cfg(call_once)]
+macro_rules! call {
+    ($c:expr) => { call_once($c) }
+}
+
+#[cfg(not(force_once))]
+macro_rules! guidance {
+    ($c:expr) => { $c }
+}
+
+#[cfg(force_once)]
+fn infer_fnonce(c: impl async FnOnce()) -> impl async FnOnce() { c }
+
+#[cfg(force_once)]
+macro_rules! guidance {
+    ($c:expr) => { infer_fnonce($c) }
+}
+
+#[derive(Debug)]
+struct Drop(&'static str);
+
+impl std::ops::Drop for Drop {
+    fn drop(&mut self) {
+        println!("{}", self.0);
+    }
+}
+
+struct S {
+    a: i32,
+    b: Drop,
+    c: Drop,
+}
+
+async fn async_main() {
+    // Precise capture struct
+    {
+        let mut s = S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
+        let mut c = guidance!(async || {
+            s.a = 2;
+            let w = &mut s.b;
+            w.0 = "fixed";
+        });
+        s.c.0 = "uncaptured";
+        let fut = call!(c);
+        println!("after call");
+        fut.await;
+        println!("after await");
+    }
+    println!();
+
+    // Precise capture &mut struct
+    {
+        let s = &mut S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
+        let mut c = guidance!(async || {
+            s.a = 2;
+            let w = &mut s.b;
+            w.0 = "fixed";
+        });
+        s.c.0 = "uncaptured";
+        let fut = call!(c);
+        println!("after call");
+        fut.await;
+        println!("after await");
+    }
+    println!();
+
+    // Precise capture struct by move
+    {
+        let mut s = S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
+        let mut c = guidance!(async move || {
+            s.a = 2;
+            let w = &mut s.b;
+            w.0 = "fixed";
+        });
+        s.c.0 = "uncaptured";
+        let fut = call!(c);
+        println!("after call");
+        fut.await;
+        println!("after await");
+    }
+    println!();
+
+    // Precise capture &mut struct by move
+    {
+        let s = &mut S { a: 1, b: Drop("fix me up"), c: Drop("untouched") };
+        let mut c = guidance!(async move || {
+            s.a = 2;
+            let w = &mut s.b;
+            w.0 = "fixed";
+        });
+        // `s` is still captured fully as `&mut S`.
+        let fut = call!(c);
+        println!("after call");
+        fut.await;
+        println!("after await");
+    }
+    println!();
+
+    // Precise capture struct, consume field
+    {
+        let mut s = S { a: 1, b: Drop("drop first"), c: Drop("untouched") };
+        let c = guidance!(async move || {
+            // s.a = 2; // FIXME(async_closures): Figure out why this fails
+            drop(s.b);
+        });
+        s.c.0 = "uncaptured";
+        let fut = call!(c);
+        println!("after call");
+        fut.await;
+        println!("after await");
+    }
+    println!();
+
+    // Precise capture struct by move, consume field
+    {
+        let mut s = S { a: 1, b: Drop("drop first"), c: Drop("untouched") };
+        let c = guidance!(async move || {
+            // s.a = 2; // FIXME(async_closures): Figure out why this fails
+            drop(s.b);
+        });
+        s.c.0 = "uncaptured";
+        let fut = call!(c);
+        println!("after call");
+        fut.await;
+        println!("after await");
+    }
+}
+
+fn main() {
+    block_on::block_on(async_main());
+}