about summary refs log tree commit diff
path: root/compiler/rustc_borrowck
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-02-06 15:04:01 +0000
committerbors <bors@rust-lang.org>2024-02-06 15:04:01 +0000
commit4a2fe4491ea616983a0cf0cbbd145a39768f4e7a (patch)
tree3ee31f8af96390f25ff12e1772aa224ba09a4828 /compiler/rustc_borrowck
parent037f515caf46846d2bffae55883eebc6c1ccb861 (diff)
parented7fca1f8805b4348b801f23f444e0dda42f7aed (diff)
downloadrust-4a2fe4491ea616983a0cf0cbbd145a39768f4e7a.tar.gz
rust-4a2fe4491ea616983a0cf0cbbd145a39768f4e7a.zip
Auto merge of #120361 - compiler-errors:async-closures, r=oli-obk
Rework support for async closures; allow them to return futures that borrow from the closure's captures

This PR implements a new lowering for async closures via `TyKind::CoroutineClosure` which handles the curious relationship between the closure and the coroutine that it returns.

I wrote up a bunch in [this hackmd](https://hackmd.io/`@compiler-errors/S1HvqQxca)` which will be copied to the dev guide after this PR lands, and hopefully left sufficient comments in the source code explaining why this change is as large as it is.

This also necessitates that they begin implementing the `AsyncFn`-family of traits, rather than the `Fn`-family of traits -- if you need `Fn` implementations, you should probably use the non-sugar `|| async {}` syntax instead.

Notably this PR does not yet implement `async Fn()` syntax sugar for bounds, but I expect to add those soon (**edit:** #120392). For now, users must use `AsyncFn()` traits directly, which necessitates adding the `async_fn_traits` feature gate as well. I will add this as a follow-up very soon.

r? oli-obk

This is based on top of #120322, but that PR is minimal.
Diffstat (limited to 'compiler/rustc_borrowck')
-rw-r--r--compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs23
-rw-r--r--compiler/rustc_borrowck/src/diagnostics/mod.rs4
-rw-r--r--compiler/rustc_borrowck/src/diagnostics/mutability_errors.rs2
-rw-r--r--compiler/rustc_borrowck/src/diagnostics/region_name.rs36
-rw-r--r--compiler/rustc_borrowck/src/lib.rs11
-rw-r--r--compiler/rustc_borrowck/src/path_utils.rs2
-rw-r--r--compiler/rustc_borrowck/src/type_check/input_output.rs76
-rw-r--r--compiler/rustc_borrowck/src/type_check/mod.rs32
-rw-r--r--compiler/rustc_borrowck/src/universal_regions.rs69
9 files changed, 220 insertions, 35 deletions
diff --git a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
index b0b7cc076ba..6debb3362b0 100644
--- a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs
@@ -858,7 +858,9 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
             use crate::session_diagnostics::CaptureVarCause::*;
             match kind {
                 hir::ClosureKind::Coroutine(_) => MoveUseInCoroutine { var_span },
-                hir::ClosureKind::Closure => MoveUseInClosure { var_span },
+                hir::ClosureKind::Closure | hir::ClosureKind::CoroutineClosure(_) => {
+                    MoveUseInClosure { var_span }
+                }
             }
         });
 
@@ -905,7 +907,7 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
                 hir::ClosureKind::Coroutine(_) => {
                     BorrowUsePlaceCoroutine { place: desc_place, var_span, is_single_var: true }
                 }
-                hir::ClosureKind::Closure => {
+                hir::ClosureKind::Closure | hir::ClosureKind::CoroutineClosure(_) => {
                     BorrowUsePlaceClosure { place: desc_place, var_span, is_single_var: true }
                 }
             }
@@ -1056,7 +1058,8 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
                                     var_span,
                                     is_single_var: true,
                                 },
