about summary refs log tree commit diff
path: root/compiler/rustc_const_eval/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-02-06 15:04:01 +0000
committerbors <bors@rust-lang.org>2024-02-06 15:04:01 +0000
commit4a2fe4491ea616983a0cf0cbbd145a39768f4e7a (patch)
tree3ee31f8af96390f25ff12e1772aa224ba09a4828 /compiler/rustc_const_eval/src
parent037f515caf46846d2bffae55883eebc6c1ccb861 (diff)
parented7fca1f8805b4348b801f23f444e0dda42f7aed (diff)
downloadrust-4a2fe4491ea616983a0cf0cbbd145a39768f4e7a.tar.gz
rust-4a2fe4491ea616983a0cf0cbbd145a39768f4e7a.zip
Auto merge of #120361 - compiler-errors:async-closures, r=oli-obk
Rework support for async closures; allow them to return futures that borrow from the closure's captures

This PR implements a new lowering for async closures via `TyKind::CoroutineClosure` which handles the curious relationship between the closure and the coroutine that it returns.

I wrote up a bunch in [this hackmd](https://hackmd.io/`@compiler-errors/S1HvqQxca)` which will be copied to the dev guide after this PR lands, and hopefully left sufficient comments in the source code explaining why this change is as large as it is.

This also necessitates that they begin implementing the `AsyncFn`-family of traits, rather than the `Fn`-family of traits -- if you need `Fn` implementations, you should probably use the non-sugar `|| async {}` syntax instead.

Notably this PR does not yet implement `async Fn()` syntax sugar for bounds, but I expect to add those soon (**edit:** #120392). For now, users must use `AsyncFn()` traits directly, which necessitates adding the `async_fn_traits` feature gate as well. I will add this as a follow-up very soon.

r? oli-obk

This is based on top of #120322, but that PR is minimal.
Diffstat (limited to 'compiler/rustc_const_eval/src')
-rw-r--r--compiler/rustc_const_eval/src/const_eval/valtrees.rs2
-rw-r--r--compiler/rustc_const_eval/src/interpret/eval_context.rs1
-rw-r--r--compiler/rustc_const_eval/src/interpret/intrinsics.rs1
-rw-r--r--compiler/rustc_const_eval/src/interpret/terminator.rs2
-rw-r--r--compiler/rustc_const_eval/src/interpret/validity.rs1
-rw-r--r--compiler/rustc_const_eval/src/transform/validate.rs23
-rw-r--r--compiler/rustc_const_eval/src/util/type_name.rs1
7 files changed, 31 insertions, 0 deletions
diff --git a/compiler/rustc_const_eval/src/const_eval/valtrees.rs b/compiler/rustc_const_eval/src/const_eval/valtrees.rs
index 12544f5b029..5c2bf4626c4 100644
--- a/compiler/rustc_const_eval/src/const_eval/valtrees.rs
+++ b/compiler/rustc_const_eval/src/const_eval/valtrees.rs
@@ -172,6 +172,7 @@ pub(crate) fn const_to_valtree_inner<'tcx>(
         | ty::Infer(_)
         // FIXME(oli-obk): we can probably encode closures just like structs
         | ty::Closure(..)
+        | ty::CoroutineClosure(..)
         | ty::Coroutine(..)
         | ty::CoroutineWitness(..) => Err(ValTreeCreationError::NonSupportedType),
     }
@@ -301,6 +302,7 @@ pub fn valtree_to_const_value<'tcx>(
         | ty::Placeholder(..)
         | ty::Infer(_)
         | ty::Closure(..)
+        | ty::CoroutineClosure(..)
         | ty::Coroutine(..)
         | ty::CoroutineWitness(..)
         | ty::FnPtr(_)
diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs
index c14bd142efa..dd989ab80fd 100644
--- a/compiler/rustc_const_eval/src/interpret/eval_context.rs
+++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs
@@ -1007,6 +1007,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                 | ty::CoroutineWitness(..)
                 | ty::Array(..)
                 | ty::Closure(..)
+                | ty::CoroutineClosure(..)
                 | ty::Never
                 | ty::Error(_) => true,
 
diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs
index 1e9e7d94596..7991f90b815 100644
--- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs
+++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs
@@ -85,6 +85,7 @@ pub(crate) fn eval_nullary_intrinsic<'tcx>(
             | ty::FnPtr(_)
             | ty::Dynamic(_, _, _)
             | ty::Closure(_, _)
+            | ty::CoroutineClosure(_, _)
             | ty::Coroutine(_, _)
             | ty::CoroutineWitness(..)
             | ty::Never
diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs
index b7ffb4a16fc..85a2e4778d2 100644
--- a/compiler/rustc_const_eval/src/interpret/terminator.rs
+++ b/compiler/rustc_const_eval/src/interpret/terminator.rs
@@ -545,6 +545,8 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
             ty::InstanceDef::VTableShim(..)
             | ty::InstanceDef::ReifyShim(..)
             | ty::InstanceDef::ClosureOnceShim { .. }
+            | ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
+            | ty::InstanceDef::CoroutineKindShim { .. }
             | ty::InstanceDef::FnPtrShim(..)
             | ty::InstanceDef::DropGlue(..)
             | ty::InstanceDef::CloneShim(..)
diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs
index b5cd3259520..811c2c3c208 100644
--- a/compiler/rustc_const_eval/src/interpret/validity.rs
+++ b/compiler/rustc_const_eval/src/interpret/validity.rs
@@ -644,6 +644,7 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, '
             | ty::Str
             | ty::Dynamic(..)
             | ty::Closure(..)
+            | ty::CoroutineClosure(..)
             | ty::Coroutine(..) => Ok(false),
             // Some types only occur during typechecking, they have no layout.
             // We should not see them here and we could not check them anyway.
diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs
index 21bdb66a276..c4542aaa7b2 100644
--- a/compiler/rustc_const_eval/src/transform/validate.rs
+++ b/compiler/rustc_const_eval/src/transform/validate.rs
@@ -58,6 +58,7 @@ impl<'tcx> MirPass<'tcx> for Validator {
             let body_abi = match body_ty.kind() {
                 ty::FnDef(..) => body_ty.fn_sig(tcx).abi(),
                 ty::Closure(..) => Abi::RustCall,
+                ty::CoroutineClosure(..) => Abi::RustCall,
                 ty::Coroutine(..) => Abi::Rust,
                 _ => {
                     span_bug!(body.span, "unexpected body ty: {:?} phase {:?}", body_ty, mir_phase)
@@ -665,6 +666,14 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
                         };
                         check_equal(self, location, f_ty);
                     }
+                    ty::CoroutineClosure(_, args) => {
+                        let args = args.as_coroutine_closure();
+                        let Some(&f_ty) = args.upvar_tys().get(f.as_usize()) else {
+                            fail_out_of_bounds(self, location);
+                            return;
+                        };
+                        check_equal(self, location, f_ty);
+                    }
                     &ty::Coroutine(def_id, args) => {
                         let f_ty = if let Some(var) = parent_ty.variant_index {
                             let gen_body = if def_id == self.body.source.def_id() {
@@ -861,6 +870,20 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
                         }
                     }
                 }
+                AggregateKind::CoroutineClosure(_, args) => {
+                    let upvars = args.as_coroutine_closure().upvar_tys();
+                    if upvars.len() != fields.len() {
+                        self.fail(
+                            location,
+                            "coroutine-closure has the wrong number of initialized fields",
+                        );
+                    }
+                    for (src, dest) in std::iter::zip(fields, upvars) {
+                        if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) {
+                            self.fail(location, "coroutine-closure field has the wrong type");
+                        }
+                    }
+                }
             },
             Rvalue::Ref(_, BorrowKind::Fake, _) => {
                 if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) {
diff --git a/compiler/rustc_const_eval/src/util/type_name.rs b/compiler/rustc_const_eval/src/util/type_name.rs
index 976e42ad768..2b80623ab45 100644
--- a/compiler/rustc_const_eval/src/util/type_name.rs
+++ b/compiler/rustc_const_eval/src/util/type_name.rs
@@ -51,6 +51,7 @@ impl<'tcx> Printer<'tcx> for AbsolutePathPrinter<'tcx> {
             | ty::FnDef(def_id, args)
             | ty::Alias(ty::Projection | ty::Opaque, ty::AliasTy { def_id, args, .. })
             | ty::Closure(def_id, args)
+            | ty::CoroutineClosure(def_id, args)
             | ty::Coroutine(def_id, args) => self.print_def_path(def_id, args),
             ty::Foreign(def_id) => self.print_def_path(def_id, &[]),