about summary refs log tree commit diff
path: root/compiler/rustc_middle/src
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-01-24 22:27:25 +0000
committerMichael Goulet <michael@errs.io>2024-02-06 02:22:58 +0000
commita82bae2172499864c12a1d0b412931ad884911f7 (patch)
treec7299fdfd83be3818fcffdb86639146c9d29bb69 /compiler/rustc_middle/src
parentc567eddec2c628d4f13707866731e1b2013ad236 (diff)
downloadrust-a82bae2172499864c12a1d0b412931ad884911f7.tar.gz
rust-a82bae2172499864c12a1d0b412931ad884911f7.zip
Teach typeck/borrowck/solvers how to deal with async closures
Diffstat (limited to 'compiler/rustc_middle/src')
-rw-r--r--compiler/rustc_middle/src/middle/lang_items.rs15
-rw-r--r--compiler/rustc_middle/src/query/mod.rs5
-rw-r--r--compiler/rustc_middle/src/traits/select.rs7
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs7
-rw-r--r--compiler/rustc_middle/src/ty/sty.rs128
5 files changed, 156 insertions, 6 deletions
diff --git a/compiler/rustc_middle/src/middle/lang_items.rs b/compiler/rustc_middle/src/middle/lang_items.rs
index f92c72c8a58..a4e193ba2c9 100644
--- a/compiler/rustc_middle/src/middle/lang_items.rs
+++ b/compiler/rustc_middle/src/middle/lang_items.rs
@@ -23,7 +23,7 @@ impl<'tcx> TyCtxt<'tcx> {
         })
     }
 
-    /// Given a [`DefId`] of a [`Fn`], [`FnMut`] or [`FnOnce`] traits,
+    /// Given a [`DefId`] of one of the [`Fn`], [`FnMut`] or [`FnOnce`] traits,
     /// returns a corresponding [`ty::ClosureKind`].
     /// For any other [`DefId`] return `None`.
     pub fn fn_trait_kind_from_def_id(self, id: DefId) -> Option<ty::ClosureKind> {
@@ -36,6 +36,19 @@ impl<'tcx> TyCtxt<'tcx> {
         }
     }
 