-                                hir::ClosureKind::Closure => BorrowUsePlaceClosure {
+                                hir::ClosureKind::Closure
+                                | hir::ClosureKind::CoroutineClosure(_) => BorrowUsePlaceClosure {
                                     place: desc_place,
                                     var_span,
                                     is_single_var: true,
@@ -1140,7 +1143,7 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
                         var_span,
                         is_single_var: false,
                     },
-                    hir::ClosureKind::Closure => {
+                    hir::ClosureKind::Closure | hir::ClosureKind::CoroutineClosure(_) => {
                         BorrowUsePlaceClosure { place: desc_place, var_span, is_single_var: false }
                     }
                 }
@@ -1158,7 +1161,7 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
                         hir::ClosureKind::Coroutine(_) => {
                             FirstBorrowUsePlaceCoroutine { place: borrow_place_desc, var_span }
                         }
-                        hir::ClosureKind::Closure => {
+                        hir::ClosureKind::Closure | hir::ClosureKind::CoroutineClosure(_) => {
                             FirstBorrowUsePlaceClosure { place: borrow_place_desc, var_span }
                         }
                     }
@@ -1175,7 +1178,7 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
                         hir::ClosureKind::Coroutine(_) => {
                             SecondBorrowUsePlaceCoroutine { place: desc_place, var_span }
                         }
-                        hir::ClosureKind::Closure => {
+                        hir::ClosureKind::Closure | hir::ClosureKind::CoroutineClosure(_) => {
                             SecondBorrowUsePlaceClosure { place: desc_place, var_span }
                         }
                     }
@@ -2942,7 +2945,9 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
                     use crate::session_diagnostics::CaptureVarCause::*;
                     match kind {
                         hir::ClosureKind::Coroutine(_) => BorrowUseInCoroutine { var_span },
-                        hir::ClosureKind::Closure => BorrowUseInClosure { var_span },
+                        hir::ClosureKind::Closure | hir::ClosureKind::CoroutineClosure(_) => {
+                            BorrowUseInClosure { var_span }
+                        }
                     }
                 });
 
@@ -2958,7 +2963,9 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
             use crate::session_diagnostics::CaptureVarCause::*;
             match kind {
                 hir::ClosureKind::Coroutine(_) => BorrowUseInCoroutine { var_span },
-                hir::ClosureKind::Closure => BorrowUseInClosure { var_span },
+                hir::ClosureKind::Closure | hir::ClosureKind::CoroutineClosure(_) => {
+                    BorrowUseInClosure { var_span }
+                }
             }
         });
 
diff --git a/compiler/rustc_borrowck/src/diagnostics/mod.rs b/compiler/rustc_borrowck/src/diagnostics/mod.rs
index b35d4e16ecc..59f3aa706ed 100644
--- a/compiler/rustc_borrowck/src/diagnostics/mod.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/mod.rs
@@ -614,7 +614,7 @@ impl UseSpans<'_> {
                         PartialAssignment => AssignPartInCoroutine { path_span },
                     });
                 }
