about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_hir_typeck/src/upvar.rs66
-rw-r--r--tests/ui/async-await/async-closures/non-copy-arg-does-not-force-inner-move.rs17
2 files changed, 70 insertions, 13 deletions
diff --git a/compiler/rustc_hir_typeck/src/upvar.rs b/compiler/rustc_hir_typeck/src/upvar.rs
index e64f4ebf45d..55f002291f0 100644
--- a/compiler/rustc_hir_typeck/src/upvar.rs
+++ b/compiler/rustc_hir_typeck/src/upvar.rs
@@ -219,7 +219,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         // `async-await/async-closures/force-move-due-to-inferred-kind.rs`.
         //
         // 2. If the coroutine-closure is forced to be `FnOnce` due to the way it
-        // uses its upvars, but not *all* upvars would force the closure to `FnOnce`.
+        // uses its upvars (e.g. it consumes a non-copy value), but not *all* upvars
+        // would force the closure to `FnOnce`.
         // See the test: `async-await/async-closures/force-move-due-to-actually-fnonce.rs`.
         //
         // This would lead to an impossible to satisfy situation, since `AsyncFnOnce`
@@ -227,11 +228,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         // we force the inner coroutine to also be `move`. This only matters for
         // coroutine-closures that are `move` since otherwise they themselves will
         // be borrowing from the outer environment, so there's no self-borrows occuring.
-        //
-        // One *important* note is that we do a call to `process_collected_capture_information`
-        // to eagerly test whether the coroutine would end up `FnOnce`, but we do this
-        // *before* capturing all the closure args by-value below, since that would always
-        // cause the analysis to return `FnOnce`.
         if let UpvarArgs::Coroutine(..) = args
             && let hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure) =
                 self.tcx.coroutine_kind(closure_def_id).expect("coroutine should have kind")
@@ -246,19 +242,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 capture_clause = hir::CaptureBy::Value { move_kw };
             }
             // (2.) The way that the closure uses its upvars means it's `FnOnce`.
-            else if let (_, ty::ClosureKind::FnOnce, _) = self
-                .process_collected_capture_information(
-                    capture_clause,
-                    &delegate.capture_information,
-                )
-            {
+            else if self.coroutine_body_consumes_upvars(closure_def_id, body) {
                 capture_clause = hir::CaptureBy::Value { move_kw };
             }
         }
 
         // As noted in `lower_coroutine_body_with_moved_arguments`, we default the capture mode
         // to `ByRef` for the `async {}` block internal to async fns/closure. This means
-        // that we would *not* be moving all of the parameters into the async block by default.
+        // that we would *not* be moving all of the parameters into the async block in all cases.
+        // For example, when one of the arguments is `Copy`, we turn a consuming use into a copy of
+        // a reference, so for `async fn x(t: i32) {}`, we'd only take a reference to `t`.
         //
         // We force all of these arguments to be captured by move before we do expr use analysis.
         //
@@ -535,6 +528,53 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         }
     }
 
+    /// Determines whether the body of the coroutine uses its upvars in a way that
+    /// consumes (i.e. moves) the value, which would force the coroutine to `FnOnce`.
+    /// In a more detailed comment above, we care whether this happens, since if
+    /// this happens, we want to force the coroutine to move all of the upvars it
+    /// would've borrowed from the parent coroutine-closure.
+    ///
+    /// This only really makes sense to be called on the child coroutine of a
+    /// coroutine-closure.
+    fn coroutine_body_consumes_upvars(
+        &self,
+        coroutine_def_id: LocalDefId,
+        body: &'tcx hir::Body<'tcx>,
+    ) -> bool {
+        // This block contains argument capturing details. Since arguments
+        // aren't upvars, we do not care about them for determining if the
+        // coroutine body actually consumes its upvars.
+        let hir::ExprKind::Block(&hir::Block { expr: Some(body), .. }, None) = body.value.kind
+        else {
+            bug!();
+        };
+        // Specifically, we only care about the *real* body of the coroutine.
+        // We skip out into the drop-temps within the block of the body in order
+        // to skip over the args of the desugaring.
+        let hir::ExprKind::DropTemps(body) = body.kind else {
+            bug!();
+        };
+
+        let mut delegate = InferBorrowKind {
+            closure_def_id: coroutine_def_id,
+            capture_information: Default::default(),
+            fake_reads: Default::default(),
+        };
+
+        let _ = euv::ExprUseVisitor::new(
+            &FnCtxt::new(self, self.tcx.param_env(coroutine_def_id), coroutine_def_id),
+            &mut delegate,
+        )
+        .consume_expr(body);
+
+        let (_, kind, _) = self.process_collected_capture_information(
+            hir::CaptureBy::Ref,
+            &delegate.capture_information,
+        );
+
+        matches!(kind, ty::ClosureKind::FnOnce)
+    }
+
     // Returns a list of `Ty`s for each upvar.
     fn final_upvar_tys(&self, closure_id: LocalDefId) -> Vec<Ty<'tcx>> {
         self.typeck_results
diff --git a/tests/ui/async-await/async-closures/non-copy-arg-does-not-force-inner-move.rs b/tests/ui/async-await/async-closures/non-copy-arg-does-not-force-inner-move.rs
new file mode 100644
index 00000000000..cd9d98d0799
--- /dev/null
+++ b/tests/ui/async-await/async-closures/non-copy-arg-does-not-force-inner-move.rs
@@ -0,0 +1,17 @@
+//@ aux-build:block-on.rs
+//@ edition:2021
+//@ build-pass
+
+#![feature(async_closure)]
+
+extern crate block_on;
+
+fn wrapper(f: impl Fn(String)) -> impl async Fn(String) {
+    async move |s| f(s)
+}
+
+fn main() {
+    block_on::block_on(async {
+        wrapper(|who| println!("Hello, {who}!"))(String::from("world")).await;
+    });
+}