+    /// Given a [`DefId`] of one of the `AsyncFn`, `AsyncFnMut` or `AsyncFnOnce` traits,
+    /// returns a corresponding [`ty::ClosureKind`].
+    /// For any other [`DefId`] return `None`.
+    pub fn async_fn_trait_kind_from_def_id(self, id: DefId) -> Option<ty::ClosureKind> {
+        let items = self.lang_items();
+        match Some(id) {
+            x if x == items.async_fn_trait() => Some(ty::ClosureKind::Fn),
+            x if x == items.async_fn_mut_trait() => Some(ty::ClosureKind::FnMut),
+            x if x == items.async_fn_once_trait() => Some(ty::ClosureKind::FnOnce),
+            _ => None,
+        }
+    }
+
     /// Given a [`ty::ClosureKind`], get the [`DefId`] of its corresponding `Fn`-family
     /// trait, if it is defined.
     pub fn fn_trait_kind_to_def_id(self, kind: ty::ClosureKind) -> Option<DefId> {
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index 2438f826441..f9ab32b16f5 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -755,6 +755,11 @@ rustc_queries! {
         separate_provide_extern
     }
 
+    query coroutine_for_closure(def_id: DefId) -> DefId {
+        desc { |_tcx| "TODO" }
+        separate_provide_extern
+    }
+
     /// Gets a map with the variance of every item; use `variances_of` instead.
     query crate_variances(_: ()) -> &'tcx ty::CrateVariancesMap<'tcx> {
         arena_cache
diff --git a/compiler/rustc_middle/src/traits/select.rs b/compiler/rustc_middle/src/traits/select.rs
index 64f4af08e12..4e11575cf98 100644
--- a/compiler/rustc_middle/src/traits/select.rs
+++ b/compiler/rustc_middle/src/traits/select.rs
@@ -134,6 +134,13 @@ pub enum SelectionCandidate<'tcx> {
         is_const: bool,
     },
 
+    /// Implementation of an `AsyncFn`-family trait by one of the anonymous types
+    /// generated for an `async ||` expression.
+    AsyncClosureCandidate,
+
+    // TODO:
+    AsyncFnKindHelperCandidate,
+
     /// Implementation of a `Coroutine` trait by one of the anonymous types
     /// generated for a coroutine.
     CoroutineCandidate,
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 8e51db82f95..4348d01eff5 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -105,9 +105,10 @@ pub use self::region::{
 pub use self::rvalue_scopes::RvalueScopes;
 pub use self::sty::{
     AliasTy, Article, Binder, BoundTy, BoundTyKind, BoundVariableKind, CanonicalPolyFnSig,
-    ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, FnSig, GenSig,
-    InlineConstArgs, InlineConstArgsParts, ParamConst, ParamTy, PolyFnSig, TyKind, TypeAndMut,
-    UpvarArgs, VarianceDiagInfo,
+    ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, CoroutineClosureArgs,
+    CoroutineClosureArgsParts, CoroutineClosureSignature, FnSig, GenSig, InlineConstArgs,
+    InlineConstArgsParts, ParamConst, ParamTy, PolyFnSig, TyKind, TypeAndMut, UpvarArgs,
+    VarianceDiagInfo,
 };
 pub use self::trait_def::TraitDef;
 pub use self::typeck_results::{
diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs
index c047f6a0521..e2a2e24f06d 100644
--- a/compiler/rustc_middle/src/ty/sty.rs
+++ b/compiler/rustc_middle/src/ty/sty.rs
@@ -36,6 +36,7 @@ use rustc_type_ir::TyKind as IrTyKind;
 use rustc_type_ir::TyKind::*;
 use rustc_type_ir::TypeAndMut as IrTypeAndMut;
 
+use super::fold::FnMutDelegate;
 use super::GenericParamDefKind;
 
 // Re-export and re-parameterize some `I = TyCtxt<'tcx>` types here
@@ -351,6 +352,27 @@ impl<'tcx> CoroutineClosureArgs<'tcx> {
         self.split().signature_parts_ty
     }
 
+    pub fn coroutine_closure_sig(self) -> ty::Binder<'tcx, CoroutineClosureSignature<'tcx>> {
+        let interior = self.coroutine_witness_ty();
+        let ty::FnPtr(sig) = self.signature_parts_ty().kind() else { bug!() };
+        sig.map_bound(|sig| {
+            let [resume_ty, tupled_inputs_ty] = *sig.inputs() else {
+                bug!();
+            };
+            let [yield_ty, return_ty] = **sig.output().tuple_fields() else { bug!() };
+            CoroutineClosureSignature {
+                interior,
+                tupled_inputs_ty,
+                resume_ty,
+                yield_ty,
+                return_ty,
+                c_variadic: sig.c_variadic,
+                unsafety: sig.unsafety,
+                abi: sig.abi,
+            }
+        })
+    }
+
     pub fn coroutine_captures_by_ref_ty(self) -> Ty<'tcx> {
         self.split().coroutine_captures_by_ref_ty
     }
@@ -360,6 +382,103 @@ impl<'tcx> CoroutineClosureArgs<'tcx> {
     }
 }
 
+#[derive(Copy, Clone, PartialEq, Eq, Debug, TypeFoldable, TypeVisitable)]
+pub struct CoroutineClosureSignature<'tcx> {
+    pub interior: Ty<'tcx>,
+    pub tupled_inputs_ty: Ty<'tcx>,
+    pub resume_ty: Ty<'tcx>,
+    pub yield_ty: Ty<'tcx>,
+    pub return_ty: Ty<'tcx>,
+    pub c_variadic: bool,
+    pub unsafety: hir::Unsafety,
+    pub abi: abi::Abi,
+}
+
+impl<'tcx> CoroutineClosureSignature<'tcx> {
+    pub fn to_coroutine(
+        self,
+        tcx: TyCtxt<'tcx>,
+        parent_args: &'tcx [GenericArg<'tcx>],
+        coroutine_def_id: DefId,
+        tupled_upvars_ty: Ty<'tcx>,
+    ) -> Ty<'tcx> {
+        let coroutine_args = ty::CoroutineArgs::new(
+            tcx,
+            ty::CoroutineArgsParts {
+                parent_args,
+                resume_ty: self.resume_ty,
+                yield_ty: self.yield_ty,
+                return_ty: self.return_ty,
+                witness: self.interior,
+                tupled_upvars_ty,
+            },
+        );
+
+        Ty::new_coroutine(tcx, coroutine_def_id, coroutine_args.args)
+    }
+
+    pub fn to_coroutine_given_kind_and_upvars(
+        self,
+        tcx: TyCtxt<'tcx>,
+        parent_args: &'tcx [GenericArg<'tcx>],
+        coroutine_def_id: DefId,
+        closure_kind: ty::ClosureKind,
+        env_region: ty::Region<'tcx>,
+        closure_tupled_upvars_ty: Ty<'tcx>,
+        coroutine_captures_by_ref_ty: Ty<'tcx>,
+    ) -> Ty<'tcx> {
+        let tupled_upvars_ty = Self::tupled_upvars_by_closure_kind(
+            tcx,
+            closure_kind,
+            self.tupled_inputs_ty,
+            closure_tupled_upvars_ty,
+            coroutine_captures_by_ref_ty,
+            env_region,
+        );
+
+        self.to_coroutine(tcx, parent_args, coroutine_def_id, tupled_upvars_ty)
+    }
+
+    /// Given a closure kind, compute the tupled upvars that the given coroutine would return.
+    pub fn tupled_upvars_by_closure_kind(
+        tcx: TyCtxt<'tcx>,
+        kind: ty::ClosureKind,
+        tupled_inputs_ty: Ty<'tcx>,
+        closure_tupled_upvars_ty: Ty<'tcx>,
+        coroutine_captures_by_ref_ty: Ty<'tcx>,
+        env_region: ty::Region<'tcx>,
+    ) -> Ty<'tcx> {
+        match kind {
+            ty::ClosureKind::Fn | ty::ClosureKind::FnMut => {
+                let ty::FnPtr(sig) = *coroutine_captures_by_ref_ty.kind() else {
+                    bug!();
+                };
+                let coroutine_captures_by_ref_ty = tcx.replace_escaping_bound_vars_uncached(
+                    sig.output().skip_binder(),
+                    FnMutDelegate {
+                        consts: &mut |c, t| ty::Const::new_bound(tcx, ty::INNERMOST, c, t),
+                        types: &mut |t| Ty::new_bound(tcx, ty::INNERMOST, t),
+                        regions: &mut |_| env_region,
+                    },
+                );
+                Ty::new_tup_from_iter(
+                    tcx,
+                    tupled_inputs_ty
+                        .tuple_fields()
+                        .iter()
+                        .chain(coroutine_captures_by_ref_ty.tuple_fields()),
+                )
+            }
+            ty::ClosureKind::FnOnce => Ty::new_tup_from_iter(
+                tcx,
+                tupled_inputs_ty
+                    .tuple_fields()
+                    .iter()
+                    .chain(closure_tupled_upvars_ty.tuple_fields()),
+            ),
+        }
+    }
+}
 /// Similar to `ClosureArgs`; see the above documentation for more.
 #[derive(Copy, Clone, PartialEq, Eq, Debug, TypeFoldable, TypeVisitable)]
 pub struct CoroutineArgs<'tcx> {
@@ -1495,7 +1614,7 @@ impl<'tcx> Ty<'tcx> {
     ) -> Ty<'tcx> {
         debug_assert_eq!(
             closure_args.len(),
-            tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 3,
+            tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 5,
             "closure constructed with incorrect substitutions"
         );
         Ty::new(tcx, CoroutineClosure(def_id, closure_args))
@@ -1836,6 +1955,11 @@ impl<'tcx> Ty<'tcx> {
     }
 
     #[inline]
+    pub fn is_coroutine_closure(self) -> bool {
+        matches!(self.kind(), CoroutineClosure(..))
+    }
+
+    #[inline]
     pub fn is_integral(self) -> bool {
         matches!(self.kind(), Infer(IntVar(_)) | Int(_) | Uint(_))
     }
@@ -2144,7 +2268,7 @@ impl<'tcx> Ty<'tcx> {
 
             // "Bound" types appear in canonical queries when the
             // closure type is not yet known
-            Bound(..) | Infer(_) => None,
+            Bound(..) | Param(_) | Infer(_) => None,
 
             Error(_) => Some(ty::ClosureKind::Fn),