about summary refs log tree commit diff
path: root/compiler/rustc_borrowck/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_borrowck/src')
-rw-r--r--compiler/rustc_borrowck/src/diagnostics/region_name.rs21
-rw-r--r--compiler/rustc_borrowck/src/lib.rs1
-rw-r--r--compiler/rustc_borrowck/src/type_check/input_output.rs70
-rw-r--r--compiler/rustc_borrowck/src/type_check/mod.rs15
-rw-r--r--compiler/rustc_borrowck/src/universal_regions.rs58
5 files changed, 143 insertions, 22 deletions
diff --git a/compiler/rustc_borrowck/src/diagnostics/region_name.rs b/compiler/rustc_borrowck/src/diagnostics/region_name.rs
index fc52306e10d..c91b8eaab05 100644
--- a/compiler/rustc_borrowck/src/diagnostics/region_name.rs
+++ b/compiler/rustc_borrowck/src/diagnostics/region_name.rs
@@ -324,9 +324,13 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
                 ty::BoundRegionKind::BrEnv => {
                     let def_ty = self.regioncx.universal_regions().defining_ty;
 
-                    let DefiningTy::Closure(_, args) = def_ty else {
-                        // Can't have BrEnv in functions, constants or coroutines.
-                        bug!("BrEnv outside of closure.");
+                    let closure_kind = match def_ty {
+                        DefiningTy::Closure(_, args) => args.as_closure().kind(),
+                        DefiningTy::CoroutineClosure(_, args) => args.as_coroutine_closure().kind(),
+                        _ => {
+                            // Can't have BrEnv in functions, constants or coroutines.
+                            bug!("BrEnv outside of closure.");
+                        }
                     };
                     let hir::ExprKind::Closure(&hir::Closure { fn_decl_span, .. }) =
                         tcx.hir().expect_expr(self.mir_hir_id()).kind
@@ -334,21 +338,18 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
                         bug!("Closure is not defined by a closure expr");
                     };
                     let region_name = self.synthesize_region_name();
-
-                    let closure_kind_ty = args.as_closure().kind_ty();
-                    let note = match closure_kind_ty.to_opt_closure_kind() {
-                        Some(ty::ClosureKind::Fn) => {
+                    let note = match closure_kind {
+                        ty::ClosureKind::Fn => {
                             "closure implements `Fn`, so references to captured variables \
                              can't escape the closure"
                         }
-                        Some(ty::ClosureKind::FnMut) => {
+                        ty::ClosureKind::FnMut => {
                             "closure implements `FnMut`, so references to captured variables \
                              can't escape the closure"
                         }
-                        Some(ty::ClosureKind::FnOnce) => {
+                        ty::ClosureKind::FnOnce => {
                             bug!("BrEnv in a `FnOnce` closure");
                         }
-                        None => bug!("Closure kind not inferred in borrow check"),
                     };
 
                     Some(RegionName {
diff --git a/compiler/rustc_borrowck/src/lib.rs b/compiler/rustc_borrowck/src/lib.rs
index eaf9fc45837..bb64571889b 100644
--- a/compiler/rustc_borrowck/src/lib.rs
+++ b/compiler/rustc_borrowck/src/lib.rs
@@ -3,6 +3,7 @@
 #![allow(internal_features)]
 #![feature(rustdoc_internals)]
 #![doc(rust_logo)]
+#![feature(assert_matches)]
 #![feature(associated_type_bounds)]
 #![feature(box_patterns)]
 #![feature(let_chains)]
diff --git a/compiler/rustc_borrowck/src/type_check/input_output.rs b/compiler/rustc_borrowck/src/type_check/input_output.rs
index 59518f68ab1..a3e5088ee09 100644
--- a/compiler/rustc_borrowck/src/type_check/input_output.rs
+++ b/compiler/rustc_borrowck/src/type_check/input_output.rs
@@ -7,13 +7,18 @@
 //! `RETURN_PLACE` the MIR arguments) are always fully normalized (and
 //! contain revealed `impl Trait` values).
 
+use std::assert_matches::assert_matches;
+
 use itertools::Itertools;
-use rustc_infer::infer::BoundRegionConversionTime;
+use rustc_hir as hir;
+use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
+use rustc_infer::infer::{BoundRegionConversionTime, RegionVariableOrigin};
 use rustc_middle::mir::*;
 use rustc_middle::ty::{self, Ty};
 use rustc_span::Span;
 
-use crate::universal_regions::UniversalRegions;
+use crate::renumber::RegionCtxt;
+use crate::universal_regions::{DefiningTy, UniversalRegions};
 
 use super::{Locations, TypeChecker};
 
@@ -23,9 +28,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
     #[instrument(skip(self, body), level = "debug")]
     pub(super) fn check_signature_annotation(&mut self, body: &Body<'tcx>) {
         let mir_def_id = body.source.def_id().expect_local();
+
         if !self.tcx().is_closure_or_coroutine(mir_def_id.to_def_id()) {
             return;
         }
+
         let user_provided_poly_sig = self.tcx().closure_user_provided_sig(mir_def_id);
 
         // Instantiate the canonicalized variables from user-provided signature
@@ -34,12 +41,69 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
         // so that they represent the view from "inside" the closure.
         let user_provided_sig = self
             .instantiate_canonical_with_fresh_inference_vars(body.span, &user_provided_poly_sig);
-        let user_provided_sig = self.infcx.instantiate_binder_with_fresh_vars(
+        let mut user_provided_sig = self.infcx.instantiate_binder_with_fresh_vars(
             body.span,
             BoundRegionConversionTime::FnCall,
             user_provided_sig,
         );
 
+        // FIXME(async_closures): We must apply the same transformation to our
+        // signature here as we do during closure checking.
+        if let DefiningTy::CoroutineClosure(_, args) =
+            self.borrowck_context.universal_regions.defining_ty
+        {
+            assert_matches!(
+                self.tcx().coroutine_kind(self.tcx().coroutine_for_closure(mir_def_id)),
+                Some(hir::CoroutineKind::Desugared(
+                    hir::CoroutineDesugaring::Async,
+                    hir::CoroutineSource::Closure
+                )),
+                "this needs to be modified if we're lowering non-async closures"
+            );
+            let args = args.as_coroutine_closure();
+            let tupled_upvars_ty = ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind(
+                self.tcx(),
+                args.kind(),
+                Ty::new_tup(self.tcx(), user_provided_sig.inputs()),
+                args.tupled_upvars_ty(),
+                args.coroutine_captures_by_ref_ty(),
+                self.infcx.next_region_var(RegionVariableOrigin::MiscVariable(body.span), || {
+                    RegionCtxt::Unknown
+                }),
+            );
+
+            let next_ty_var = || {
+                self.infcx.next_ty_var(TypeVariableOrigin {
+                    span: body.span,
+                    kind: TypeVariableOriginKind::MiscVariable,
+                })
+            };
+            let output_ty = Ty::new_coroutine(
+                self.tcx(),
+                self.tcx().coroutine_for_closure(mir_def_id),
+                ty::CoroutineArgs::new(
+                    self.tcx(),
+                    ty::CoroutineArgsParts {
+                        parent_args: args.parent_args(),
+                        resume_ty: next_ty_var(),
+                        yield_ty: next_ty_var(),
+                        witness: next_ty_var(),
+                        return_ty: user_provided_sig.output(),
+                        tupled_upvars_ty: tupled_upvars_ty,
+                    },
+                )
+                .args,
+            );
+
+            user_provided_sig = self.tcx().mk_fn_sig(
+                user_provided_sig.inputs().iter().copied(),
+                output_ty,
+                user_provided_sig.c_variadic,
+                user_provided_sig.unsafety,
+                user_provided_sig.abi,
+            );
+        }
+
         let is_coroutine_with_implicit_resume_ty = self.tcx().is_coroutine(mir_def_id.to_def_id())
             && user_provided_sig.inputs().is_empty();
 
diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs
index 90a9844fda9..cdb187e0776 100644
--- a/compiler/rustc_borrowck/src/type_check/mod.rs
+++ b/compiler/rustc_borrowck/src/type_check/mod.rs
@@ -2773,15 +2773,14 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
         let typeck_root_args = ty::GenericArgs::identity_for_item(tcx, typeck_root_def_id);
 
         let parent_args = match tcx.def_kind(def_id) {
+            // We don't want to dispatch on 3 different kind of closures here, so take
+            // advantage of the fact that the `parent_args` is the same length as the
+            // `typeck_root_args`.
             DefKind::Closure => {
-                // FIXME(async_closures): It's kind of icky to access HIR here.
-                match tcx.hir_node_by_def_id(def_id).expect_closure().kind {
-                    hir::ClosureKind::Closure => args.as_closure().parent_args(),
-                    hir::ClosureKind::Coroutine(_) => args.as_coroutine().parent_args(),
-                    hir::ClosureKind::CoroutineClosure(_) => {
-                        args.as_coroutine_closure().parent_args()
-                    }
-                }
+                // FIXME(async_closures): It may be useful to add a debug assert here
+                // to actually call `type_of` and check the `parent_args` are the same
+                // length as the `typeck_root_args`.
+                &args[..typeck_root_args.len()]
             }
             DefKind::InlineConst => args.as_inline_const().parent_args(),
             other => bug!("unexpected item {:?}", other),
diff --git a/compiler/rustc_borrowck/src/universal_regions.rs b/compiler/rustc_borrowck/src/universal_regions.rs
index ae8a135f090..9c266d123b3 100644
--- a/compiler/rustc_borrowck/src/universal_regions.rs
+++ b/compiler/rustc_borrowck/src/universal_regions.rs
@@ -97,6 +97,10 @@ pub enum DefiningTy<'tcx> {
     /// `ClosureArgs::coroutine_return_ty`.
     Coroutine(DefId, GenericArgsRef<'tcx>),
 
+    /// The MIR is a special kind of closure that returns coroutines.
+    /// TODO: describe how to make the sig...
+    CoroutineClosure(DefId, GenericArgsRef<'tcx>),
+
     /// The MIR is a fn item with the given `DefId` and args. The signature
     /// of the function can be bound then with the `fn_sig` query.
     FnDef(DefId, GenericArgsRef<'tcx>),
@@ -119,6 +123,7 @@ impl<'tcx> DefiningTy<'tcx> {
     pub fn upvar_tys(self) -> &'tcx ty::List<Ty<'tcx>> {
         match self {
             DefiningTy::Closure(_, args) => args.as_closure().upvar_tys(),
+            DefiningTy::CoroutineClosure(_, args) => args.as_coroutine_closure().upvar_tys(),
             DefiningTy::Coroutine(_, args) => args.as_coroutine().upvar_tys(),
             DefiningTy::FnDef(..) | DefiningTy::Const(..) | DefiningTy::InlineConst(..) => {
                 ty::List::empty()
@@ -131,7 +136,9 @@ impl<'tcx> DefiningTy<'tcx> {
     /// user's code.
     pub fn implicit_inputs(self) -> usize {
         match self {
-            DefiningTy::Closure(..) | DefiningTy::Coroutine(..) => 1,
+            DefiningTy::Closure(..)
+            | DefiningTy::CoroutineClosure(..)
+            | DefiningTy::Coroutine(..) => 1,
             DefiningTy::FnDef(..) | DefiningTy::Const(..) | DefiningTy::InlineConst(..) => 0,
         }
     }
@@ -147,6 +154,7 @@ impl<'tcx> DefiningTy<'tcx> {
     pub fn def_id(&self) -> DefId {
         match *self {
             DefiningTy::Closure(def_id, ..)
+            | DefiningTy::CoroutineClosure(def_id, ..)
             | DefiningTy::Coroutine(def_id, ..)
             | DefiningTy::FnDef(def_id, ..)
             | DefiningTy::Const(def_id, ..)
@@ -355,6 +363,9 @@ impl<'tcx> UniversalRegions<'tcx> {
                     err.note(format!("late-bound region is {:?}", self.to_region_vid(r)));
                 });
             }
+            DefiningTy::CoroutineClosure(..) => {
+                todo!()
+            }
             DefiningTy::Coroutine(def_id, args) => {
                 let v = with_no_trimmed_paths!(
                     args[tcx.generics_of(def_id).parent_count..]
@@ -568,6 +579,9 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
                 match *defining_ty.kind() {
                     ty::Closure(def_id, args) => DefiningTy::Closure(def_id, args),
                     ty::Coroutine(def_id, args) => DefiningTy::Coroutine(def_id, args),
+                    ty::CoroutineClosure(def_id, args) => {
+                        DefiningTy::CoroutineClosure(def_id, args)
+                    }
                     ty::FnDef(def_id, args) => DefiningTy::FnDef(def_id, args),
                     _ => span_bug!(
                         tcx.def_span(self.mir_def),
@@ -623,6 +637,7 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
         let identity_args = GenericArgs::identity_for_item(tcx, typeck_root_def_id);
         let fr_args = match defining_ty {
             DefiningTy::Closure(_, args)
+            | DefiningTy::CoroutineClosure(_, args)
             | DefiningTy::Coroutine(_, args)
             | DefiningTy::InlineConst(_, args) => {
                 // In the case of closures, we rely on the fact that
@@ -702,6 +717,47 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
                 ty::Binder::dummy(inputs_and_output)
             }
 
+            DefiningTy::CoroutineClosure(def_id, args) => {
+                assert_eq!(self.mir_def.to_def_id(), def_id);
+                let closure_sig = args.as_coroutine_closure().coroutine_closure_sig();
+                let bound_vars = tcx.mk_bound_variable_kinds_from_iter(
+                    closure_sig
+                        .bound_vars()
+                        .iter()
+                        .chain(iter::once(ty::BoundVariableKind::Region(ty::BrEnv))),
+                );
+                let br = ty::BoundRegion {
+                    var: ty::BoundVar::from_usize(bound_vars.len() - 1),
+                    kind: ty::BrEnv,
+                };
+                let env_region = ty::Region::new_bound(tcx, ty::INNERMOST, br);
+                let closure_kind = args.as_coroutine_closure().kind();
+
+                let closure_ty = tcx.closure_env_ty(
+                    Ty::new_coroutine_closure(tcx, def_id, args),
+                    closure_kind,
+                    env_region,
+                );
+
+                let inputs = closure_sig.skip_binder().tupled_inputs_ty.tuple_fields();
+                let output = closure_sig.skip_binder().to_coroutine_given_kind_and_upvars(
+                    tcx,
+                    args.as_coroutine_closure().parent_args(),
+                    tcx.coroutine_for_closure(def_id),
+                    closure_kind,
+                    env_region,
+                    args.as_coroutine_closure().tupled_upvars_ty(),
+                    args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
+                );
+
+                ty::Binder::bind_with_vars(
+                    tcx.mk_type_list_from_iter(
+                        iter::once(closure_ty).chain(inputs).chain(iter::once(output)),
+                    ),
+                    bound_vars,
+                )
+            }
+
             DefiningTy::FnDef(def_id, _) => {
                 let sig = tcx.fn_sig(def_id).instantiate_identity();
                 let sig = indices.fold_to_region_vids(tcx, sig);