-                hir::ClosureKind::Closure => {
+                hir::ClosureKind::Closure | hir::ClosureKind::CoroutineClosure(_) => {
                     err.subdiagnostic(match action {
                         Borrow => BorrowInClosure { path_span },
                         MatchOn | Use => UseInClosure { path_span },
@@ -1253,7 +1253,7 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
                     hir::ClosureKind::Coroutine(_) => {
                         CaptureVarCause::PartialMoveUseInCoroutine { var_span, is_partial }
                     }
-                    hir::ClosureKind::Closure => {
+                    hir::ClosureKind::Closure | hir::ClosureKind::CoroutineClosure(_) => {
                         CaptureVarCause::PartialMoveUseInClosure { var_span, is_partial }
                     }
                 })
diff --git a/compiler/rustc_borrowck/src/diagnostics/mutability_errors.rs b/compiler/rustc_borrowck/src/diagnostics/mutability_errors.rs
index 3fddf67f55b..55649ec2f16 100644
--- a/compiler/rustc_borrowck/src/diagnostics/mutability_errors.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/mutability_errors.rs
@@ -1472,7 +1472,7 @@ fn suggest_ampmut<'tcx>(
 }
 
 fn is_closure_or_coroutine(ty: Ty<'_>) -> bool {
-    ty.is_closure() || ty.is_coroutine()
+    ty.is_closure() || ty.is_coroutine() || ty.is_coroutine_closure()
 }
 
 /// Given a field that needs to be mutable, returns a span where the " mut " could go.
diff --git a/compiler/rustc_borrowck/src/diagnostics/region_name.rs b/compiler/rustc_borrowck/src/diagnostics/region_name.rs
index 15e1066e983..c91b8eaab05 100644
--- a/compiler/rustc_borrowck/src/diagnostics/region_name.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/region_name.rs
@@ -324,9 +324,13 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
                 ty::BoundRegionKind::BrEnv => {
                     let def_ty = self.regioncx.universal_regions().defining_ty;
 
-                    let DefiningTy::Closure(_, args) = def_ty else {
-                        // Can't have BrEnv in functions, constants or coroutines.
-                        bug!("BrEnv outside of closure.");
+                    let closure_kind = match def_ty {
+                        DefiningTy::Closure(_, args) => args.as_closure().kind(),
+                        DefiningTy::CoroutineClosure(_, args) => args.as_coroutine_closure().kind(),
+                        _ => {
+                            // Can't have BrEnv in functions, constants or coroutines.
+                            bug!("BrEnv outside of closure.");
+                        }
                     };
                     let hir::ExprKind::Closure(&hir::Closure { fn_decl_span, .. }) =
                         tcx.hir().expect_expr(self.mir_hir_id()).kind
@@ -334,21 +338,18 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
                         bug!("Closure is not defined by a closure expr");
                     };
                     let region_name = self.synthesize_region_name();
-
-                    let closure_kind_ty = args.as_closure().kind_ty();
-                    let note = match closure_kind_ty.to_opt_closure_kind() {
-                        Some(ty::ClosureKind::Fn) => {
+                    let note = match closure_kind {
+                        ty::ClosureKind::Fn => {
                             "closure implements `Fn`, so references to captured variables \
                              can't escape the closure"
                         }
-                        Some(ty::ClosureKind::FnMut) => {
+                        ty::ClosureKind::FnMut => {
                             "closure implements `FnMut`, so references to captured variables \
                              can't escape the closure"
                         }
-                        Some(ty::ClosureKind::FnOnce) => {
+                        ty::ClosureKind::FnOnce => {
                             bug!("BrEnv in a `FnOnce` closure");
                         }
-                        None => bug!("Closure kind not inferred in borrow check"),
                     };
 
                     Some(RegionName {
@@ -692,7 +693,10 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
                     hir::ClosureKind::Coroutine(hir::CoroutineKind::Desugared(
                         hir::CoroutineDesugaring::Async,
                         hir::CoroutineSource::Closure,
-                    )) => " of async closure",
+                    ))
+                    | hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => {
+                        " of async closure"
+                    }
 
                     hir::ClosureKind::Coroutine(hir::CoroutineKind::Desugared(
                         hir::CoroutineDesugaring::Async,
@@ -719,7 +723,10 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
                     hir::ClosureKind::Coroutine(hir::CoroutineKind::Desugared(
                         hir::CoroutineDesugaring::Gen,
                         hir::CoroutineSource::Closure,
-                    )) => " of gen closure",
+                    ))
+                    | hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Gen) => {
+                        " of gen closure"
+                    }
 
                     hir::ClosureKind::Coroutine(hir::CoroutineKind::Desugared(
                         hir::CoroutineDesugaring::Gen,
@@ -743,7 +750,10 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
                     hir::ClosureKind::Coroutine(hir::CoroutineKind::Desugared(
                         hir::CoroutineDesugaring::AsyncGen,
                         hir::CoroutineSource::Closure,
-                    )) => " of async gen closure",
+                    ))
+                    | hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::AsyncGen) => {
+                        " of async gen closure"
+                    }
 
                     hir::ClosureKind::Coroutine(hir::CoroutineKind::Desugared(
                         hir::CoroutineDesugaring::AsyncGen,
diff --git a/compiler/rustc_borrowck/src/lib.rs b/compiler/rustc_borrowck/src/lib.rs
index 8b5e548345c..bb64571889b 100644
--- a/compiler/rustc_borrowck/src/lib.rs
+++ b/compiler/rustc_borrowck/src/lib.rs
@@ -3,6 +3,7 @@
 #![allow(internal_features)]
 #![feature(rustdoc_internals)]
 #![doc(rust_logo)]
+#![feature(assert_matches)]
 #![feature(associated_type_bounds)]
 #![feature(box_patterns)]
 #![feature(let_chains)]
@@ -1303,7 +1304,9 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
                 // moved into the closure and subsequently used by the closure,
                 // in order to populate our used_mut set.
                 match **aggregate_kind {
-                    AggregateKind::Closure(def_id, _) | AggregateKind::Coroutine(def_id, _) => {
+                    AggregateKind::Closure(def_id, _)
+                    | AggregateKind::CoroutineClosure(def_id, _)
+                    | AggregateKind::Coroutine(def_id, _) => {
                         let def_id = def_id.expect_local();
                         let BorrowCheckResult { used_mut_upvars, .. } =
                             self.infcx.tcx.mir_borrowck(def_id);
@@ -1609,6 +1612,7 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
                     | ty::FnPtr(_)
                     | ty::Dynamic(_, _, _)
                     | ty::Closure(_, _)
+                    | ty::CoroutineClosure(_, _)
                     | ty::Coroutine(_, _)
                     | ty::CoroutineWitness(..)
                     | ty::Never
@@ -1633,7 +1637,10 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> {
                             return;
                         }
                     }
-                    ty::Closure(_, _) | ty::Coroutine(_, _) | ty::Tuple(_) => (),
+                    ty::Closure(..)
+                    | ty::CoroutineClosure(..)
+                    | ty::Coroutine(_, _)
+                    | ty::Tuple(_) => (),
                     ty::Bool
                     | ty::Char
                     | ty::Int(_)
diff --git a/compiler/rustc_borrowck/src/path_utils.rs b/compiler/rustc_borrowck/src/path_utils.rs
index 2d997dfadf0..4cfde47664e 100644
--- a/compiler/rustc_borrowck/src/path_utils.rs
+++ b/compiler/rustc_borrowck/src/path_utils.rs
@@ -164,7 +164,7 @@ pub(crate) fn is_upvar_field_projection<'tcx>(
     match place_ref.last_projection() {
         Some((place_base, ProjectionElem::Field(field, _ty))) => {
             let base_ty = place_base.ty(body, tcx).ty;
-            if (base_ty.is_closure() || base_ty.is_coroutine())
+            if (base_ty.is_closure() || base_ty.is_coroutine() || base_ty.is_coroutine_closure())
                 && (!by_ref || upvars[field.index()].is_by_ref())
             {
                 Some(field)
diff --git a/compiler/rustc_borrowck/src/type_check/input_output.rs b/compiler/rustc_borrowck/src/type_check/input_output.rs
index 59518f68ab1..a5a7ce4ea3e 100644
--- a/compiler/rustc_borrowck/src/type_check/input_output.rs
+++ b/compiler/rustc_borrowck/src/type_check/input_output.rs
@@ -7,13 +7,18 @@
 //! `RETURN_PLACE` the MIR arguments) are always fully normalized (and
 //! contain revealed `impl Trait` values).
 
+use std::assert_matches::assert_matches;
+
 use itertools::Itertools;
-use rustc_infer::infer::BoundRegionConversionTime;
+use rustc_hir as hir;
+use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
+use rustc_infer::infer::{BoundRegionConversionTime, RegionVariableOrigin};
 use rustc_middle::mir::*;
 use rustc_middle::ty::{self, Ty};
 use rustc_span::Span;
 
-use crate::universal_regions::UniversalRegions;
+use crate::renumber::RegionCtxt;
+use crate::universal_regions::{DefiningTy, UniversalRegions};
 
 use super::{Locations, TypeChecker};
 
@@ -23,9 +28,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
     #[instrument(skip(self, body), level = "debug")]
     pub(super) fn check_signature_annotation(&mut self, body: &Body<'tcx>) {
         let mir_def_id = body.source.def_id().expect_local();
+
         if !self.tcx().is_closure_or_coroutine(mir_def_id.to_def_id()) {
             return;
         }
+
         let user_provided_poly_sig = self.tcx().closure_user_provided_sig(mir_def_id);
 
         // Instantiate the canonicalized variables from user-provided signature
@@ -34,12 +41,75 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
         // so that they represent the view from "inside" the closure.
         let user_provided_sig = self
             .instantiate_canonical_with_fresh_inference_vars(body.span, &user_provided_poly_sig);
-        let user_provided_sig = self.infcx.instantiate_binder_with_fresh_vars(
+        let mut user_provided_sig = self.infcx.instantiate_binder_with_fresh_vars(
             body.span,
             BoundRegionConversionTime::FnCall,
             user_provided_sig,
         );
 
+        // FIXME(async_closures): It's kind of wacky that we must apply this
+        // transformation here, since we do the same thing in HIR typeck.
+        // Maybe we could just fix up the canonicalized signature during HIR typeck?
+        if let DefiningTy::CoroutineClosure(_, args) =
+            self.borrowck_context.universal_regions.defining_ty
+        {
+            assert_matches!(
+                self.tcx().coroutine_kind(self.tcx().coroutine_for_closure(mir_def_id)),
+                Some(hir::CoroutineKind::Desugared(
+                    hir::CoroutineDesugaring::Async,
+                    hir::CoroutineSource::Closure
+                )),
+                "this needs to be modified if we're lowering non-async closures"
+            );
+            // Make sure to use the args from `DefiningTy` so the right NLL region vids are prepopulated
+            // into the type.
+            let args = args.as_coroutine_closure();
+            let tupled_upvars_ty = ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind(
+                self.tcx(),
+                args.kind(),
+                Ty::new_tup(self.tcx(), user_provided_sig.inputs()),
+                args.tupled_upvars_ty(),
+                args.coroutine_captures_by_ref_ty(),
+                self.infcx.next_region_var(RegionVariableOrigin::MiscVariable(body.span), || {
+                    RegionCtxt::Unknown
+                }),
+            );
+
+            let next_ty_var = || {
+                self.infcx.next_ty_var(TypeVariableOrigin {
+                    span: body.span,
+                    kind: TypeVariableOriginKind::MiscVariable,
+                })
+            };
+            let output_ty = Ty::new_coroutine(
+                self.tcx(),
+                self.tcx().coroutine_for_closure(mir_def_id),
+                ty::CoroutineArgs::new(
+                    self.tcx(),
+                    ty::CoroutineArgsParts {
+                        parent_args: args.parent_args(),
+                        kind_ty: Ty::from_closure_kind(self.tcx(), args.kind()),
+                        return_ty: user_provided_sig.output(),
+                        tupled_upvars_ty,
+                        // For async closures, none of these can be annotated, so just fill
+                        // them with fresh ty vars.
+                        resume_ty: next_ty_var(),
+                        yield_ty: next_ty_var(),
+                        witness: next_ty_var(),
+                    },
+                )
+                .args,
+            );
+
+            user_provided_sig = self.tcx().mk_fn_sig(
+                user_provided_sig.inputs().iter().copied(),
+                output_ty,
+                user_provided_sig.c_variadic,
+                user_provided_sig.unsafety,
+                user_provided_sig.abi,
+            );
+        }
+
         let is_coroutine_with_implicit_resume_ty = self.tcx().is_coroutine(mir_def_id.to_def_id())
             && user_provided_sig.inputs().is_empty();
 
diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs
index 59c4d9a6c78..cdb187e0776 100644
--- a/compiler/rustc_borrowck/src/type_check/mod.rs
+++ b/compiler/rustc_borrowck/src/type_check/mod.rs
@@ -808,6 +808,14 @@ impl<'a, 'b, 'tcx> TypeVerifier<'a, 'b, 'tcx> {
                         }),
                     };
                 }
+                ty::CoroutineClosure(_, args) => {
+                    return match args.as_coroutine_closure().upvar_tys().get(field.index()) {
+                        Some(&ty) => Ok(ty),
+                        None => Err(FieldAccessError::OutOfRange {
+                            field_count: args.as_coroutine_closure().upvar_tys().len(),
+                        }),
+                    };
+                }
                 ty::Coroutine(_, args) => {
                     // Only prefix fields (upvars and current state) are
                     // accessible without a variant index.
@@ -1875,6 +1883,14 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                     }),
                 }
             }
+            AggregateKind::CoroutineClosure(_, args) => {
+                match args.as_coroutine_closure().upvar_tys().get(field_index.as_usize()) {
+                    Some(ty) => Ok(*ty),
+                    None => Err(FieldAccessError::OutOfRange {
+                        field_count: args.as_coroutine_closure().upvar_tys().len(),
+                    }),
+                }
+            }
             AggregateKind::Array(ty) => Ok(ty),
             AggregateKind::Tuple => {
                 unreachable!("This should have been covered in check_rvalues");
@@ -2478,6 +2494,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                 AggregateKind::Tuple => None,
                 AggregateKind::Closure(_, _) => None,
                 AggregateKind::Coroutine(_, _) => None,
+                AggregateKind::CoroutineClosure(_, _) => None,
             },
         }
     }
@@ -2705,7 +2722,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
             // desugaring. A closure gets desugared to a struct, and
             // these extra requirements are basically like where
             // clauses on the struct.
-            AggregateKind::Closure(def_id, args) | AggregateKind::Coroutine(def_id, args) => (
+            AggregateKind::Closure(def_id, args)
+            | AggregateKind::CoroutineClosure(def_id, args)
+            | AggregateKind::Coroutine(def_id, args) => (
                 def_id,
                 self.prove_closure_bounds(
                     tcx,
@@ -2754,10 +2773,15 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
         let typeck_root_args = ty::GenericArgs::identity_for_item(tcx, typeck_root_def_id);
 
         let parent_args = match tcx.def_kind(def_id) {
-            DefKind::Closure if tcx.is_coroutine(def_id.to_def_id()) => {
-                args.as_coroutine().parent_args()
+            // We don't want to dispatch on 3 different kind of closures here, so take
+            // advantage of the fact that the `parent_args` is the same length as the
+            // `typeck_root_args`.
+            DefKind::Closure => {
+                // FIXME(async_closures): It may be useful to add a debug assert here
+                // to actually call `type_of` and check the `parent_args` are the same
+                // length as the `typeck_root_args`.
+                &args[..typeck_root_args.len()]
             }
-            DefKind::Closure => args.as_closure().parent_args(),
             DefKind::InlineConst => args.as_inline_const().parent_args(),
             other => bug!("unexpected item {:?}", other),
         };
diff --git a/compiler/rustc_borrowck/src/universal_regions.rs b/compiler/rustc_borrowck/src/universal_regions.rs
index ae8a135f090..111eaaf60f7 100644
--- a/compiler/rustc_borrowck/src/universal_regions.rs
+++ b/compiler/rustc_borrowck/src/universal_regions.rs
@@ -97,6 +97,13 @@ pub enum DefiningTy<'tcx> {
     /// `ClosureArgs::coroutine_return_ty`.
     Coroutine(DefId, GenericArgsRef<'tcx>),
 
+    /// The MIR is a special kind of closure that returns coroutines.
+    ///
+    /// See the documentation on `CoroutineClosureSignature` for details
+    /// on how to construct the callable signature of the coroutine from
+    /// its args.
+    CoroutineClosure(DefId, GenericArgsRef<'tcx>),
+
     /// The MIR is a fn item with the given `DefId` and args. The signature
     /// of the function can be bound then with the `fn_sig` query.
     FnDef(DefId, GenericArgsRef<'tcx>),
@@ -119,6 +126,7 @@ impl<'tcx> DefiningTy<'tcx> {
     pub fn upvar_tys(self) -> &'tcx ty::List<Ty<'tcx>> {
         match self {
             DefiningTy::Closure(_, args) => args.as_closure().upvar_tys(),
+            DefiningTy::CoroutineClosure(_, args) => args.as_coroutine_closure().upvar_tys(),
             DefiningTy::Coroutine(_, args) => args.as_coroutine().upvar_tys(),
             DefiningTy::FnDef(..) | DefiningTy::Const(..) | DefiningTy::InlineConst(..) => {
                 ty::List::empty()
@@ -131,7 +139,9 @@ impl<'tcx> DefiningTy<'tcx> {
     /// user's code.
     pub fn implicit_inputs(self) -> usize {
         match self {
-            DefiningTy::Closure(..) | DefiningTy::Coroutine(..) => 1,
+            DefiningTy::Closure(..)
+            | DefiningTy::CoroutineClosure(..)
+            | DefiningTy::Coroutine(..) => 1,
             DefiningTy::FnDef(..) | DefiningTy::Const(..) | DefiningTy::InlineConst(..) => 0,
         }
     }
@@ -147,6 +157,7 @@ impl<'tcx> DefiningTy<'tcx> {
     pub fn def_id(&self) -> DefId {
         match *self {
             DefiningTy::Closure(def_id, ..)
+            | DefiningTy::CoroutineClosure(def_id, ..)
             | DefiningTy::Coroutine(def_id, ..)
             | DefiningTy::FnDef(def_id, ..)
             | DefiningTy::Const(def_id, ..)
@@ -355,6 +366,9 @@ impl<'tcx> UniversalRegions<'tcx> {
                     err.note(format!("late-bound region is {:?}", self.to_region_vid(r)));
                 });
             }
+            DefiningTy::CoroutineClosure(..) => {
+                todo!()
+            }
             DefiningTy::Coroutine(def_id, args) => {
                 let v = with_no_trimmed_paths!(
                     args[tcx.generics_of(def_id).parent_count..]
@@ -568,6 +582,9 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
                 match *defining_ty.kind() {
                     ty::Closure(def_id, args) => DefiningTy::Closure(def_id, args),
                     ty::Coroutine(def_id, args) => DefiningTy::Coroutine(def_id, args),
+                    ty::CoroutineClosure(def_id, args) => {
+                        DefiningTy::CoroutineClosure(def_id, args)
+                    }
                     ty::FnDef(def_id, args) => DefiningTy::FnDef(def_id, args),
                     _ => span_bug!(
                         tcx.def_span(self.mir_def),
@@ -623,6 +640,7 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
         let identity_args = GenericArgs::identity_for_item(tcx, typeck_root_def_id);
         let fr_args = match defining_ty {
             DefiningTy::Closure(_, args)
+            | DefiningTy::CoroutineClosure(_, args)
             | DefiningTy::Coroutine(_, args)
             | DefiningTy::InlineConst(_, args) => {
                 // In the case of closures, we rely on the fact that
@@ -702,6 +720,55 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
                 ty::Binder::dummy(inputs_and_output)
             }
 
+            // Construct the signature of the CoroutineClosure for the purposes of borrowck.
+            // This is pretty straightforward -- we:
+            // 1. first grab the `coroutine_closure_sig`,
+            // 2. compute the self type (`&`/`&mut`/no borrow),
+            // 3. flatten the tupled_input_tys,
+            // 4. construct the correct generator type to return with
+            //    `CoroutineClosureSignature::to_coroutine_given_kind_and_upvars`.
+            // Then we wrap it all up into a list of inputs and output.
+            DefiningTy::CoroutineClosure(def_id, args) => {
+                assert_eq!(self.mir_def.to_def_id(), def_id);
+                let closure_sig = args.as_coroutine_closure().coroutine_closure_sig();
+                let bound_vars = tcx.mk_bound_variable_kinds_from_iter(
+                    closure_sig
+                        .bound_vars()
+                        .iter()
+                        .chain(iter::once(ty::BoundVariableKind::Region(ty::BrEnv))),
+                );
+                let br = ty::BoundRegion {
+                    var: ty::BoundVar::from_usize(bound_vars.len() - 1),
+                    kind: ty::BrEnv,
+                };
+                let env_region = ty::Region::new_bound(tcx, ty::INNERMOST, br);
+                let closure_kind = args.as_coroutine_closure().kind();
+
+                let closure_ty = tcx.closure_env_ty(
+                    Ty::new_coroutine_closure(tcx, def_id, args),
+                    closure_kind,
+                    env_region,
+                );
+
+                let inputs = closure_sig.skip_binder().tupled_inputs_ty.tuple_fields();
+                let output = closure_sig.skip_binder().to_coroutine_given_kind_and_upvars(
+                    tcx,
+                    args.as_coroutine_closure().parent_args(),
+                    tcx.coroutine_for_closure(def_id),
+                    closure_kind,
+                    env_region,
+                    args.as_coroutine_closure().tupled_upvars_ty(),
+                    args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
+                );
+
+                ty::Binder::bind_with_vars(
+                    tcx.mk_type_list_from_iter(
+                        iter::once(closure_ty).chain(inputs).chain(iter::once(output)),
+                    ),
+                    bound_vars,
+                )
+            }
+
             DefiningTy::FnDef(def_id, _) => {
                 let sig = tcx.fn_sig(def_id).instantiate_identity();
                 let sig = indices.fold_to_region_vids(tcx, sig);