From c567eddec2c628d4f13707866731e1b2013ad236 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Wed, 24 Jan 2024 18:01:56 +0000 Subject: Add CoroutineClosure to TyKind, AggregateKind, UpvarArgs --- .../rustc_const_eval/src/const_eval/valtrees.rs | 2 ++ .../rustc_const_eval/src/interpret/eval_context.rs | 1 + .../rustc_const_eval/src/interpret/intrinsics.rs | 1 + .../rustc_const_eval/src/interpret/validity.rs | 1 + .../rustc_const_eval/src/transform/validate.rs | 22 ++++++++++++++++++++++ compiler/rustc_const_eval/src/util/type_name.rs | 1 + 6 files changed, 28 insertions(+) (limited to 'compiler/rustc_const_eval/src') diff --git a/compiler/rustc_const_eval/src/const_eval/valtrees.rs b/compiler/rustc_const_eval/src/const_eval/valtrees.rs index 12544f5b029..5c2bf4626c4 100644 --- a/compiler/rustc_const_eval/src/const_eval/valtrees.rs +++ b/compiler/rustc_const_eval/src/const_eval/valtrees.rs @@ -172,6 +172,7 @@ pub(crate) fn const_to_valtree_inner<'tcx>( | ty::Infer(_) // FIXME(oli-obk): we can probably encode closures just like structs | ty::Closure(..) + | ty::CoroutineClosure(..) | ty::Coroutine(..) | ty::CoroutineWitness(..) => Err(ValTreeCreationError::NonSupportedType), } @@ -301,6 +302,7 @@ pub fn valtree_to_const_value<'tcx>( | ty::Placeholder(..) | ty::Infer(_) | ty::Closure(..) + | ty::CoroutineClosure(..) | ty::Coroutine(..) | ty::CoroutineWitness(..) | ty::FnPtr(_) diff --git a/compiler/rustc_const_eval/src/interpret/eval_context.rs b/compiler/rustc_const_eval/src/interpret/eval_context.rs index c14bd142efa..dd989ab80fd 100644 --- a/compiler/rustc_const_eval/src/interpret/eval_context.rs +++ b/compiler/rustc_const_eval/src/interpret/eval_context.rs @@ -1007,6 +1007,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { | ty::CoroutineWitness(..) | ty::Array(..) | ty::Closure(..) + | ty::CoroutineClosure(..) | ty::Never | ty::Error(_) => true, diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs index 1e9e7d94596..7991f90b815 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs @@ -85,6 +85,7 @@ pub(crate) fn eval_nullary_intrinsic<'tcx>( | ty::FnPtr(_) | ty::Dynamic(_, _, _) | ty::Closure(_, _) + | ty::CoroutineClosure(_, _) | ty::Coroutine(_, _) | ty::CoroutineWitness(..) | ty::Never diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs index b5cd3259520..811c2c3c208 100644 --- a/compiler/rustc_const_eval/src/interpret/validity.rs +++ b/compiler/rustc_const_eval/src/interpret/validity.rs @@ -644,6 +644,7 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, ' | ty::Str | ty::Dynamic(..) | ty::Closure(..) + | ty::CoroutineClosure(..) | ty::Coroutine(..) => Ok(false), // Some types only occur during typechecking, they have no layout. // We should not see them here and we could not check them anyway. diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs index 21bdb66a276..d3230a2455d 100644 --- a/compiler/rustc_const_eval/src/transform/validate.rs +++ b/compiler/rustc_const_eval/src/transform/validate.rs @@ -665,6 +665,14 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> { }; check_equal(self, location, f_ty); } + ty::CoroutineClosure(_, args) => { + let args = args.as_coroutine_closure(); + let Some(&f_ty) = args.upvar_tys().get(f.as_usize()) else { + fail_out_of_bounds(self, location); + return; + }; + check_equal(self, location, f_ty); + } &ty::Coroutine(def_id, args) => { let f_ty = if let Some(var) = parent_ty.variant_index { let gen_body = if def_id == self.body.source.def_id() { @@ -861,6 +869,20 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> { } } } + AggregateKind::CoroutineClosure(_, args) => { + let upvars = args.as_coroutine_closure().upvar_tys(); + if upvars.len() != fields.len() { + self.fail( + location, + "coroutine-closure has the wrong number of initialized fields", + ); + } + for (src, dest) in std::iter::zip(fields, upvars) { + if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest) { + self.fail(location, "coroutine-closure field has the wrong type"); + } + } + } }, Rvalue::Ref(_, BorrowKind::Fake, _) => { if self.mir_phase >= MirPhase::Runtime(RuntimePhase::Initial) { diff --git a/compiler/rustc_const_eval/src/util/type_name.rs b/compiler/rustc_const_eval/src/util/type_name.rs index 976e42ad768..2b80623ab45 100644 --- a/compiler/rustc_const_eval/src/util/type_name.rs +++ b/compiler/rustc_const_eval/src/util/type_name.rs @@ -51,6 +51,7 @@ impl<'tcx> Printer<'tcx> for AbsolutePathPrinter<'tcx> { | ty::FnDef(def_id, args) | ty::Alias(ty::Projection | ty::Opaque, ty::AliasTy { def_id, args, .. }) | ty::Closure(def_id, args) + | ty::CoroutineClosure(def_id, args) | ty::Coroutine(def_id, args) => self.print_def_path(def_id, args), ty::Foreign(def_id) => self.print_def_path(def_id, &[]), -- cgit 1.4.1-3-g733a5 From a82bae2172499864c12a1d0b412931ad884911f7 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Wed, 24 Jan 2024 22:27:25 +0000 Subject: Teach typeck/borrowck/solvers how to deal with async closures --- compiler/rustc_ast_lowering/src/expr.rs | 8 + .../rustc_borrowck/src/diagnostics/region_name.rs | 21 +-- compiler/rustc_borrowck/src/lib.rs | 1 + .../rustc_borrowck/src/type_check/input_output.rs | 70 +++++++- compiler/rustc_borrowck/src/type_check/mod.rs | 15 +- compiler/rustc_borrowck/src/universal_regions.rs | 58 ++++++- .../rustc_const_eval/src/transform/validate.rs | 1 + compiler/rustc_hir/src/lang_items.rs | 1 + compiler/rustc_hir_analysis/src/collect.rs | 31 ++++ compiler/rustc_hir_typeck/src/callee.rs | 85 +++++++--- compiler/rustc_hir_typeck/src/closure.rs | 87 +++++++++- compiler/rustc_hir_typeck/src/upvar.rs | 64 +++++++- compiler/rustc_middle/src/middle/lang_items.rs | 15 +- compiler/rustc_middle/src/query/mod.rs | 5 + compiler/rustc_middle/src/traits/select.rs | 7 + compiler/rustc_middle/src/ty/mod.rs | 7 +- compiler/rustc_middle/src/ty/sty.rs | 128 ++++++++++++++- compiler/rustc_mir_build/src/build/mod.rs | 1 + compiler/rustc_mir_build/src/thir/cx/expr.rs | 3 + compiler/rustc_mir_build/src/thir/cx/mod.rs | 18 +- compiler/rustc_mir_dataflow/src/value_analysis.rs | 3 + .../rustc_mir_transform/src/const_prop_lint.rs | 3 +- .../src/coverage/spans/from_mir.rs | 4 +- compiler/rustc_span/src/symbol.rs | 4 + .../src/solve/assembly/mod.rs | 18 ++ .../src/solve/assembly/structural_traits.rs | 111 ++++++++++++- .../src/solve/normalizes_to/mod.rs | 113 +++++++++++++ .../rustc_trait_selection/src/solve/trait_goals.rs | 60 +++++++ .../rustc_trait_selection/src/traits/project.rs | 182 ++++++++++++++++++++- .../src/traits/select/candidate_assembly.rs | 46 ++++++ .../src/traits/select/confirmation.rs | 50 ++++++ .../rustc_trait_selection/src/traits/select/mod.rs | 12 ++ compiler/rustc_ty_utils/src/abi.rs | 36 ++++ compiler/rustc_ty_utils/src/instance.rs | 11 ++ library/core/src/ops/async_function.rs | 8 + 35 files changed, 1221 insertions(+), 66 deletions(-) (limited to 'compiler/rustc_const_eval/src') diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs index 2241b4cd8e5..9990d526bf7 100644 --- a/compiler/rustc_ast_lowering/src/expr.rs +++ b/compiler/rustc_ast_lowering/src/expr.rs @@ -1,3 +1,5 @@ +use std::assert_matches::assert_matches; + use super::errors::{ AsyncCoroutinesNotSupported, AwaitOnlyInAsyncFnAndBlocks, BaseExpressionDoubleDot, ClosureCannotBeStatic, CoroutineTooManyParameters, @@ -1027,6 +1029,12 @@ impl<'hir> LoweringContext<'_, 'hir> { ) -> hir::ExprKind<'hir> { let (binder_clause, generic_params) = self.lower_closure_binder(binder); + assert_matches!( + coroutine_kind, + CoroutineKind::Async { .. }, + "only async closures are supported currently" + ); + let body = self.with_new_scopes(fn_decl_span, |this| { let inner_decl = FnDecl { inputs: decl.inputs.clone(), output: FnRetTy::Default(fn_decl_span) }; diff --git a/compiler/rustc_borrowck/src/diagnostics/region_name.rs b/compiler/rustc_borrowck/src/diagnostics/region_name.rs index fc52306e10d..c91b8eaab05 100644 --- a/compiler/rustc_borrowck/src/diagnostics/region_name.rs +++ b/compiler/rustc_borrowck/src/diagnostics/region_name.rs @@ -324,9 +324,13 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> { ty::BoundRegionKind::BrEnv => { let def_ty = self.regioncx.universal_regions().defining_ty; - let DefiningTy::Closure(_, args) = def_ty else { - // Can't have BrEnv in functions, constants or coroutines. - bug!("BrEnv outside of closure."); + let closure_kind = match def_ty { + DefiningTy::Closure(_, args) => args.as_closure().kind(), + DefiningTy::CoroutineClosure(_, args) => args.as_coroutine_closure().kind(), + _ => { + // Can't have BrEnv in functions, constants or coroutines. + bug!("BrEnv outside of closure."); + } }; let hir::ExprKind::Closure(&hir::Closure { fn_decl_span, .. }) = tcx.hir().expect_expr(self.mir_hir_id()).kind @@ -334,21 +338,18 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> { bug!("Closure is not defined by a closure expr"); }; let region_name = self.synthesize_region_name(); - - let closure_kind_ty = args.as_closure().kind_ty(); - let note = match closure_kind_ty.to_opt_closure_kind() { - Some(ty::ClosureKind::Fn) => { + let note = match closure_kind { + ty::ClosureKind::Fn => { "closure implements `Fn`, so references to captured variables \ can't escape the closure" } - Some(ty::ClosureKind::FnMut) => { + ty::ClosureKind::FnMut => { "closure implements `FnMut`, so references to captured variables \ can't escape the closure" } - Some(ty::ClosureKind::FnOnce) => { + ty::ClosureKind::FnOnce => { bug!("BrEnv in a `FnOnce` closure"); } - None => bug!("Closure kind not inferred in borrow check"), }; Some(RegionName { diff --git a/compiler/rustc_borrowck/src/lib.rs b/compiler/rustc_borrowck/src/lib.rs index eaf9fc45837..bb64571889b 100644 --- a/compiler/rustc_borrowck/src/lib.rs +++ b/compiler/rustc_borrowck/src/lib.rs @@ -3,6 +3,7 @@ #![allow(internal_features)] #![feature(rustdoc_internals)] #![doc(rust_logo)] +#![feature(assert_matches)] #![feature(associated_type_bounds)] #![feature(box_patterns)] #![feature(let_chains)] diff --git a/compiler/rustc_borrowck/src/type_check/input_output.rs b/compiler/rustc_borrowck/src/type_check/input_output.rs index 59518f68ab1..a3e5088ee09 100644 --- a/compiler/rustc_borrowck/src/type_check/input_output.rs +++ b/compiler/rustc_borrowck/src/type_check/input_output.rs @@ -7,13 +7,18 @@ //! `RETURN_PLACE` the MIR arguments) are always fully normalized (and //! contain revealed `impl Trait` values). +use std::assert_matches::assert_matches; + use itertools::Itertools; -use rustc_infer::infer::BoundRegionConversionTime; +use rustc_hir as hir; +use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; +use rustc_infer::infer::{BoundRegionConversionTime, RegionVariableOrigin}; use rustc_middle::mir::*; use rustc_middle::ty::{self, Ty}; use rustc_span::Span; -use crate::universal_regions::UniversalRegions; +use crate::renumber::RegionCtxt; +use crate::universal_regions::{DefiningTy, UniversalRegions}; use super::{Locations, TypeChecker}; @@ -23,9 +28,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { #[instrument(skip(self, body), level = "debug")] pub(super) fn check_signature_annotation(&mut self, body: &Body<'tcx>) { let mir_def_id = body.source.def_id().expect_local(); + if !self.tcx().is_closure_or_coroutine(mir_def_id.to_def_id()) { return; } + let user_provided_poly_sig = self.tcx().closure_user_provided_sig(mir_def_id); // Instantiate the canonicalized variables from user-provided signature @@ -34,12 +41,69 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { // so that they represent the view from "inside" the closure. let user_provided_sig = self .instantiate_canonical_with_fresh_inference_vars(body.span, &user_provided_poly_sig); - let user_provided_sig = self.infcx.instantiate_binder_with_fresh_vars( + let mut user_provided_sig = self.infcx.instantiate_binder_with_fresh_vars( body.span, BoundRegionConversionTime::FnCall, user_provided_sig, ); + // FIXME(async_closures): We must apply the same transformation to our + // signature here as we do during closure checking. + if let DefiningTy::CoroutineClosure(_, args) = + self.borrowck_context.universal_regions.defining_ty + { + assert_matches!( + self.tcx().coroutine_kind(self.tcx().coroutine_for_closure(mir_def_id)), + Some(hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Async, + hir::CoroutineSource::Closure + )), + "this needs to be modified if we're lowering non-async closures" + ); + let args = args.as_coroutine_closure(); + let tupled_upvars_ty = ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind( + self.tcx(), + args.kind(), + Ty::new_tup(self.tcx(), user_provided_sig.inputs()), + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + self.infcx.next_region_var(RegionVariableOrigin::MiscVariable(body.span), || { + RegionCtxt::Unknown + }), + ); + + let next_ty_var = || { + self.infcx.next_ty_var(TypeVariableOrigin { + span: body.span, + kind: TypeVariableOriginKind::MiscVariable, + }) + }; + let output_ty = Ty::new_coroutine( + self.tcx(), + self.tcx().coroutine_for_closure(mir_def_id), + ty::CoroutineArgs::new( + self.tcx(), + ty::CoroutineArgsParts { + parent_args: args.parent_args(), + resume_ty: next_ty_var(), + yield_ty: next_ty_var(), + witness: next_ty_var(), + return_ty: user_provided_sig.output(), + tupled_upvars_ty: tupled_upvars_ty, + }, + ) + .args, + ); + + user_provided_sig = self.tcx().mk_fn_sig( + user_provided_sig.inputs().iter().copied(), + output_ty, + user_provided_sig.c_variadic, + user_provided_sig.unsafety, + user_provided_sig.abi, + ); + } + let is_coroutine_with_implicit_resume_ty = self.tcx().is_coroutine(mir_def_id.to_def_id()) && user_provided_sig.inputs().is_empty(); diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs index 90a9844fda9..cdb187e0776 100644 --- a/compiler/rustc_borrowck/src/type_check/mod.rs +++ b/compiler/rustc_borrowck/src/type_check/mod.rs @@ -2773,15 +2773,14 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { let typeck_root_args = ty::GenericArgs::identity_for_item(tcx, typeck_root_def_id); let parent_args = match tcx.def_kind(def_id) { + // We don't want to dispatch on 3 different kind of closures here, so take + // advantage of the fact that the `parent_args` is the same length as the + // `typeck_root_args`. DefKind::Closure => { - // FIXME(async_closures): It's kind of icky to access HIR here. - match tcx.hir_node_by_def_id(def_id).expect_closure().kind { - hir::ClosureKind::Closure => args.as_closure().parent_args(), - hir::ClosureKind::Coroutine(_) => args.as_coroutine().parent_args(), - hir::ClosureKind::CoroutineClosure(_) => { - args.as_coroutine_closure().parent_args() - } - } + // FIXME(async_closures): It may be useful to add a debug assert here + // to actually call `type_of` and check the `parent_args` are the same + // length as the `typeck_root_args`. + &args[..typeck_root_args.len()] } DefKind::InlineConst => args.as_inline_const().parent_args(), other => bug!("unexpected item {:?}", other), diff --git a/compiler/rustc_borrowck/src/universal_regions.rs b/compiler/rustc_borrowck/src/universal_regions.rs index ae8a135f090..9c266d123b3 100644 --- a/compiler/rustc_borrowck/src/universal_regions.rs +++ b/compiler/rustc_borrowck/src/universal_regions.rs @@ -97,6 +97,10 @@ pub enum DefiningTy<'tcx> { /// `ClosureArgs::coroutine_return_ty`. Coroutine(DefId, GenericArgsRef<'tcx>), + /// The MIR is a special kind of closure that returns coroutines. + /// TODO: describe how to make the sig... + CoroutineClosure(DefId, GenericArgsRef<'tcx>), + /// The MIR is a fn item with the given `DefId` and args. The signature /// of the function can be bound then with the `fn_sig` query. FnDef(DefId, GenericArgsRef<'tcx>), @@ -119,6 +123,7 @@ impl<'tcx> DefiningTy<'tcx> { pub fn upvar_tys(self) -> &'tcx ty::List> { match self { DefiningTy::Closure(_, args) => args.as_closure().upvar_tys(), + DefiningTy::CoroutineClosure(_, args) => args.as_coroutine_closure().upvar_tys(), DefiningTy::Coroutine(_, args) => args.as_coroutine().upvar_tys(), DefiningTy::FnDef(..) | DefiningTy::Const(..) | DefiningTy::InlineConst(..) => { ty::List::empty() @@ -131,7 +136,9 @@ impl<'tcx> DefiningTy<'tcx> { /// user's code. pub fn implicit_inputs(self) -> usize { match self { - DefiningTy::Closure(..) | DefiningTy::Coroutine(..) => 1, + DefiningTy::Closure(..) + | DefiningTy::CoroutineClosure(..) + | DefiningTy::Coroutine(..) => 1, DefiningTy::FnDef(..) | DefiningTy::Const(..) | DefiningTy::InlineConst(..) => 0, } } @@ -147,6 +154,7 @@ impl<'tcx> DefiningTy<'tcx> { pub fn def_id(&self) -> DefId { match *self { DefiningTy::Closure(def_id, ..) + | DefiningTy::CoroutineClosure(def_id, ..) | DefiningTy::Coroutine(def_id, ..) | DefiningTy::FnDef(def_id, ..) | DefiningTy::Const(def_id, ..) @@ -355,6 +363,9 @@ impl<'tcx> UniversalRegions<'tcx> { err.note(format!("late-bound region is {:?}", self.to_region_vid(r))); }); } + DefiningTy::CoroutineClosure(..) => { + todo!() + } DefiningTy::Coroutine(def_id, args) => { let v = with_no_trimmed_paths!( args[tcx.generics_of(def_id).parent_count..] @@ -568,6 +579,9 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> { match *defining_ty.kind() { ty::Closure(def_id, args) => DefiningTy::Closure(def_id, args), ty::Coroutine(def_id, args) => DefiningTy::Coroutine(def_id, args), + ty::CoroutineClosure(def_id, args) => { + DefiningTy::CoroutineClosure(def_id, args) + } ty::FnDef(def_id, args) => DefiningTy::FnDef(def_id, args), _ => span_bug!( tcx.def_span(self.mir_def), @@ -623,6 +637,7 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> { let identity_args = GenericArgs::identity_for_item(tcx, typeck_root_def_id); let fr_args = match defining_ty { DefiningTy::Closure(_, args) + | DefiningTy::CoroutineClosure(_, args) | DefiningTy::Coroutine(_, args) | DefiningTy::InlineConst(_, args) => { // In the case of closures, we rely on the fact that @@ -702,6 +717,47 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> { ty::Binder::dummy(inputs_and_output) } + DefiningTy::CoroutineClosure(def_id, args) => { + assert_eq!(self.mir_def.to_def_id(), def_id); + let closure_sig = args.as_coroutine_closure().coroutine_closure_sig(); + let bound_vars = tcx.mk_bound_variable_kinds_from_iter( + closure_sig + .bound_vars() + .iter() + .chain(iter::once(ty::BoundVariableKind::Region(ty::BrEnv))), + ); + let br = ty::BoundRegion { + var: ty::BoundVar::from_usize(bound_vars.len() - 1), + kind: ty::BrEnv, + }; + let env_region = ty::Region::new_bound(tcx, ty::INNERMOST, br); + let closure_kind = args.as_coroutine_closure().kind(); + + let closure_ty = tcx.closure_env_ty( + Ty::new_coroutine_closure(tcx, def_id, args), + closure_kind, + env_region, + ); + + let inputs = closure_sig.skip_binder().tupled_inputs_ty.tuple_fields(); + let output = closure_sig.skip_binder().to_coroutine_given_kind_and_upvars( + tcx, + args.as_coroutine_closure().parent_args(), + tcx.coroutine_for_closure(def_id), + closure_kind, + env_region, + args.as_coroutine_closure().tupled_upvars_ty(), + args.as_coroutine_closure().coroutine_captures_by_ref_ty(), + ); + + ty::Binder::bind_with_vars( + tcx.mk_type_list_from_iter( + iter::once(closure_ty).chain(inputs).chain(iter::once(output)), + ), + bound_vars, + ) + } + DefiningTy::FnDef(def_id, _) => { let sig = tcx.fn_sig(def_id).instantiate_identity(); let sig = indices.fold_to_region_vids(tcx, sig); diff --git a/compiler/rustc_const_eval/src/transform/validate.rs b/compiler/rustc_const_eval/src/transform/validate.rs index d3230a2455d..c4542aaa7b2 100644 --- a/compiler/rustc_const_eval/src/transform/validate.rs +++ b/compiler/rustc_const_eval/src/transform/validate.rs @@ -58,6 +58,7 @@ impl<'tcx> MirPass<'tcx> for Validator { let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, + ty::CoroutineClosure(..) => Abi::RustCall, ty::Coroutine(..) => Abi::Rust, _ => { span_bug!(body.span, "unexpected body ty: {:?} phase {:?}", body_ty, mir_phase) diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index 8e327c029df..9c9e72d6e65 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -209,6 +209,7 @@ language_item_table! { AsyncFn, sym::async_fn, async_fn_trait, Target::Trait, GenericRequirement::Exact(1); AsyncFnMut, sym::async_fn_mut, async_fn_mut_trait, Target::Trait, GenericRequirement::Exact(1); AsyncFnOnce, sym::async_fn_once, async_fn_once_trait, Target::Trait, GenericRequirement::Exact(1); + AsyncFnKindHelper, sym::async_fn_kind_helper,async_fn_kind_helper, Target::Trait, GenericRequirement::Exact(1); FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None; diff --git a/compiler/rustc_hir_analysis/src/collect.rs b/compiler/rustc_hir_analysis/src/collect.rs index 8d862d5eb71..8fdd3cfe3cf 100644 --- a/compiler/rustc_hir_analysis/src/collect.rs +++ b/compiler/rustc_hir_analysis/src/collect.rs @@ -81,6 +81,7 @@ pub fn provide(providers: &mut Providers) { impl_trait_ref, impl_polarity, coroutine_kind, + coroutine_for_closure, collect_mod_item_types, is_type_alias_impl_trait, ..*providers @@ -1531,6 +1532,36 @@ fn coroutine_kind(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Option, def_id: LocalDefId) -> DefId { + let Node::Expr(&hir::Expr { + kind: + hir::ExprKind::Closure(&rustc_hir::Closure { + kind: hir::ClosureKind::CoroutineClosure(_), + body, + .. + }), + .. + }) = tcx.hir_node_by_def_id(def_id) + else { + bug!() + }; + + let &hir::Expr { + kind: + hir::ExprKind::Closure(&rustc_hir::Closure { + def_id, + kind: hir::ClosureKind::Coroutine(_), + .. + }), + .. + } = tcx.hir().body(body).value + else { + bug!() + }; + + def_id.to_def_id() +} + fn is_type_alias_impl_trait<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> bool { match tcx.hir_node_by_def_id(def_id) { Node::Item(hir::Item { kind: hir::ItemKind::OpaqueTy(opaque), .. }) => { diff --git a/compiler/rustc_hir_typeck/src/callee.rs b/compiler/rustc_hir_typeck/src/callee.rs index 80650dcb53c..1858b2770cd 100644 --- a/compiler/rustc_hir_typeck/src/callee.rs +++ b/compiler/rustc_hir_typeck/src/callee.rs @@ -141,32 +141,67 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { return Some(CallStep::Builtin(adjusted_ty)); } - ty::Closure(def_id, args) => { + // Check whether this is a call to a closure where we + // haven't yet decided on whether the closure is fn vs + // fnmut vs fnonce. If so, we have to defer further processing. + ty::Closure(def_id, args) if self.closure_kind(adjusted_ty).is_none() => { let def_id = def_id.expect_local(); + let closure_sig = args.as_closure().sig(); + let closure_sig = self.instantiate_binder_with_fresh_vars( + call_expr.span, + infer::FnCall, + closure_sig, + ); + let adjustments = self.adjust_steps(autoderef); + self.record_deferred_call_resolution( + def_id, + DeferredCallResolution { + call_expr, + callee_expr, + closure_ty: adjusted_ty, + adjustments, + fn_sig: closure_sig, + }, + ); + return Some(CallStep::DeferredClosure(def_id, closure_sig)); + } - // Check whether this is a call to a closure where we - // haven't yet decided on whether the closure is fn vs - // fnmut vs fnonce. If so, we have to defer further processing. - if self.closure_kind(adjusted_ty).is_none() { - let closure_sig = args.as_closure().sig(); - let closure_sig = self.instantiate_binder_with_fresh_vars( - call_expr.span, - infer::FnCall, - closure_sig, - ); - let adjustments = self.adjust_steps(autoderef); - self.record_deferred_call_resolution( - def_id, - DeferredCallResolution { - call_expr, - callee_expr, - closure_ty: adjusted_ty, - adjustments, - fn_sig: closure_sig, - }, - ); - return Some(CallStep::DeferredClosure(def_id, closure_sig)); - } + ty::CoroutineClosure(def_id, args) if self.closure_kind(adjusted_ty).is_none() => { + let def_id = def_id.expect_local(); + let closure_args = args.as_coroutine_closure(); + let coroutine_closure_sig = self.instantiate_binder_with_fresh_vars( + call_expr.span, + infer::FnCall, + closure_args.coroutine_closure_sig(), + ); + let tupled_upvars_ty = self.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::TypeInference, + span: callee_expr.span, + }); + let call_sig = self.tcx.mk_fn_sig( + [coroutine_closure_sig.tupled_inputs_ty], + coroutine_closure_sig.to_coroutine( + self.tcx, + closure_args.parent_args(), + self.tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ), + coroutine_closure_sig.c_variadic, + coroutine_closure_sig.unsafety, + coroutine_closure_sig.abi, + ); + let adjustments = self.adjust_steps(autoderef); + self.record_deferred_call_resolution( + def_id, + DeferredCallResolution { + call_expr, + callee_expr, + closure_ty: adjusted_ty, + adjustments, + fn_sig: call_sig, + }, + ); + return Some(CallStep::DeferredClosure(def_id, call_sig)); } // Hack: we know that there are traits implementing Fn for &F @@ -935,7 +970,7 @@ impl<'a, 'tcx> DeferredCallResolution<'tcx> { span_bug!( self.call_expr.span, "Expected to find a suitable `Fn`/`FnMut`/`FnOnce` implementation for `{}`", - self.adjusted_ty + self.closure_ty ) } } diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index 18dff5188af..1d024efdd49 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -63,7 +63,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { None => (None, None), }; - let ClosureSignatures { bound_sig, liberated_sig } = + let ClosureSignatures { bound_sig, mut liberated_sig } = self.sig_of_closure(expr_def_id, closure.fn_decl, closure.kind, expected_sig); debug!(?bound_sig, ?liberated_sig); @@ -125,7 +125,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) | hir::CoroutineKind::Coroutine(_) => { let yield_ty = self.next_ty_var(TypeVariableOrigin { - kind: TypeVariableOriginKind::TypeInference, + kind: TypeVariableOriginKind::ClosureSynthetic, span: expr_span, }); self.require_type_is_sized(yield_ty, expr_span, traits::SizedYieldType); @@ -137,7 +137,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // not a problem. hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => { let yield_ty = self.next_ty_var(TypeVariableOrigin { - kind: TypeVariableOriginKind::TypeInference, + kind: TypeVariableOriginKind::ClosureSynthetic, span: expr_span, }); self.require_type_is_sized(yield_ty, expr_span, traits::SizedYieldType); @@ -166,8 +166,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let resume_ty = liberated_sig.inputs().get(0).copied().unwrap_or(tcx.types.unit); let interior = self.next_ty_var(TypeVariableOrigin { - kind: TypeVariableOriginKind::MiscVariable, - span: body.value.span, + kind: TypeVariableOriginKind::ClosureSynthetic, + span: expr_span, }); self.deferred_coroutine_interiors.borrow_mut().push(( expr_def_id, @@ -192,7 +192,82 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { Some(CoroutineTypes { resume_ty, yield_ty }), ) } - hir::ClosureKind::CoroutineClosure(_) => todo!(), + hir::ClosureKind::CoroutineClosure(kind) => { + let (bound_return_ty, bound_yield_ty) = match kind { + hir::CoroutineDesugaring::Async => { + (bound_sig.skip_binder().output(), tcx.types.unit) + } + hir::CoroutineDesugaring::Gen | hir::CoroutineDesugaring::AsyncGen => { + todo!("`gen` and `async gen` closures not supported yet") + } + }; + let resume_ty = self.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::ClosureSynthetic, + span: expr_span, + }); + let interior = self.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::ClosureSynthetic, + span: expr_span, + }); + let closure_kind_ty = self.next_ty_var(TypeVariableOrigin { + // FIXME(eddyb) distinguish closure kind inference variables from the rest. + kind: TypeVariableOriginKind::ClosureSynthetic, + span: expr_span, + }); + let coroutine_captures_by_ref_ty = self.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::ClosureSynthetic, + span: expr_span, + }); + let closure_args = ty::CoroutineClosureArgs::new( + tcx, + ty::CoroutineClosureArgsParts { + parent_args, + closure_kind_ty, + signature_parts_ty: Ty::new_fn_ptr( + tcx, + bound_sig.map_bound(|sig| { + tcx.mk_fn_sig( + [ + resume_ty, + Ty::new_tup_from_iter(tcx, sig.inputs().iter().copied()), + ], + Ty::new_tup(tcx, &[bound_yield_ty, bound_return_ty]), + sig.c_variadic, + sig.unsafety, + sig.abi, + ) + }), + ), + tupled_upvars_ty, + coroutine_captures_by_ref_ty, + coroutine_witness_ty: interior, + }, + ); + + let coroutine_upvars_ty = self.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::ClosureSynthetic, + span: expr_span, + }); + liberated_sig = tcx.mk_fn_sig( + liberated_sig.inputs().iter().copied(), + tcx.liberate_late_bound_regions( + expr_def_id.to_def_id(), + closure_args.coroutine_closure_sig().map_bound(|sig| { + sig.to_coroutine( + tcx, + parent_args, + tcx.coroutine_for_closure(expr_def_id), + coroutine_upvars_ty, + ) + }), + ), + liberated_sig.c_variadic, + liberated_sig.unsafety, + liberated_sig.abi, + ); + + (Ty::new_coroutine_closure(tcx, expr_def_id.to_def_id(), closure_args.args), None) + } }; check_fn( diff --git a/compiler/rustc_hir_typeck/src/upvar.rs b/compiler/rustc_hir_typeck/src/upvar.rs index 55ffac8d92f..b087d6d9e57 100644 --- a/compiler/rustc_hir_typeck/src/upvar.rs +++ b/compiler/rustc_hir_typeck/src/upvar.rs @@ -174,6 +174,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { ty::Closure(def_id, args) => { (def_id, UpvarArgs::Closure(args), self.closure_kind(ty).is_none()) } + ty::CoroutineClosure(def_id, args) => { + (def_id, UpvarArgs::CoroutineClosure(args), self.closure_kind(ty).is_none()) + } ty::Coroutine(def_id, args) => (def_id, UpvarArgs::Coroutine(args), false), ty::Error(_) => { // #51714: skip analysis when we have already encountered type errors @@ -333,6 +336,65 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { } } + if let UpvarArgs::CoroutineClosure(args) = args { + let closure_env_region: ty::Region<'_> = ty::Region::new_bound( + self.tcx, + ty::INNERMOST, + ty::BoundRegion { + var: ty::BoundVar::from_usize(0), + kind: ty::BoundRegionKind::BrEnv, + }, + ); + let tupled_upvars_ty_for_borrow = Ty::new_tup_from_iter( + self.tcx, + self.typeck_results + .borrow() + .closure_min_captures_flattened( + self.tcx.coroutine_for_closure(closure_def_id).expect_local(), + ) + // Skip the captures that are just moving the closure's args + // into the coroutine. These are always by move. + .skip( + args.as_coroutine_closure() + .coroutine_closure_sig() + .skip_binder() + .tupled_inputs_ty + .tuple_fields() + .len(), + ) + .map(|captured_place| { + let upvar_ty = captured_place.place.ty(); + let capture = captured_place.info.capture_kind; + apply_capture_kind_on_capture_ty( + self.tcx, + upvar_ty, + capture, + Some(closure_env_region), + ) + }), + ); + let coroutine_captures_by_ref_ty = Ty::new_fn_ptr( + self.tcx, + ty::Binder::bind_with_vars( + self.tcx.mk_fn_sig( + [], + tupled_upvars_ty_for_borrow, + false, + hir::Unsafety::Normal, + rustc_target::spec::abi::Abi::Rust, + ), + self.tcx.mk_bound_variable_kinds(&[ty::BoundVariableKind::Region( + ty::BoundRegionKind::BrEnv, + )]), + ), + ); + self.demand_eqtype( + span, + args.as_coroutine_closure().coroutine_captures_by_ref_ty(), + coroutine_captures_by_ref_ty, + ); + } + self.log_closure_min_capture_info(closure_def_id, span); // Now that we've analyzed the closure, we know how each @@ -602,7 +664,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { }; // Go through each entry in the current list of min_captures - // - if ancestor is found, update it's capture kind to account for current place's + // - if ancestor is found, update its capture kind to account for current place's // capture information. // // - if descendant is found, remove it from the list, and update the current place's diff --git a/compiler/rustc_middle/src/middle/lang_items.rs b/compiler/rustc_middle/src/middle/lang_items.rs index f92c72c8a58..a4e193ba2c9 100644 --- a/compiler/rustc_middle/src/middle/lang_items.rs +++ b/compiler/rustc_middle/src/middle/lang_items.rs @@ -23,7 +23,7 @@ impl<'tcx> TyCtxt<'tcx> { }) } - /// Given a [`DefId`] of a [`Fn`], [`FnMut`] or [`FnOnce`] traits, + /// Given a [`DefId`] of one of the [`Fn`], [`FnMut`] or [`FnOnce`] traits, /// returns a corresponding [`ty::ClosureKind`]. /// For any other [`DefId`] return `None`. pub fn fn_trait_kind_from_def_id(self, id: DefId) -> Option { @@ -36,6 +36,19 @@ impl<'tcx> TyCtxt<'tcx> { } } + /// Given a [`DefId`] of one of the `AsyncFn`, `AsyncFnMut` or `AsyncFnOnce` traits, + /// returns a corresponding [`ty::ClosureKind`]. + /// For any other [`DefId`] return `None`. + pub fn async_fn_trait_kind_from_def_id(self, id: DefId) -> Option { + let items = self.lang_items(); + match Some(id) { + x if x == items.async_fn_trait() => Some(ty::ClosureKind::Fn), + x if x == items.async_fn_mut_trait() => Some(ty::ClosureKind::FnMut), + x if x == items.async_fn_once_trait() => Some(ty::ClosureKind::FnOnce), + _ => None, + } + } + /// Given a [`ty::ClosureKind`], get the [`DefId`] of its corresponding `Fn`-family /// trait, if it is defined. pub fn fn_trait_kind_to_def_id(self, kind: ty::ClosureKind) -> Option { diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 2438f826441..f9ab32b16f5 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -755,6 +755,11 @@ rustc_queries! { separate_provide_extern } + query coroutine_for_closure(def_id: DefId) -> DefId { + desc { |_tcx| "TODO" } + separate_provide_extern + } + /// Gets a map with the variance of every item; use `variances_of` instead. query crate_variances(_: ()) -> &'tcx ty::CrateVariancesMap<'tcx> { arena_cache diff --git a/compiler/rustc_middle/src/traits/select.rs b/compiler/rustc_middle/src/traits/select.rs index 64f4af08e12..4e11575cf98 100644 --- a/compiler/rustc_middle/src/traits/select.rs +++ b/compiler/rustc_middle/src/traits/select.rs @@ -134,6 +134,13 @@ pub enum SelectionCandidate<'tcx> { is_const: bool, }, + /// Implementation of an `AsyncFn`-family trait by one of the anonymous types + /// generated for an `async ||` expression. + AsyncClosureCandidate, + + // TODO: + AsyncFnKindHelperCandidate, + /// Implementation of a `Coroutine` trait by one of the anonymous types /// generated for a coroutine. CoroutineCandidate, diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 8e51db82f95..4348d01eff5 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -105,9 +105,10 @@ pub use self::region::{ pub use self::rvalue_scopes::RvalueScopes; pub use self::sty::{ AliasTy, Article, Binder, BoundTy, BoundTyKind, BoundVariableKind, CanonicalPolyFnSig, - ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, FnSig, GenSig, - InlineConstArgs, InlineConstArgsParts, ParamConst, ParamTy, PolyFnSig, TyKind, TypeAndMut, - UpvarArgs, VarianceDiagInfo, + ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, CoroutineClosureArgs, + CoroutineClosureArgsParts, CoroutineClosureSignature, FnSig, GenSig, InlineConstArgs, + InlineConstArgsParts, ParamConst, ParamTy, PolyFnSig, TyKind, TypeAndMut, UpvarArgs, + VarianceDiagInfo, }; pub use self::trait_def::TraitDef; pub use self::typeck_results::{ diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index c047f6a0521..e2a2e24f06d 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -36,6 +36,7 @@ use rustc_type_ir::TyKind as IrTyKind; use rustc_type_ir::TyKind::*; use rustc_type_ir::TypeAndMut as IrTypeAndMut; +use super::fold::FnMutDelegate; use super::GenericParamDefKind; // Re-export and re-parameterize some `I = TyCtxt<'tcx>` types here @@ -351,6 +352,27 @@ impl<'tcx> CoroutineClosureArgs<'tcx> { self.split().signature_parts_ty } + pub fn coroutine_closure_sig(self) -> ty::Binder<'tcx, CoroutineClosureSignature<'tcx>> { + let interior = self.coroutine_witness_ty(); + let ty::FnPtr(sig) = self.signature_parts_ty().kind() else { bug!() }; + sig.map_bound(|sig| { + let [resume_ty, tupled_inputs_ty] = *sig.inputs() else { + bug!(); + }; + let [yield_ty, return_ty] = **sig.output().tuple_fields() else { bug!() }; + CoroutineClosureSignature { + interior, + tupled_inputs_ty, + resume_ty, + yield_ty, + return_ty, + c_variadic: sig.c_variadic, + unsafety: sig.unsafety, + abi: sig.abi, + } + }) + } + pub fn coroutine_captures_by_ref_ty(self) -> Ty<'tcx> { self.split().coroutine_captures_by_ref_ty } @@ -360,6 +382,103 @@ impl<'tcx> CoroutineClosureArgs<'tcx> { } } +#[derive(Copy, Clone, PartialEq, Eq, Debug, TypeFoldable, TypeVisitable)] +pub struct CoroutineClosureSignature<'tcx> { + pub interior: Ty<'tcx>, + pub tupled_inputs_ty: Ty<'tcx>, + pub resume_ty: Ty<'tcx>, + pub yield_ty: Ty<'tcx>, + pub return_ty: Ty<'tcx>, + pub c_variadic: bool, + pub unsafety: hir::Unsafety, + pub abi: abi::Abi, +} + +impl<'tcx> CoroutineClosureSignature<'tcx> { + pub fn to_coroutine( + self, + tcx: TyCtxt<'tcx>, + parent_args: &'tcx [GenericArg<'tcx>], + coroutine_def_id: DefId, + tupled_upvars_ty: Ty<'tcx>, + ) -> Ty<'tcx> { + let coroutine_args = ty::CoroutineArgs::new( + tcx, + ty::CoroutineArgsParts { + parent_args, + resume_ty: self.resume_ty, + yield_ty: self.yield_ty, + return_ty: self.return_ty, + witness: self.interior, + tupled_upvars_ty, + }, + ); + + Ty::new_coroutine(tcx, coroutine_def_id, coroutine_args.args) + } + + pub fn to_coroutine_given_kind_and_upvars( + self, + tcx: TyCtxt<'tcx>, + parent_args: &'tcx [GenericArg<'tcx>], + coroutine_def_id: DefId, + closure_kind: ty::ClosureKind, + env_region: ty::Region<'tcx>, + closure_tupled_upvars_ty: Ty<'tcx>, + coroutine_captures_by_ref_ty: Ty<'tcx>, + ) -> Ty<'tcx> { + let tupled_upvars_ty = Self::tupled_upvars_by_closure_kind( + tcx, + closure_kind, + self.tupled_inputs_ty, + closure_tupled_upvars_ty, + coroutine_captures_by_ref_ty, + env_region, + ); + + self.to_coroutine(tcx, parent_args, coroutine_def_id, tupled_upvars_ty) + } + + /// Given a closure kind, compute the tupled upvars that the given coroutine would return. + pub fn tupled_upvars_by_closure_kind( + tcx: TyCtxt<'tcx>, + kind: ty::ClosureKind, + tupled_inputs_ty: Ty<'tcx>, + closure_tupled_upvars_ty: Ty<'tcx>, + coroutine_captures_by_ref_ty: Ty<'tcx>, + env_region: ty::Region<'tcx>, + ) -> Ty<'tcx> { + match kind { + ty::ClosureKind::Fn | ty::ClosureKind::FnMut => { + let ty::FnPtr(sig) = *coroutine_captures_by_ref_ty.kind() else { + bug!(); + }; + let coroutine_captures_by_ref_ty = tcx.replace_escaping_bound_vars_uncached( + sig.output().skip_binder(), + FnMutDelegate { + consts: &mut |c, t| ty::Const::new_bound(tcx, ty::INNERMOST, c, t), + types: &mut |t| Ty::new_bound(tcx, ty::INNERMOST, t), + regions: &mut |_| env_region, + }, + ); + Ty::new_tup_from_iter( + tcx, + tupled_inputs_ty + .tuple_fields() + .iter() + .chain(coroutine_captures_by_ref_ty.tuple_fields()), + ) + } + ty::ClosureKind::FnOnce => Ty::new_tup_from_iter( + tcx, + tupled_inputs_ty + .tuple_fields() + .iter() + .chain(closure_tupled_upvars_ty.tuple_fields()), + ), + } + } +} /// Similar to `ClosureArgs`; see the above documentation for more. #[derive(Copy, Clone, PartialEq, Eq, Debug, TypeFoldable, TypeVisitable)] pub struct CoroutineArgs<'tcx> { @@ -1495,7 +1614,7 @@ impl<'tcx> Ty<'tcx> { ) -> Ty<'tcx> { debug_assert_eq!( closure_args.len(), - tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 3, + tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 5, "closure constructed with incorrect substitutions" ); Ty::new(tcx, CoroutineClosure(def_id, closure_args)) @@ -1835,6 +1954,11 @@ impl<'tcx> Ty<'tcx> { matches!(self.kind(), Coroutine(..)) } + #[inline] + pub fn is_coroutine_closure(self) -> bool { + matches!(self.kind(), CoroutineClosure(..)) + } + #[inline] pub fn is_integral(self) -> bool { matches!(self.kind(), Infer(IntVar(_)) | Int(_) | Uint(_)) @@ -2144,7 +2268,7 @@ impl<'tcx> Ty<'tcx> { // "Bound" types appear in canonical queries when the // closure type is not yet known - Bound(..) | Infer(_) => None, + Bound(..) | Param(_) | Infer(_) => None, Error(_) => Some(ty::ClosureKind::Fn), diff --git a/compiler/rustc_mir_build/src/build/mod.rs b/compiler/rustc_mir_build/src/build/mod.rs index 9bcfe9fbc33..9a2f2ceced1 100644 --- a/compiler/rustc_mir_build/src/build/mod.rs +++ b/compiler/rustc_mir_build/src/build/mod.rs @@ -822,6 +822,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let upvar_args = match closure_ty.kind() { ty::Closure(_, args) => ty::UpvarArgs::Closure(args), ty::Coroutine(_, args) => ty::UpvarArgs::Coroutine(args), + ty::CoroutineClosure(_, args) => ty::UpvarArgs::CoroutineClosure(args), _ => return, }; diff --git a/compiler/rustc_mir_build/src/thir/cx/expr.rs b/compiler/rustc_mir_build/src/thir/cx/expr.rs index 22094c112fc..35adcc33480 100644 --- a/compiler/rustc_mir_build/src/thir/cx/expr.rs +++ b/compiler/rustc_mir_build/src/thir/cx/expr.rs @@ -556,6 +556,9 @@ impl<'tcx> Cx<'tcx> { ty::Coroutine(def_id, args) => { (def_id, UpvarArgs::Coroutine(args), Some(tcx.coroutine_movability(def_id))) } + ty::CoroutineClosure(def_id, args) => { + (def_id, UpvarArgs::CoroutineClosure(args), None) + } _ => { span_bug!(expr.span, "closure expr w/o closure type: {:?}", closure_ty); } diff --git a/compiler/rustc_mir_build/src/thir/cx/mod.rs b/compiler/rustc_mir_build/src/thir/cx/mod.rs index 5d0bb3954cc..f4f591d15b9 100644 --- a/compiler/rustc_mir_build/src/thir/cx/mod.rs +++ b/compiler/rustc_mir_build/src/thir/cx/mod.rs @@ -127,10 +127,24 @@ impl<'tcx> Cx<'tcx> { ty::Coroutine(..) => { Param { ty: closure_ty, pat: None, ty_span: None, self_kind: None, hir_id: None } } - ty::Closure(_, closure_args) => { + ty::Closure(_, args) => { let closure_env_ty = self.tcx.closure_env_ty( closure_ty, - closure_args.as_closure().kind(), + args.as_closure().kind(), + self.tcx.lifetimes.re_erased, + ); + Param { + ty: closure_env_ty, + pat: None, + ty_span: None, + self_kind: None, + hir_id: None, + } + } + ty::CoroutineClosure(_, args) => { + let closure_env_ty = self.tcx.closure_env_ty( + closure_ty, + args.as_coroutine_closure().kind(), self.tcx.lifetimes.re_erased, ); Param { diff --git a/compiler/rustc_mir_dataflow/src/value_analysis.rs b/compiler/rustc_mir_dataflow/src/value_analysis.rs index 2802f5491be..a85e53ff241 100644 --- a/compiler/rustc_mir_dataflow/src/value_analysis.rs +++ b/compiler/rustc_mir_dataflow/src/value_analysis.rs @@ -1149,6 +1149,9 @@ pub fn iter_fields<'tcx>( ty::Closure(_, args) => { iter_fields(args.as_closure().tupled_upvars_ty(), tcx, param_env, f); } + ty::CoroutineClosure(_, args) => { + iter_fields(args.as_coroutine_closure().tupled_upvars_ty(), tcx, param_env, f); + } _ => (), } } diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs index aa22b8c7c58..f2448ee3d44 100644 --- a/compiler/rustc_mir_transform/src/const_prop_lint.rs +++ b/compiler/rustc_mir_transform/src/const_prop_lint.rs @@ -607,7 +607,8 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { AggregateKind::Array(_) | AggregateKind::Tuple | AggregateKind::Closure(_, _) - | AggregateKind::Coroutine(_, _) => VariantIdx::new(0), + | AggregateKind::Coroutine(_, _) + | AggregateKind::CoroutineClosure(_, _) => VariantIdx::new(0), }, }, diff --git a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs index 5b4d58836b4..01fae7c0bec 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs @@ -156,7 +156,9 @@ fn bcb_to_initial_coverage_spans<'a, 'tcx>( fn is_closure_or_coroutine(statement: &Statement<'_>) -> bool { match statement.kind { StatementKind::Assign(box (_, Rvalue::Aggregate(box ref agg_kind, _))) => match agg_kind { - AggregateKind::Closure(_, _) | AggregateKind::Coroutine(_, _) => true, + AggregateKind::Closure(_, _) + | AggregateKind::Coroutine(_, _) + | AggregateKind::CoroutineClosure(..) => true, _ => false, }, _ => false, diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index dbfc89c2d49..587eb1c7cc5 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -167,6 +167,9 @@ symbols! { Break, C, CStr, + CallFuture, + CallMutFuture, + CallOnceFuture, Capture, Center, Cleanup, @@ -420,6 +423,7 @@ symbols! { async_closure, async_fn, async_fn_in_trait, + async_fn_kind_helper, async_fn_mut, async_fn_once, async_fn_track_caller, diff --git a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs index e1c68039e79..8451fbcc434 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs @@ -182,6 +182,20 @@ pub(super) trait GoalKind<'tcx>: kind: ty::ClosureKind, ) -> QueryResult<'tcx>; + /// An async closure is known to implement the `AsyncFn` family of traits + /// where `A` is given by the signature of the type. + fn consider_builtin_async_fn_trait_candidates( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + kind: ty::ClosureKind, + ) -> QueryResult<'tcx>; + + /// TODO: + fn consider_builtin_async_fn_kind_helper_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx>; + /// `Tuple` is implemented if the `Self` type is a tuple. fn consider_builtin_tuple_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, @@ -461,6 +475,10 @@ impl<'tcx> EvalCtxt<'_, 'tcx> { G::consider_builtin_fn_ptr_trait_candidate(self, goal) } else if let Some(kind) = self.tcx().fn_trait_kind_from_def_id(trait_def_id) { G::consider_builtin_fn_trait_candidates(self, goal, kind) + } else if let Some(kind) = self.tcx().async_fn_trait_kind_from_def_id(trait_def_id) { + G::consider_builtin_async_fn_trait_candidates(self, goal, kind) + } else if lang_items.async_fn_kind_helper() == Some(trait_def_id) { + G::consider_builtin_async_fn_kind_helper_candidate(self, goal) } else if lang_items.tuple_trait() == Some(trait_def_id) { G::consider_builtin_tuple_candidate(self, goal) } else if lang_items.pointee_trait() == Some(trait_def_id) { diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs index 3c571e1d96f..c35134c78eb 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs @@ -1,11 +1,12 @@ //! Code which is used by built-in goals that match "structurally", such a auto //! traits, `Copy`/`Clone`. use rustc_data_structures::fx::FxHashMap; +use rustc_hir::LangItem; use rustc_hir::{def_id::DefId, Movability, Mutability}; use rustc_infer::traits::query::NoSolution; use rustc_middle::traits::solve::Goal; use rustc_middle::ty::{ - self, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt, + self, ToPredicate, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt, }; use crate::solve::EvalCtxt; @@ -306,6 +307,114 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_callable<'tcx>( } } +// Returns a binder of the tupled inputs types, output type, and coroutine type +// from a builtin async closure type. +pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tcx>( + tcx: TyCtxt<'tcx>, + self_ty: Ty<'tcx>, + goal_kind: ty::ClosureKind, + env_region: ty::Region<'tcx>, +) -> Result<(ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Vec>), NoSolution> +{ + match *self_ty.kind() { + ty::CoroutineClosure(def_id, args) => { + let args = args.as_coroutine_closure(); + let kind_ty = args.kind_ty(); + + if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { + if !closure_kind.extends(goal_kind) { + return Err(NoSolution); + } + Ok(( + args.coroutine_closure_sig().map_bound(|sig| { + let coroutine_ty = sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + goal_kind, + env_region, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ); + (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty) + }), + vec![], + )) + } else { + let helper_trait_def_id = tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); + // FIXME(async_closures): Make this into a lang item. + let upvars_projection_def_id = tcx + .associated_items(helper_trait_def_id) + .in_definition_order() + .next() + .unwrap() + .def_id; + Ok(( + args.coroutine_closure_sig().map_bound(|sig| { + let tupled_upvars_ty = Ty::new_projection( + tcx, + upvars_projection_def_id, + [ + ty::GenericArg::from(kind_ty), + Ty::from_closure_kind(tcx, goal_kind).into(), + env_region.into(), + sig.tupled_inputs_ty.into(), + args.tupled_upvars_ty().into(), + args.coroutine_captures_by_ref_ty().into(), + ], + ); + let coroutine_ty = sig.to_coroutine( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ); + (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty) + }), + vec![ + ty::TraitRef::new( + tcx, + helper_trait_def_id, + [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], + ) + .to_predicate(tcx), + ], + )) + } + } + + ty::FnDef(..) | ty::FnPtr(..) | ty::Closure(..) => Err(NoSolution), + + ty::Bool + | ty::Char + | ty::Int(_) + | ty::Uint(_) + | ty::Float(_) + | ty::Adt(_, _) + | ty::Foreign(_) + | ty::Str + | ty::Array(_, _) + | ty::Slice(_) + | ty::RawPtr(_) + | ty::Ref(_, _, _) + | ty::Dynamic(_, _, _) + | ty::Coroutine(_, _) + | ty::CoroutineWitness(..) + | ty::Never + | ty::Tuple(_) + | ty::Alias(_, _) + | ty::Param(_) + | ty::Placeholder(..) + | ty::Infer(ty::IntVar(_) | ty::FloatVar(_)) + | ty::Error(_) => Err(NoSolution), + + ty::Bound(..) + | ty::Infer(ty::TyVar(_) | ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_)) => { + bug!("unexpected type `{self_ty}`") + } + } +} + /// Assemble a list of predicates that would be present on a theoretical /// user impl for an object type. These predicates must be checked any time /// we assemble a built-in object candidate for an object type, since they diff --git a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs index e1c6f67f05e..47ba549022d 100644 --- a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs +++ b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs @@ -366,6 +366,119 @@ impl<'tcx> assembly::GoalKind<'tcx> for NormalizesTo<'tcx> { Self::consider_implied_clause(ecx, goal, pred, [goal.with(tcx, output_is_sized_pred)]) } + fn consider_builtin_async_fn_trait_candidates( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + goal_kind: ty::ClosureKind, + ) -> QueryResult<'tcx> { + let tcx = ecx.tcx(); + + let env_region = match goal_kind { + ty::ClosureKind::Fn | ty::ClosureKind::FnMut => goal.predicate.alias.args.region_at(2), + // Doesn't matter what this region is + ty::ClosureKind::FnOnce => tcx.lifetimes.re_static, + }; + let (tupled_inputs_and_output_and_coroutine, nested_preds) = + structural_traits::extract_tupled_inputs_and_output_from_async_callable( + tcx, + goal.predicate.self_ty(), + goal_kind, + env_region, + )?; + let output_is_sized_pred = + tupled_inputs_and_output_and_coroutine.map_bound(|(_, output, _)| { + ty::TraitRef::from_lang_item(tcx, LangItem::Sized, DUMMY_SP, [output]) + }); + + let pred = tupled_inputs_and_output_and_coroutine + .map_bound(|(inputs, output, coroutine)| { + let (projection_ty, term) = match tcx.item_name(goal.predicate.def_id()) { + sym::CallOnceFuture => ( + ty::AliasTy::new( + tcx, + goal.predicate.def_id(), + [goal.predicate.self_ty(), inputs], + ), + coroutine.into(), + ), + sym::CallMutFuture | sym::CallFuture => ( + ty::AliasTy::new( + tcx, + goal.predicate.def_id(), + [ + ty::GenericArg::from(goal.predicate.self_ty()), + inputs.into(), + env_region.into(), + ], + ), + coroutine.into(), + ), + sym::Output => ( + ty::AliasTy::new( + tcx, + goal.predicate.def_id(), + [ty::GenericArg::from(goal.predicate.self_ty()), inputs.into()], + ), + output.into(), + ), + name => bug!("no such associated type: {name}"), + }; + ty::ProjectionPredicate { projection_ty, term } + }) + .to_predicate(tcx); + + // A built-in `AsyncFn` impl only holds if the output is sized. + // (FIXME: technically we only need to check this if the type is a fn ptr...) + Self::consider_implied_clause( + ecx, + goal, + pred, + [goal.with(tcx, output_is_sized_pred)] + .into_iter() + .chain(nested_preds.into_iter().map(|pred| goal.with(tcx, pred))), + ) + } + + fn consider_builtin_async_fn_kind_helper_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx> { + let [ + closure_fn_kind_ty, + goal_kind_ty, + borrow_region, + tupled_inputs_ty, + tupled_upvars_ty, + coroutine_captures_by_ref_ty, + ] = **goal.predicate.alias.args + else { + bug!(); + }; + + let Some(closure_kind) = closure_fn_kind_ty.expect_ty().to_opt_closure_kind() else { + // We don't need to worry about the self type being an infer var. + return Err(NoSolution); + }; + let Some(goal_kind) = goal_kind_ty.expect_ty().to_opt_closure_kind() else { + return Err(NoSolution); + }; + if !closure_kind.extends(goal_kind) { + return Err(NoSolution); + } + + let upvars_ty = ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind( + ecx.tcx(), + goal_kind, + tupled_inputs_ty.expect_ty(), + tupled_upvars_ty.expect_ty(), + coroutine_captures_by_ref_ty.expect_ty(), + borrow_region.expect_region(), + ); + + ecx.eq(goal.param_env, goal.predicate.term.ty().unwrap(), upvars_ty)?; + ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) + } + fn consider_builtin_tuple_candidate( _ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs index efaad47b6dd..fd09a6b671d 100644 --- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs +++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs @@ -303,6 +303,66 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> { Self::consider_implied_clause(ecx, goal, pred, [goal.with(tcx, output_is_sized_pred)]) } + fn consider_builtin_async_fn_trait_candidates( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + goal_kind: ty::ClosureKind, + ) -> QueryResult<'tcx> { + if goal.predicate.polarity != ty::ImplPolarity::Positive { + return Err(NoSolution); + } + + let tcx = ecx.tcx(); + let (tupled_inputs_and_output_and_coroutine, nested_preds) = + structural_traits::extract_tupled_inputs_and_output_from_async_callable( + tcx, + goal.predicate.self_ty(), + goal_kind, + // This region doesn't matter because we're throwing away the coroutine type + tcx.lifetimes.re_static, + )?; + let output_is_sized_pred = + tupled_inputs_and_output_and_coroutine.map_bound(|(_, output, _)| { + ty::TraitRef::from_lang_item(tcx, LangItem::Sized, DUMMY_SP, [output]) + }); + + let pred = tupled_inputs_and_output_and_coroutine + .map_bound(|(inputs, _, _)| { + ty::TraitRef::new(tcx, goal.predicate.def_id(), [goal.predicate.self_ty(), inputs]) + }) + .to_predicate(tcx); + // A built-in `AsyncFn` impl only holds if the output is sized. + // (FIXME: technically we only need to check this if the type is a fn ptr...) + Self::consider_implied_clause( + ecx, + goal, + pred, + [goal.with(tcx, output_is_sized_pred)] + .into_iter() + .chain(nested_preds.into_iter().map(|pred| goal.with(tcx, pred))), + ) + } + + fn consider_builtin_async_fn_kind_helper_candidate( + ecx: &mut EvalCtxt<'_, 'tcx>, + goal: Goal<'tcx, Self>, + ) -> QueryResult<'tcx> { + let [closure_fn_kind_ty, goal_kind_ty] = **goal.predicate.trait_ref.args else { + bug!(); + }; + + let Some(closure_kind) = closure_fn_kind_ty.expect_ty().to_opt_closure_kind() else { + // We don't need to worry about the self type being an infer var. + return Err(NoSolution); + }; + let goal_kind = goal_kind_ty.expect_ty().to_opt_closure_kind().unwrap(); + if closure_kind.extends(goal_kind) { + ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes) + } else { + Err(NoSolution) + } + } + fn consider_builtin_tuple_candidate( ecx: &mut EvalCtxt<'_, 'tcx>, goal: Goal<'tcx, Self>, diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index a960befcd4b..648f14beaa7 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -1833,10 +1833,28 @@ fn assemble_candidates_from_impls<'cx, 'tcx>( lang_items.fn_trait(), lang_items.fn_mut_trait(), lang_items.fn_once_trait(), + lang_items.async_fn_trait(), + lang_items.async_fn_mut_trait(), + lang_items.async_fn_once_trait(), ].contains(&Some(trait_ref.def_id)) { true - }else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) { + } else if lang_items.async_fn_kind_helper() == Some(trait_ref.def_id) { + // FIXME(async_closures): Validity constraints here could be cleaned up. + if obligation.predicate.args.type_at(0).is_ty_var() + || obligation.predicate.args.type_at(4).is_ty_var() + || obligation.predicate.args.type_at(5).is_ty_var() + { + candidate_set.mark_ambiguous(); + true + } else if obligation.predicate.args.type_at(0).to_opt_closure_kind().is_some() + && obligation.predicate.args.type_at(1).to_opt_closure_kind().is_some() + { + true + } else { + false + } + } else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) { match self_ty.kind() { ty::Bool | ty::Char @@ -2061,6 +2079,10 @@ fn confirm_select_candidate<'cx, 'tcx>( } else { confirm_fn_pointer_candidate(selcx, obligation, data) } + } else if selcx.tcx().async_fn_trait_kind_from_def_id(trait_def_id).is_some() { + confirm_async_closure_candidate(selcx, obligation, data) + } else if lang_items.async_fn_kind_helper() == Some(trait_def_id) { + confirm_async_fn_kind_helper_candidate(selcx, obligation, data) } else { confirm_builtin_candidate(selcx, obligation, data) } @@ -2421,6 +2443,164 @@ fn confirm_callable_candidate<'cx, 'tcx>( confirm_param_env_candidate(selcx, obligation, predicate, true) } +fn confirm_async_closure_candidate<'cx, 'tcx>( + selcx: &mut SelectionContext<'cx, 'tcx>, + obligation: &ProjectionTyObligation<'tcx>, + mut nested: Vec>, +) -> Progress<'tcx> { + let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty()); + let ty::CoroutineClosure(def_id, args) = *self_ty.kind() else { + unreachable!( + "expected coroutine-closure self type for coroutine-closure candidate, found {self_ty}" + ) + }; + let args = args.as_coroutine_closure(); + let kind_ty = args.kind_ty(); + + let tcx = selcx.tcx(); + let goal_kind = + tcx.async_fn_trait_kind_from_def_id(obligation.predicate.trait_def_id(tcx)).unwrap(); + + let helper_trait_def_id = tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); + nested.push(obligation.with( + tcx, + ty::TraitRef::new( + tcx, + helper_trait_def_id, + [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], + ), + )); + + let env_region = match goal_kind { + ty::ClosureKind::Fn | ty::ClosureKind::FnMut => obligation.predicate.args.region_at(2), + ty::ClosureKind::FnOnce => tcx.lifetimes.re_static, + }; + + // FIXME(async_closures): Make this into a lang item. + let upvars_projection_def_id = + tcx.associated_items(helper_trait_def_id).in_definition_order().next().unwrap().def_id; + + // FIXME(async_closures): Confirmation is kind of a mess here. Ideally, + // we'd short-circuit when we know that the goal_kind >= closure_kind, and not + // register a nested predicate or create a new projection ty here. But I'm too + // lazy to make this more efficient atm, and we can always tweak it later, + // since all this does is make the solver do more work. + // + // The code duplication due to the different length args is kind of weird, too. + let poly_cache_entry = args.coroutine_closure_sig().map_bound(|sig| { + let (projection_ty, term) = match tcx.item_name(obligation.predicate.def_id) { + sym::CallOnceFuture => { + let tupled_upvars_ty = Ty::new_projection( + tcx, + upvars_projection_def_id, + [ + ty::GenericArg::from(kind_ty), + Ty::from_closure_kind(tcx, goal_kind).into(), + env_region.into(), + sig.tupled_inputs_ty.into(), + args.tupled_upvars_ty().into(), + args.coroutine_captures_by_ref_ty().into(), + ], + ); + let coroutine_ty = sig.to_coroutine( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ); + ( + ty::AliasTy::new( + tcx, + obligation.predicate.def_id, + [self_ty, sig.tupled_inputs_ty], + ), + coroutine_ty.into(), + ) + } + sym::CallMutFuture | sym::CallFuture => { + let tupled_upvars_ty = Ty::new_projection( + tcx, + upvars_projection_def_id, + [ + ty::GenericArg::from(kind_ty), + Ty::from_closure_kind(tcx, goal_kind).into(), + env_region.into(), + sig.tupled_inputs_ty.into(), + args.tupled_upvars_ty().into(), + args.coroutine_captures_by_ref_ty().into(), + ], + ); + let coroutine_ty = sig.to_coroutine( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ); + ( + ty::AliasTy::new( + tcx, + obligation.predicate.def_id, + [ + ty::GenericArg::from(self_ty), + sig.tupled_inputs_ty.into(), + env_region.into(), + ], + ), + coroutine_ty.into(), + ) + } + sym::Output => ( + ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.tupled_inputs_ty]), + sig.return_ty.into(), + ), + name => bug!("no such associated type: {name}"), + }; + ty::ProjectionPredicate { projection_ty, term } + }); + + confirm_param_env_candidate(selcx, obligation, poly_cache_entry, true) + .with_addl_obligations(nested) +} + +fn confirm_async_fn_kind_helper_candidate<'cx, 'tcx>( + selcx: &mut SelectionContext<'cx, 'tcx>, + obligation: &ProjectionTyObligation<'tcx>, + nested: Vec>, +) -> Progress<'tcx> { + let [ + // We already checked that the goal_kind >= closure_kind + _closure_kind_ty, + goal_kind_ty, + borrow_region, + tupled_inputs_ty, + tupled_upvars_ty, + coroutine_captures_by_ref_ty, + ] = **obligation.predicate.args + else { + bug!(); + }; + + let predicate = ty::ProjectionPredicate { + projection_ty: ty::AliasTy::new( + selcx.tcx(), + obligation.predicate.def_id, + obligation.predicate.args, + ), + term: ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind( + selcx.tcx(), + goal_kind_ty.expect_ty().to_opt_closure_kind().unwrap(), + tupled_inputs_ty.expect_ty(), + tupled_upvars_ty.expect_ty(), + coroutine_captures_by_ref_ty.expect_ty(), + borrow_region.expect_region(), + ) + .into(), + }; + + confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false) + .with_addl_obligations(nested) +} + fn confirm_param_env_candidate<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTyObligation<'tcx>, diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index b354ebf111f..75aedd5cd22 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -117,9 +117,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { self.assemble_iterator_candidates(obligation, &mut candidates); } else if lang_items.async_iterator_trait() == Some(def_id) { self.assemble_async_iterator_candidates(obligation, &mut candidates); + } else if lang_items.async_fn_kind_helper() == Some(def_id) { + self.assemble_async_fn_kind_helper_candidates(obligation, &mut candidates); } self.assemble_closure_candidates(obligation, &mut candidates); + self.assemble_async_closure_candidates(obligation, &mut candidates); self.assemble_fn_pointer_candidates(obligation, &mut candidates); self.assemble_candidates_from_impls(obligation, &mut candidates); self.assemble_candidates_from_object_ty(obligation, &mut candidates); @@ -335,6 +338,49 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { } } + fn assemble_async_closure_candidates( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + candidates: &mut SelectionCandidateSet<'tcx>, + ) { + let Some(goal_kind) = + self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()) + else { + return; + }; + + match *obligation.self_ty().skip_binder().kind() { + ty::CoroutineClosure(_, args) => { + if let Some(closure_kind) = + args.as_coroutine_closure().kind_ty().to_opt_closure_kind() + && !closure_kind.extends(goal_kind) + { + return; + } + candidates.vec.push(AsyncClosureCandidate); + } + ty::Infer(ty::TyVar(_)) => { + candidates.ambiguous = true; + } + _ => {} + } + } + + fn assemble_async_fn_kind_helper_candidates( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + candidates: &mut SelectionCandidateSet<'tcx>, + ) { + if let Some(closure_kind) = obligation.self_ty().skip_binder().to_opt_closure_kind() + && let Some(goal_kind) = + obligation.predicate.skip_binder().trait_ref.args.type_at(1).to_opt_closure_kind() + { + if closure_kind.extends(goal_kind) { + candidates.vec.push(AsyncFnKindHelperCandidate); + } + } + } + /// Implements one of the `Fn()` family for a fn pointer. fn assemble_fn_pointer_candidates( &mut self, diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs index 74f388e53a3..79336a9d4e8 100644 --- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs +++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs @@ -83,6 +83,13 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { ImplSource::Builtin(BuiltinImplSource::Misc, vtable_closure) } + AsyncClosureCandidate => { + let vtable_closure = self.confirm_async_closure_candidate(obligation)?; + ImplSource::Builtin(BuiltinImplSource::Misc, vtable_closure) + } + + AsyncFnKindHelperCandidate => ImplSource::Builtin(BuiltinImplSource::Misc, vec![]), + CoroutineCandidate => { let vtable_coroutine = self.confirm_coroutine_candidate(obligation)?; ImplSource::Builtin(BuiltinImplSource::Misc, vtable_coroutine) @@ -869,6 +876,49 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { Ok(nested) } + #[instrument(skip(self), level = "debug")] + fn confirm_async_closure_candidate( + &mut self, + obligation: &PolyTraitObligation<'tcx>, + ) -> Result>, SelectionError<'tcx>> { + // Okay to skip binder because the args on closure types never + // touch bound regions, they just capture the in-scope + // type/region parameters. + let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); + let ty::CoroutineClosure(closure_def_id, args) = *self_ty.kind() else { + bug!("async closure candidate for non-coroutine-closure {:?}", obligation); + }; + + let trait_ref = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { + ty::TraitRef::new( + self.tcx(), + obligation.predicate.def_id(), + [self_ty, sig.tupled_inputs_ty], + ) + }); + + let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?; + + let goal_kind = + self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap(); + nested.push(obligation.with( + self.tcx(), + ty::TraitRef::from_lang_item( + self.tcx(), + LangItem::AsyncFnKindHelper, + obligation.cause.span, + [ + args.as_coroutine_closure().kind_ty(), + Ty::from_closure_kind(self.tcx(), goal_kind), + ], + ), + )); + + debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations"); + + Ok(nested) + } + /// In the case of closure types and fn pointers, /// we currently treat the input type parameters on the trait as /// outputs. This means that when we have a match we have only diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs index e79277b89c4..7f41c73b72f 100644 --- a/compiler/rustc_trait_selection/src/traits/select/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs @@ -1864,6 +1864,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> { ImplCandidate(..) | AutoImplCandidate | ClosureCandidate { .. } + | AsyncClosureCandidate + | AsyncFnKindHelperCandidate | CoroutineCandidate | FutureCandidate | IteratorCandidate @@ -1894,6 +1896,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> { ImplCandidate(_) | AutoImplCandidate | ClosureCandidate { .. } + | AsyncClosureCandidate + | AsyncFnKindHelperCandidate | CoroutineCandidate | FutureCandidate | IteratorCandidate @@ -1930,6 +1934,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> { ImplCandidate(..) | AutoImplCandidate | ClosureCandidate { .. } + | AsyncClosureCandidate + | AsyncFnKindHelperCandidate | CoroutineCandidate | FutureCandidate | IteratorCandidate @@ -1946,6 +1952,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> { ImplCandidate(..) | AutoImplCandidate | ClosureCandidate { .. } + | AsyncClosureCandidate + | AsyncFnKindHelperCandidate | CoroutineCandidate | FutureCandidate | IteratorCandidate @@ -2054,6 +2062,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> { ( ImplCandidate(_) | ClosureCandidate { .. } + | AsyncClosureCandidate + | AsyncFnKindHelperCandidate | CoroutineCandidate | FutureCandidate | IteratorCandidate @@ -2066,6 +2076,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> { | TraitAliasCandidate, ImplCandidate(_) | ClosureCandidate { .. } + | AsyncClosureCandidate + | AsyncFnKindHelperCandidate | CoroutineCandidate | FutureCandidate | IteratorCandidate diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs index af26616831f..8d1bdec6684 100644 --- a/compiler/rustc_ty_utils/src/abi.rs +++ b/compiler/rustc_ty_utils/src/abi.rs @@ -101,6 +101,42 @@ fn fn_sig_for_fn_abi<'tcx>( bound_vars, ) } + ty::CoroutineClosure(def_id, args) => { + let sig = args.as_coroutine_closure().coroutine_closure_sig(); + let bound_vars = tcx.mk_bound_variable_kinds_from_iter( + sig.bound_vars().iter().chain(iter::once(ty::BoundVariableKind::Region(ty::BrEnv))), + ); + let br = ty::BoundRegion { + var: ty::BoundVar::from_usize(bound_vars.len() - 1), + kind: ty::BoundRegionKind::BrEnv, + }; + let env_region = ty::Region::new_bound(tcx, ty::INNERMOST, br); + let env_ty = tcx.closure_env_ty( + Ty::new_coroutine_closure(tcx, def_id, args), + args.as_coroutine_closure().kind(), + env_region, + ); + + let sig = sig.skip_binder(); + ty::Binder::bind_with_vars( + tcx.mk_fn_sig( + iter::once(env_ty).chain([sig.tupled_inputs_ty]), + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.as_coroutine_closure().parent_args(), + tcx.coroutine_for_closure(def_id), + args.as_coroutine_closure().kind(), + env_region, + args.as_coroutine_closure().tupled_upvars_ty(), + args.as_coroutine_closure().coroutine_captures_by_ref_ty(), + ), + sig.c_variadic, + sig.unsafety, + sig.abi, + ), + bound_vars, + ) + } ty::Coroutine(did, args) => { let coroutine_kind = tcx.coroutine_kind(did).unwrap(); let sig = args.as_coroutine().sig(); diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs index a5a2c53c732..50d27ed4ced 100644 --- a/compiler/rustc_ty_utils/src/instance.rs +++ b/compiler/rustc_ty_utils/src/instance.rs @@ -38,6 +38,7 @@ fn resolve_instance<'tcx>( debug!(" => nontrivial drop glue"); match *ty.kind() { ty::Closure(..) + | ty::CoroutineClosure(..) | ty::Coroutine(..) | ty::Tuple(..) | ty::Adt(..) @@ -282,6 +283,16 @@ fn resolve_associated_item<'tcx>( tcx.item_name(trait_item_id) ), } + } else if tcx.async_fn_trait_kind_from_def_id(trait_ref.def_id).is_some() { + match *rcvr_args.type_at(0).kind() { + ty::CoroutineClosure(closure_def_id, args) => { + Some(Instance::new(closure_def_id, args)) + } + _ => bug!( + "no built-in definition for `{trait_ref}::{}` for non-lending-closure type", + tcx.item_name(trait_item_id) + ), + } } else { Instance::try_resolve_item_for_coroutine(tcx, trait_item_id, trait_id, rcvr_args) } diff --git a/library/core/src/ops/async_function.rs b/library/core/src/ops/async_function.rs index 965873f163e..b11d5643990 100644 --- a/library/core/src/ops/async_function.rs +++ b/library/core/src/ops/async_function.rs @@ -106,3 +106,11 @@ mod impls { } } } + +mod internal_implementation_detail { + // TODO: needs a detailed explanation + #[cfg_attr(not(bootstrap), lang = "async_fn_kind_helper")] + trait AsyncFnKindHelper { + type Assoc<'closure_env, Inputs, Upvars, BorrowedUpvarsAsFnPtr>; + } +} -- cgit 1.4.1-3-g733a5 From fc4fff40385252212b9921928927568f233ba02f Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Thu, 25 Jan 2024 00:30:55 +0000 Subject: Build a shim to call async closures with different AsyncFn trait kinds --- .../rustc_const_eval/src/interpret/terminator.rs | 1 + compiler/rustc_middle/src/mir/mono.rs | 1 + compiler/rustc_middle/src/mir/visit.rs | 1 + compiler/rustc_middle/src/ty/instance.rs | 23 +++- compiler/rustc_middle/src/ty/mod.rs | 1 + compiler/rustc_mir_transform/src/inline.rs | 1 + compiler/rustc_mir_transform/src/inline/cycle.rs | 1 + compiler/rustc_mir_transform/src/shim.rs | 121 ++++++++++++++++++++- compiler/rustc_monomorphize/src/collector.rs | 1 + compiler/rustc_monomorphize/src/partitioning.rs | 2 + compiler/rustc_smir/src/rustc_smir/convert/ty.rs | 1 + compiler/rustc_ty_utils/src/abi.rs | 15 ++- compiler/rustc_ty_utils/src/instance.rs | 17 ++- 13 files changed, 175 insertions(+), 11 deletions(-) (limited to 'compiler/rustc_const_eval/src') diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs index b7ffb4a16fc..4c8f68b25b5 100644 --- a/compiler/rustc_const_eval/src/interpret/terminator.rs +++ b/compiler/rustc_const_eval/src/interpret/terminator.rs @@ -545,6 +545,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { ty::InstanceDef::VTableShim(..) | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::ClosureOnceShim { .. } + | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_middle/src/mir/mono.rs b/compiler/rustc_middle/src/mir/mono.rs index 91fdf0b3129..4a29171d8bf 100644 --- a/compiler/rustc_middle/src/mir/mono.rs +++ b/compiler/rustc_middle/src/mir/mono.rs @@ -402,6 +402,7 @@ impl<'tcx> CodegenUnit<'tcx> { | InstanceDef::FnPtrShim(..) | InstanceDef::Virtual(..) | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs index 591104ecc66..6bc58adea0f 100644 --- a/compiler/rustc_middle/src/mir/visit.rs +++ b/compiler/rustc_middle/src/mir/visit.rs @@ -345,6 +345,7 @@ macro_rules! make_mir_visitor { ty::InstanceDef::Virtual(_def_id, _) | ty::InstanceDef::ThreadLocalShim(_def_id) | ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } | + ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } | ty::InstanceDef::DropGlue(_def_id, None) => {} ty::InstanceDef::FnPtrShim(_def_id, ty) | diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs index 293fdb026b6..41ae136851e 100644 --- a/compiler/rustc_middle/src/ty/instance.rs +++ b/compiler/rustc_middle/src/ty/instance.rs @@ -82,11 +82,25 @@ pub enum InstanceDef<'tcx> { /// details on that). Virtual(DefId, usize), - /// `<[FnMut closure] as FnOnce>::call_once`. + /// `<[FnMut/Fn closure] as FnOnce>::call_once`. /// /// The `DefId` is the ID of the `call_once` method in `FnOnce`. + /// + /// This generates a body that will just borrow the (owned) self type, + /// and dispatch to the `FnMut::call_mut` instance for the closure. ClosureOnceShim { call_once: DefId, track_caller: bool }, + /// `<[FnMut/Fn coroutine-closure] as FnOnce>::call_once` or + /// `<[Fn coroutine-closure] as FnMut>::call_mut`. + /// + /// The body generated here differs significantly from the `ClosureOnceShim`, + /// since we need to generate a distinct coroutine type that will move the + /// closure's upvars *out* of the closure. + ConstructCoroutineInClosureShim { + coroutine_closure_def_id: DefId, + target_kind: ty::ClosureKind, + }, + /// Compiler-generated accessor for thread locals which returns a reference to the thread local /// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking /// native support. @@ -168,6 +182,10 @@ impl<'tcx> InstanceDef<'tcx> { | InstanceDef::Intrinsic(def_id) | InstanceDef::ThreadLocalShim(def_id) | InstanceDef::ClosureOnceShim { call_once: def_id, track_caller: _ } + | ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id: def_id, + target_kind: _, + } | InstanceDef::DropGlue(def_id, _) | InstanceDef::CloneShim(def_id, _) | InstanceDef::FnPtrAddrShim(def_id, _) => def_id, @@ -187,6 +205,7 @@ impl<'tcx> InstanceDef<'tcx> { | InstanceDef::Virtual(..) | InstanceDef::Intrinsic(..) | InstanceDef::ClosureOnceShim { .. } + | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::FnPtrAddrShim(..) => None, @@ -282,6 +301,7 @@ impl<'tcx> InstanceDef<'tcx> { | InstanceDef::FnPtrShim(..) | InstanceDef::DropGlue(_, Some(_)) => false, InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::Item(_) | InstanceDef::Intrinsic(..) @@ -319,6 +339,7 @@ fn fmt_instance( InstanceDef::Virtual(_, num) => write!(f, " - virtual#{num}"), InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"), InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"), + InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"), InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"), InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"), InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"), diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 4348d01eff5..05875a9798b 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -1680,6 +1680,7 @@ impl<'tcx> TyCtxt<'tcx> { | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::Virtual(..) | ty::InstanceDef::ClosureOnceShim { .. } + | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) | ty::InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 67668a216de..7c731b070a7 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -317,6 +317,7 @@ impl<'tcx> Inliner<'tcx> { | InstanceDef::ReifyShim(_) | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index d30e0bad813..3f3dc9145b6 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -87,6 +87,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( | InstanceDef::ReifyShim(_) | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } | InstanceDef::ThreadLocalShim { .. } | InstanceDef::CloneShim(..) => {} diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 89414ce940e..29b83f58ef5 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -3,8 +3,8 @@ use rustc_hir::def_id::DefId; use rustc_hir::lang_items::LangItem; use rustc_middle::mir::*; use rustc_middle::query::Providers; -use rustc_middle::ty::GenericArgs; use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt}; +use rustc_middle::ty::{GenericArgs, CAPTURE_STRUCT_LOCAL}; use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; use rustc_index::{Idx, IndexVec}; @@ -66,6 +66,21 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut)) } + ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind, + } => match target_kind { + ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"), + ty::ClosureKind::FnMut => { + let body = build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id); + // No need to optimize the body, it has already been optimized. + return body; + } + ty::ClosureKind::FnOnce => { + build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id) + } + }, + ty::InstanceDef::DropGlue(def_id, ty) => { // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end // of this function. Is this intentional? @@ -981,3 +996,107 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t let source = MirSource::from_instance(ty::InstanceDef::FnPtrAddrShim(def_id, self_ty)); new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span) } + +fn build_construct_coroutine_by_move_shim<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_closure_def_id: DefId, +) -> Body<'tcx> { + let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let ty::CoroutineClosure(_, args) = *self_ty.kind() else { + bug!(); + }; + + let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { + tcx.mk_fn_sig( + [self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()), + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.as_coroutine_closure().parent_args(), + tcx.coroutine_for_closure(coroutine_closure_def_id), + ty::ClosureKind::FnOnce, + tcx.lifetimes.re_erased, + args.as_coroutine_closure().tupled_upvars_ty(), + args.as_coroutine_closure().coroutine_captures_by_ref_ty(), + ), + sig.c_variadic, + sig.unsafety, + sig.abi, + ) + }); + let sig = tcx.liberate_late_bound_regions(coroutine_closure_def_id, poly_sig); + let ty::Coroutine(coroutine_def_id, coroutine_args) = *sig.output().kind() else { + bug!(); + }; + + let span = tcx.def_span(coroutine_closure_def_id); + let locals = local_decls_for_sig(&sig, span); + + let mut fields = vec![]; + for idx in 1..sig.inputs().len() { + fields.push(Operand::Move(Local::from_usize(idx + 1).into())); + } + for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() { + fields.push(Operand::Move(tcx.mk_place_field( + Local::from_usize(1).into(), + FieldIdx::from_usize(idx), + ty, + ))); + } + + let source_info = SourceInfo::outermost(span); + let rvalue = Rvalue::Aggregate( + Box::new(AggregateKind::Coroutine(coroutine_def_id, coroutine_args)), + IndexVec::from_raw(fields), + ); + let stmt = Statement { + source_info, + kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))), + }; + let statements = vec![stmt]; + let start_block = BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }; + + let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind: ty::ClosureKind::FnOnce, + }); + + new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span) +} + +fn build_construct_coroutine_by_mut_shim<'tcx>( + tcx: TyCtxt<'tcx>, + coroutine_closure_def_id: DefId, +) -> Body<'tcx> { + let mut body = tcx.optimized_mir(coroutine_closure_def_id).clone(); + let coroutine_closure_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity(); + let ty::CoroutineClosure(_, args) = *coroutine_closure_ty.kind() else { + bug!(); + }; + let args = args.as_coroutine_closure(); + + body.local_decls[RETURN_PLACE].ty = + tcx.instantiate_bound_regions_with_erased(args.coroutine_closure_sig().map_bound(|sig| { + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(coroutine_closure_def_id), + ty::ClosureKind::FnMut, + tcx.lifetimes.re_erased, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ) + })); + body.local_decls[CAPTURE_STRUCT_LOCAL].ty = + Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_closure_ty); + + body.source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind: ty::ClosureKind::FnMut, + }); + + body +} diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index b86244e5a4b..698fd634114 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -983,6 +983,7 @@ fn visit_instance_use<'tcx>( | ty::InstanceDef::VTableShim(..) | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } | ty::InstanceDef::Item(..) | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 7ff182381b8..46bd33c89e7 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -620,6 +620,7 @@ fn characteristic_def_id_of_mono_item<'tcx>( | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } | ty::InstanceDef::Intrinsic(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::Virtual(..) @@ -783,6 +784,7 @@ fn mono_item_visibility<'tcx>( | InstanceDef::Virtual(..) | InstanceDef::Intrinsic(..) | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::ConstructCoroutineInClosureShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::FnPtrAddrShim(..) => return Visibility::Hidden, diff --git a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs index e3fd64eb189..e0e9815cf40 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs @@ -799,6 +799,7 @@ impl<'tcx> Stable<'tcx> for ty::Instance<'tcx> { | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::FnPtrAddrShim(..) | ty::InstanceDef::ClosureOnceShim { .. } + | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } | ty::InstanceDef::ThreadLocalShim(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs index 8d1bdec6684..3ea48ee824d 100644 --- a/compiler/rustc_ty_utils/src/abi.rs +++ b/compiler/rustc_ty_utils/src/abi.rs @@ -111,11 +111,14 @@ fn fn_sig_for_fn_abi<'tcx>( kind: ty::BoundRegionKind::BrEnv, }; let env_region = ty::Region::new_bound(tcx, ty::INNERMOST, br); - let env_ty = tcx.closure_env_ty( - Ty::new_coroutine_closure(tcx, def_id, args), - args.as_coroutine_closure().kind(), - env_region, - ); + + let mut kind = args.as_coroutine_closure().kind(); + if let InstanceDef::ConstructCoroutineInClosureShim { target_kind, .. } = instance.def { + kind = target_kind; + } + + let env_ty = + tcx.closure_env_ty(Ty::new_coroutine_closure(tcx, def_id, args), kind, env_region); let sig = sig.skip_binder(); ty::Binder::bind_with_vars( @@ -125,7 +128,7 @@ fn fn_sig_for_fn_abi<'tcx>( tcx, args.as_coroutine_closure().parent_args(), tcx.coroutine_for_closure(def_id), - args.as_coroutine_closure().kind(), + kind, env_region, args.as_coroutine_closure().tupled_upvars_ty(), args.as_coroutine_closure().coroutine_captures_by_ref_ty(), diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs index 50d27ed4ced..fc224d242db 100644 --- a/compiler/rustc_ty_utils/src/instance.rs +++ b/compiler/rustc_ty_utils/src/instance.rs @@ -283,10 +283,21 @@ fn resolve_associated_item<'tcx>( tcx.item_name(trait_item_id) ), } - } else if tcx.async_fn_trait_kind_from_def_id(trait_ref.def_id).is_some() { + } else if let Some(target_kind) = tcx.async_fn_trait_kind_from_def_id(trait_ref.def_id) + { match *rcvr_args.type_at(0).kind() { - ty::CoroutineClosure(closure_def_id, args) => { - Some(Instance::new(closure_def_id, args)) + ty::CoroutineClosure(coroutine_closure_def_id, args) => { + if target_kind > args.as_coroutine_closure().kind() { + Some(Instance { + def: ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind, + }, + args, + }) + } else { + Some(Instance::new(coroutine_closure_def_id, args)) + } } _ => bug!( "no built-in definition for `{trait_ref}::{}` for non-lending-closure type", -- cgit 1.4.1-3-g733a5 From 427896dd7e39f1aaf3e3cbc15e5ddf77d45a6aec Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Wed, 24 Jan 2024 23:38:33 +0000 Subject: Construct body for by-move coroutine closure output --- .../rustc_borrowck/src/type_check/input_output.rs | 1 + .../rustc_const_eval/src/interpret/terminator.rs | 1 + compiler/rustc_hir_typeck/src/callee.rs | 1 + compiler/rustc_hir_typeck/src/closure.rs | 11 +++ compiler/rustc_hir_typeck/src/upvar.rs | 10 ++ compiler/rustc_middle/src/mir/mod.rs | 5 + compiler/rustc_middle/src/mir/mono.rs | 1 + compiler/rustc_middle/src/mir/visit.rs | 1 + compiler/rustc_middle/src/ty/instance.rs | 21 +++- compiler/rustc_middle/src/ty/mod.rs | 1 + compiler/rustc_middle/src/ty/sty.rs | 47 ++++++--- compiler/rustc_mir_transform/src/coroutine.rs | 3 + .../src/coroutine/by_move_body.rs | 108 +++++++++++++++++++++ compiler/rustc_mir_transform/src/inline.rs | 1 + compiler/rustc_mir_transform/src/inline/cycle.rs | 1 + compiler/rustc_mir_transform/src/lib.rs | 4 + compiler/rustc_mir_transform/src/pass_manager.rs | 6 ++ compiler/rustc_mir_transform/src/shim.rs | 12 +++ compiler/rustc_monomorphize/src/collector.rs | 3 +- compiler/rustc_monomorphize/src/partitioning.rs | 4 +- compiler/rustc_smir/src/rustc_smir/convert/ty.rs | 1 + .../src/solve/assembly/structural_traits.rs | 1 + .../rustc_trait_selection/src/traits/project.rs | 2 + ...sync_await.b-{closure#0}.coroutine_resume.0.mir | 2 + 24 files changed, 233 insertions(+), 15 deletions(-) create mode 100644 compiler/rustc_mir_transform/src/coroutine/by_move_body.rs (limited to 'compiler/rustc_const_eval/src') diff --git a/compiler/rustc_borrowck/src/type_check/input_output.rs b/compiler/rustc_borrowck/src/type_check/input_output.rs index a3e5088ee09..ace9c5ae71d 100644 --- a/compiler/rustc_borrowck/src/type_check/input_output.rs +++ b/compiler/rustc_borrowck/src/type_check/input_output.rs @@ -85,6 +85,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { self.tcx(), ty::CoroutineArgsParts { parent_args: args.parent_args(), + kind_ty: Ty::from_closure_kind(self.tcx(), args.kind()), resume_ty: next_ty_var(), yield_ty: next_ty_var(), witness: next_ty_var(), diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs index 4c8f68b25b5..b8d6836da14 100644 --- a/compiler/rustc_const_eval/src/interpret/terminator.rs +++ b/compiler/rustc_const_eval/src/interpret/terminator.rs @@ -546,6 +546,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_hir_typeck/src/callee.rs b/compiler/rustc_hir_typeck/src/callee.rs index 1858b2770cd..730a475f630 100644 --- a/compiler/rustc_hir_typeck/src/callee.rs +++ b/compiler/rustc_hir_typeck/src/callee.rs @@ -183,6 +183,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { coroutine_closure_sig.to_coroutine( self.tcx, closure_args.parent_args(), + closure_args.kind_ty(), self.tcx.coroutine_for_closure(def_id), tupled_upvars_ty, ), diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index 1d024efdd49..014293c1f83 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -175,10 +175,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { interior, )); + let kind_ty = match kind { + hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure) => self + .next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::ClosureSynthetic, + span: expr_span, + }), + _ => tcx.types.unit, + }; + let coroutine_args = ty::CoroutineArgs::new( tcx, ty::CoroutineArgsParts { parent_args, + kind_ty, resume_ty, yield_ty, return_ty: liberated_sig.output(), @@ -256,6 +266,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { sig.to_coroutine( tcx, parent_args, + closure_kind_ty, tcx.coroutine_for_closure(expr_def_id), coroutine_upvars_ty, ) diff --git a/compiler/rustc_hir_typeck/src/upvar.rs b/compiler/rustc_hir_typeck/src/upvar.rs index b087d6d9e57..d4e072976fa 100644 --- a/compiler/rustc_hir_typeck/src/upvar.rs +++ b/compiler/rustc_hir_typeck/src/upvar.rs @@ -393,6 +393,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { args.as_coroutine_closure().coroutine_captures_by_ref_ty(), coroutine_captures_by_ref_ty, ); + + let ty::Coroutine(_, args) = *self.typeck_results.borrow().expr_ty(body.value).kind() + else { + bug!(); + }; + self.demand_eqtype( + span, + args.as_coroutine().kind_ty(), + Ty::from_closure_kind(self.tcx, closure_kind), + ); } self.log_closure_min_capture_info(closure_def_id, span); diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index c9e69253701..d88e9261e5a 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -262,6 +262,10 @@ pub struct CoroutineInfo<'tcx> { /// Coroutine drop glue. This field is populated after the state transform pass. pub coroutine_drop: Option>, + /// The body of the coroutine, modified to take its upvars by move. + /// TODO: + pub by_move_body: Option>, + /// The layout of a coroutine. This field is populated after the state transform pass. pub coroutine_layout: Option>, @@ -281,6 +285,7 @@ impl<'tcx> CoroutineInfo<'tcx> { coroutine_kind, yield_ty: Some(yield_ty), resume_ty: Some(resume_ty), + by_move_body: None, coroutine_drop: None, coroutine_layout: None, } diff --git a/compiler/rustc_middle/src/mir/mono.rs b/compiler/rustc_middle/src/mir/mono.rs index 4a29171d8bf..e6d1535fdf2 100644 --- a/compiler/rustc_middle/src/mir/mono.rs +++ b/compiler/rustc_middle/src/mir/mono.rs @@ -403,6 +403,7 @@ impl<'tcx> CodegenUnit<'tcx> { | InstanceDef::Virtual(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs index 6bc58adea0f..ce1859d6ada 100644 --- a/compiler/rustc_middle/src/mir/visit.rs +++ b/compiler/rustc_middle/src/mir/visit.rs @@ -346,6 +346,7 @@ macro_rules! make_mir_visitor { ty::InstanceDef::ThreadLocalShim(_def_id) | ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } | ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } | + ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: _def_id } | ty::InstanceDef::DropGlue(_def_id, None) => {} ty::InstanceDef::FnPtrShim(_def_id, ty) | diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs index 41ae136851e..44bf3c32b48 100644 --- a/compiler/rustc_middle/src/ty/instance.rs +++ b/compiler/rustc_middle/src/ty/instance.rs @@ -101,6 +101,9 @@ pub enum InstanceDef<'tcx> { target_kind: ty::ClosureKind, }, + /// TODO: + CoroutineByMoveShim { coroutine_def_id: DefId }, + /// Compiler-generated accessor for thread locals which returns a reference to the thread local /// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking /// native support. @@ -186,6 +189,7 @@ impl<'tcx> InstanceDef<'tcx> { coroutine_closure_def_id: def_id, target_kind: _, } + | ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: def_id } | InstanceDef::DropGlue(def_id, _) | InstanceDef::CloneShim(def_id, _) | InstanceDef::FnPtrAddrShim(def_id, _) => def_id, @@ -206,6 +210,7 @@ impl<'tcx> InstanceDef<'tcx> { | InstanceDef::Intrinsic(..) | InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::FnPtrAddrShim(..) => None, @@ -302,6 +307,7 @@ impl<'tcx> InstanceDef<'tcx> { | InstanceDef::DropGlue(_, Some(_)) => false, InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::Item(_) | InstanceDef::Intrinsic(..) @@ -340,6 +346,7 @@ fn fmt_instance( InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"), InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"), InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"), + InstanceDef::CoroutineByMoveShim { .. } => write!(f, " - shim"), InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"), InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"), InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"), @@ -631,7 +638,19 @@ impl<'tcx> Instance<'tcx> { }; if tcx.lang_items().get(coroutine_callable_item) == Some(trait_item_id) { - Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args }) + let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind() + else { + bug!() + }; + + if args.as_coroutine().kind_ty() == id_args.as_coroutine().kind_ty() { + Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args }) + } else { + Some(Instance { + def: ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id }, + args, + }) + } } else { // All other methods should be defaulted methods of the built-in trait. // This is important for `Iterator`'s combinators, but also useful for diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 05875a9798b..9ceb3ec3f61 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -1681,6 +1681,7 @@ impl<'tcx> TyCtxt<'tcx> { | ty::InstanceDef::Virtual(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) | ty::InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index e2a2e24f06d..8918a3735d6 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -399,6 +399,7 @@ impl<'tcx> CoroutineClosureSignature<'tcx> { self, tcx: TyCtxt<'tcx>, parent_args: &'tcx [GenericArg<'tcx>], + kind_ty: Ty<'tcx>, coroutine_def_id: DefId, tupled_upvars_ty: Ty<'tcx>, ) -> Ty<'tcx> { @@ -406,6 +407,7 @@ impl<'tcx> CoroutineClosureSignature<'tcx> { tcx, ty::CoroutineArgsParts { parent_args, + kind_ty, resume_ty: self.resume_ty, yield_ty: self.yield_ty, return_ty: self.return_ty, @@ -436,7 +438,13 @@ impl<'tcx> CoroutineClosureSignature<'tcx> { env_region, ); - self.to_coroutine(tcx, parent_args, coroutine_def_id, tupled_upvars_ty) + self.to_coroutine( + tcx, + parent_args, + Ty::from_closure_kind(tcx, closure_kind), + coroutine_def_id, + tupled_upvars_ty, + ) } /// Given a closure kind, compute the tupled upvars that the given coroutine would return. @@ -488,6 +496,8 @@ pub struct CoroutineArgs<'tcx> { pub struct CoroutineArgsParts<'tcx> { /// This is the args of the typeck root. pub parent_args: &'tcx [GenericArg<'tcx>], + // TODO: why + pub kind_ty: Ty<'tcx>, pub resume_ty: Ty<'tcx>, pub yield_ty: Ty<'tcx>, pub return_ty: Ty<'tcx>, @@ -506,6 +516,7 @@ impl<'tcx> CoroutineArgs<'tcx> { pub fn new(tcx: TyCtxt<'tcx>, parts: CoroutineArgsParts<'tcx>) -> CoroutineArgs<'tcx> { CoroutineArgs { args: tcx.mk_args_from_iter(parts.parent_args.iter().copied().chain([ + parts.kind_ty.into(), parts.resume_ty.into(), parts.yield_ty.into(), parts.return_ty.into(), @@ -519,16 +530,23 @@ impl<'tcx> CoroutineArgs<'tcx> { /// The ordering assumed here must match that used by `CoroutineArgs::new` above. fn split(self) -> CoroutineArgsParts<'tcx> { match self.args[..] { - [ref parent_args @ .., resume_ty, yield_ty, return_ty, witness, tupled_upvars_ty] => { - CoroutineArgsParts { - parent_args, - resume_ty: resume_ty.expect_ty(), - yield_ty: yield_ty.expect_ty(), - return_ty: return_ty.expect_ty(), - witness: witness.expect_ty(), - tupled_upvars_ty: tupled_upvars_ty.expect_ty(), - } - } + [ + ref parent_args @ .., + kind_ty, + resume_ty, + yield_ty, + return_ty, + witness, + tupled_upvars_ty, + ] => CoroutineArgsParts { + parent_args, + kind_ty: kind_ty.expect_ty(), + resume_ty: resume_ty.expect_ty(), + yield_ty: yield_ty.expect_ty(), + return_ty: return_ty.expect_ty(), + witness: witness.expect_ty(), + tupled_upvars_ty: tupled_upvars_ty.expect_ty(), + }, _ => bug!("coroutine args missing synthetics"), } } @@ -538,6 +556,11 @@ impl<'tcx> CoroutineArgs<'tcx> { self.split().parent_args } + // TODO: + pub fn kind_ty(self) -> Ty<'tcx> { + self.split().kind_ty + } + /// This describes the types that can be contained in a coroutine. /// It will be a type variable initially and unified in the last stages of typeck of a body. /// It contains a tuple of all the types that could end up on a coroutine frame. @@ -1628,7 +1651,7 @@ impl<'tcx> Ty<'tcx> { ) -> Ty<'tcx> { debug_assert_eq!( coroutine_args.len(), - tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 5, + tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 6, "coroutine constructed with incorrect number of substitutions" ); Ty::new(tcx, Coroutine(def_id, coroutine_args)) diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index bde879f6067..297b2fa143d 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -50,6 +50,9 @@ //! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing. //! Otherwise it drops all the values in scope at the last suspension point. +mod by_move_body; +pub use by_move_body::ByMoveBody; + use crate::abort_unwinding_calls; use crate::deref_separator::deref_finder; use crate::errors; diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs new file mode 100644 index 00000000000..4e3e70bdafe --- /dev/null +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -0,0 +1,108 @@ +use rustc_data_structures::fx::FxIndexSet; +use rustc_hir as hir; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::{self, MirPass}; +use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt}; +use rustc_target::abi::FieldIdx; + +pub struct ByMoveBody; + +impl<'tcx> MirPass<'tcx> for ByMoveBody { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { + let Some(coroutine_def_id) = body.source.def_id().as_local() else { + return; + }; + let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) = + tcx.coroutine_kind(coroutine_def_id) + else { + return; + }; + let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() }; + if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce { + return; + } + + let mut by_ref_fields = FxIndexSet::default(); + let by_move_upvars = Ty::new_tup_from_iter( + tcx, + tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| { + if capture.is_by_ref() { + by_ref_fields.insert(FieldIdx::from_usize(idx)); + } + capture.place.ty() + }), + ); + let by_move_coroutine_ty = Ty::new_coroutine( + tcx, + coroutine_def_id.to_def_id(), + ty::CoroutineArgs::new( + tcx, + ty::CoroutineArgsParts { + parent_args: args.as_coroutine().parent_args(), + kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce), + resume_ty: args.as_coroutine().resume_ty(), + yield_ty: args.as_coroutine().yield_ty(), + return_ty: args.as_coroutine().return_ty(), + witness: args.as_coroutine().witness(), + tupled_upvars_ty: by_move_upvars, + }, + ) + .args, + ); + + let mut by_move_body = body.clone(); + MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body); + by_move_body.source = mir::MirSource { + instance: InstanceDef::CoroutineByMoveShim { + coroutine_def_id: coroutine_def_id.to_def_id(), + }, + promoted: None, + }; + + body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body); + } +} + +struct MakeByMoveBody<'tcx> { + tcx: TyCtxt<'tcx>, + by_ref_fields: FxIndexSet, + by_move_coroutine_ty: Ty<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place( + &mut self, + place: &mut mir::Place<'tcx>, + context: mir::visit::PlaceContext, + location: mir::Location, + ) { + if place.local == ty::CAPTURE_STRUCT_LOCAL + && !place.projection.is_empty() + && let mir::ProjectionElem::Field(idx, ty) = place.projection[0] + && self.by_ref_fields.contains(&idx) + { + let (begin, end) = place.projection[1..].split_first().unwrap(); + assert_eq!(*begin, mir::ProjectionElem::Deref); + *place = mir::Place { + local: place.local, + projection: self.tcx.mk_place_elems_from_iter( + [mir::ProjectionElem::Field(idx, ty.builtin_deref(true).unwrap().ty)] + .into_iter() + .chain(end.iter().copied()), + ), + }; + } + self.super_place(place, context, location); + } + + fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) { + if local == ty::CAPTURE_STRUCT_LOCAL { + local_decl.ty = self.by_move_coroutine_ty; + } + } +} diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 7c731b070a7..24bc84a235c 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -318,6 +318,7 @@ impl<'tcx> Inliner<'tcx> { | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index 3f3dc9145b6..77ff780393e 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -88,6 +88,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::ThreadLocalShim { .. } | InstanceDef::CloneShim(..) => {} diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 69f93fa3a0e..031515ea958 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -307,6 +307,10 @@ fn mir_const(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal> { &Lint(check_packed_ref::CheckPackedRef), &Lint(check_const_item_mutation::CheckConstItemMutation), &Lint(function_item_references::FunctionItemReferences), + // If this is an async closure's output coroutine, generate + // by-move and by-mut bodies if needed. We do this first so + // they can be optimized in lockstep with their parent bodies. + &coroutine::ByMoveBody, // What we need to do constant evaluation. &simplify::SimplifyCfg::Initial, &rustc_peek::SanityCheck, // Just a lint diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index c1ef2b9f887..c7e770904fb 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -189,6 +189,12 @@ fn run_passes_inner<'tcx>( body.pass_count = 1; } + + if let Some(coroutine) = body.coroutine.as_mut() + && let Some(by_move_body) = coroutine.by_move_body.as_mut() + { + run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each); + } } pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: String) { diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 29b83f58ef5..668ccdd8735 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -81,6 +81,18 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' } }, + ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id } => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine + .as_ref() + .unwrap() + .by_move_body + .as_ref() + .unwrap() + .clone(); + } + ty::InstanceDef::DropGlue(def_id, ty) => { // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end // of this function. Is this intentional? diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 698fd634114..cf3c8e1fdd3 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -983,7 +983,8 @@ fn visit_instance_use<'tcx>( | ty::InstanceDef::VTableShim(..) | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::ClosureOnceShim { .. } - | InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | ty::InstanceDef::Item(..) | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 46bd33c89e7..22b35c4344b 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -620,7 +620,8 @@ fn characteristic_def_id_of_mono_item<'tcx>( | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::ClosureOnceShim { .. } - | InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | ty::InstanceDef::Intrinsic(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::Virtual(..) @@ -785,6 +786,7 @@ fn mono_item_visibility<'tcx>( | InstanceDef::Intrinsic(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::FnPtrAddrShim(..) => return Visibility::Hidden, diff --git a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs index e0e9815cf40..3c1858e920b 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs @@ -800,6 +800,7 @@ impl<'tcx> Stable<'tcx> for ty::Instance<'tcx> { | ty::InstanceDef::FnPtrAddrShim(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | ty::InstanceDef::ThreadLocalShim(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs index c35134c78eb..0699026117d 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs @@ -366,6 +366,7 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc let coroutine_ty = sig.to_coroutine( tcx, args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), tcx.coroutine_for_closure(def_id), tupled_upvars_ty, ); diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 648f14beaa7..db1e89ae72f 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -2505,6 +2505,7 @@ fn confirm_async_closure_candidate<'cx, 'tcx>( let coroutine_ty = sig.to_coroutine( tcx, args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), tcx.coroutine_for_closure(def_id), tupled_upvars_ty, ); @@ -2533,6 +2534,7 @@ fn confirm_async_closure_candidate<'cx, 'tcx>( let coroutine_ty = sig.to_coroutine( tcx, args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), tcx.coroutine_for_closure(def_id), tupled_upvars_ty, ); diff --git a/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir b/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir index 3c0d4008c90..9c8cf8763fd 100644 --- a/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir +++ b/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir @@ -5,6 +5,7 @@ ty: Coroutine( DefId(0:4 ~ async_await[ccf8]::a::{closure#0}), [ + (), std::future::ResumeTy, (), (), @@ -22,6 +23,7 @@ ty: Coroutine( DefId(0:4 ~ async_await[ccf8]::a::{closure#0}), [ + (), std::future::ResumeTy, (), (), -- cgit 1.4.1-3-g733a5 From ca444160232a1bf1914da906da9e061a7636955c Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Mon, 29 Jan 2024 17:41:51 +0000 Subject: Fix drop shim for AsyncFnOnce closure, AsyncFnMut shim for AsyncFn closure --- .../rustc_const_eval/src/interpret/terminator.rs | 2 +- compiler/rustc_middle/src/mir/mod.rs | 18 ++++- compiler/rustc_middle/src/mir/mono.rs | 2 +- compiler/rustc_middle/src/mir/visit.rs | 2 +- compiler/rustc_middle/src/ty/instance.rs | 26 ++++---- compiler/rustc_middle/src/ty/mod.rs | 2 +- compiler/rustc_middle/src/ty/print/mod.rs | 21 +++++- .../src/coroutine/by_move_body.rs | 45 +++++++++++-- compiler/rustc_mir_transform/src/inline.rs | 2 +- compiler/rustc_mir_transform/src/inline/cycle.rs | 2 +- compiler/rustc_mir_transform/src/pass_manager.rs | 11 ++-- compiler/rustc_mir_transform/src/shim.rs | 77 +++++++++++++++++----- compiler/rustc_monomorphize/src/collector.rs | 2 +- compiler/rustc_monomorphize/src/partitioning.rs | 4 +- compiler/rustc_smir/src/rustc_smir/convert/ty.rs | 2 +- compiler/rustc_symbol_mangling/src/legacy.rs | 33 +++++++--- compiler/rustc_symbol_mangling/src/v0.rs | 7 ++ compiler/rustc_ty_utils/src/abi.rs | 2 +- ...{closure#0}.coroutine_by_move.0.panic-abort.mir | 47 +++++++++++++ ...closure#0}.coroutine_by_move.0.panic-unwind.mir | 47 +++++++++++++ ...-{closure#0}.coroutine_by_mut.0.panic-abort.mir | 47 +++++++++++++ ...{closure#0}.coroutine_by_mut.0.panic-unwind.mir | 47 +++++++++++++ ...#0}.coroutine_closure_by_move.0.panic-abort.mir | 10 +++ ...0}.coroutine_closure_by_move.0.panic-unwind.mir | 10 +++ ...e#0}.coroutine_closure_by_mut.0.panic-abort.mir | 16 +++++ ...#0}.coroutine_closure_by_mut.0.panic-unwind.mir | 16 +++++ tests/mir-opt/async_closure_shims.rs | 46 +++++++++++++ .../async-closures/async-fn-mut-for-async-fn.rs | 23 +++++++ .../async-fn-mut-for-async-fn.stderr | 1 + .../async-fn-once-for-async-fn.stderr | 3 +- tests/ui/async-await/async-closures/def-path.rs | 4 +- .../ui/async-await/async-closures/def-path.stderr | 4 +- tests/ui/async-await/async-closures/drop.rs | 40 +++++++++++ .../ui/async-await/async-closures/drop.run.stdout | 5 ++ tests/ui/async-await/async-closures/mangle.rs | 36 ++++++++++ 35 files changed, 595 insertions(+), 67 deletions(-) create mode 100644 tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir create mode 100644 tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir create mode 100644 tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_mut.0.panic-abort.mir create mode 100644 tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_mut.0.panic-unwind.mir create mode 100644 tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir create mode 100644 tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir create mode 100644 tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_mut.0.panic-abort.mir create mode 100644 tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_mut.0.panic-unwind.mir create mode 100644 tests/mir-opt/async_closure_shims.rs create mode 100644 tests/ui/async-await/async-closures/async-fn-mut-for-async-fn.rs create mode 100644 tests/ui/async-await/async-closures/async-fn-mut-for-async-fn.stderr create mode 100644 tests/ui/async-await/async-closures/drop.rs create mode 100644 tests/ui/async-await/async-closures/drop.run.stdout create mode 100644 tests/ui/async-await/async-closures/mangle.rs (limited to 'compiler/rustc_const_eval/src') diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs index b8d6836da14..85a2e4778d2 100644 --- a/compiler/rustc_const_eval/src/interpret/terminator.rs +++ b/compiler/rustc_const_eval/src/interpret/terminator.rs @@ -546,7 +546,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } - | ty::InstanceDef::CoroutineByMoveShim { .. } + | ty::InstanceDef::CoroutineKindShim { .. } | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index 3d6c28088ad..9475b89aa15 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -265,7 +265,7 @@ pub struct CoroutineInfo<'tcx> { /// The body of the coroutine, modified to take its upvars by move rather than by ref. /// /// This is used by coroutine-closures, which must return a different flavor of coroutine - /// when called using `AsyncFnOnce::call_once`. It is produced by the `ByMoveBody` which + /// when called using `AsyncFnOnce::call_once`. It is produced by the `ByMoveBody` pass which /// is run right after building the initial MIR, and will only be populated for coroutines /// which come out of the async closure desugaring. /// @@ -274,6 +274,13 @@ pub struct CoroutineInfo<'tcx> { /// using `run_passes`. pub by_move_body: Option>, + /// The body of the coroutine, modified to take its upvars by mutable ref rather than by + /// immutable ref. + /// + /// FIXME(async_closures): This is literally the same body as the parent body. Find a better + /// way to represent the by-mut signature (or cap the closure-kind of the coroutine). + pub by_mut_body: Option>, + /// The layout of a coroutine. This field is populated after the state transform pass. pub coroutine_layout: Option>, @@ -294,6 +301,7 @@ impl<'tcx> CoroutineInfo<'tcx> { yield_ty: Some(yield_ty), resume_ty: Some(resume_ty), by_move_body: None, + by_mut_body: None, coroutine_drop: None, coroutine_layout: None, } @@ -604,6 +612,14 @@ impl<'tcx> Body<'tcx> { self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_drop.as_ref()) } + pub fn coroutine_by_move_body(&self) -> Option<&Body<'tcx>> { + self.coroutine.as_ref()?.by_move_body.as_ref() + } + + pub fn coroutine_by_mut_body(&self) -> Option<&Body<'tcx>> { + self.coroutine.as_ref()?.by_mut_body.as_ref() + } + #[inline] pub fn coroutine_kind(&self) -> Option { self.coroutine.as_ref().map(|coroutine| coroutine.coroutine_kind) diff --git a/compiler/rustc_middle/src/mir/mono.rs b/compiler/rustc_middle/src/mir/mono.rs index e6d1535fdf2..6937df7bb18 100644 --- a/compiler/rustc_middle/src/mir/mono.rs +++ b/compiler/rustc_middle/src/mir/mono.rs @@ -403,7 +403,7 @@ impl<'tcx> CodegenUnit<'tcx> { | InstanceDef::Virtual(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } - | InstanceDef::CoroutineByMoveShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs index ce1859d6ada..2c5ca82a4cd 100644 --- a/compiler/rustc_middle/src/mir/visit.rs +++ b/compiler/rustc_middle/src/mir/visit.rs @@ -346,7 +346,7 @@ macro_rules! make_mir_visitor { ty::InstanceDef::ThreadLocalShim(_def_id) | ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } | ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } | - ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: _def_id } | + ty::InstanceDef::CoroutineKindShim { coroutine_def_id: _def_id, target_kind: _ } | ty::InstanceDef::DropGlue(_def_id, None) => {} ty::InstanceDef::FnPtrShim(_def_id, ty) | diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs index 1c9415ef3b0..9c1f4b20d2c 100644 --- a/compiler/rustc_middle/src/ty/instance.rs +++ b/compiler/rustc_middle/src/ty/instance.rs @@ -102,10 +102,12 @@ pub enum InstanceDef<'tcx> { }, /// `<[coroutine] as Future>::poll`, but for coroutines produced when `AsyncFnOnce` - /// is called on a coroutine-closure whose closure kind is not `FnOnce`. This - /// will select the body that is produced by the `ByMoveBody` transform, and thus + /// is called on a coroutine-closure whose closure kind greater than `FnOnce`, or + /// similarly for `AsyncFnMut`. + /// + /// This will select the body that is produced by the `ByMoveBody` transform, and thus /// take and use all of its upvars by-move rather than by-ref. - CoroutineByMoveShim { coroutine_def_id: DefId }, + CoroutineKindShim { coroutine_def_id: DefId, target_kind: ty::ClosureKind }, /// Compiler-generated accessor for thread locals which returns a reference to the thread local /// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking @@ -192,7 +194,7 @@ impl<'tcx> InstanceDef<'tcx> { coroutine_closure_def_id: def_id, target_kind: _, } - | ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: def_id } + | ty::InstanceDef::CoroutineKindShim { coroutine_def_id: def_id, target_kind: _ } | InstanceDef::DropGlue(def_id, _) | InstanceDef::CloneShim(def_id, _) | InstanceDef::FnPtrAddrShim(def_id, _) => def_id, @@ -213,7 +215,7 @@ impl<'tcx> InstanceDef<'tcx> { | InstanceDef::Intrinsic(..) | InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } - | ty::InstanceDef::CoroutineByMoveShim { .. } + | ty::InstanceDef::CoroutineKindShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::FnPtrAddrShim(..) => None, @@ -310,7 +312,7 @@ impl<'tcx> InstanceDef<'tcx> { | InstanceDef::DropGlue(_, Some(_)) => false, InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } - | InstanceDef::CoroutineByMoveShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::Item(_) | InstanceDef::Intrinsic(..) @@ -349,7 +351,7 @@ fn fmt_instance( InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"), InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"), InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"), - InstanceDef::CoroutineByMoveShim { .. } => write!(f, " - shim"), + InstanceDef::CoroutineKindShim { .. } => write!(f, " - shim"), InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"), InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"), InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"), @@ -651,13 +653,11 @@ impl<'tcx> Instance<'tcx> { if args.as_coroutine().kind_ty() == id_args.as_coroutine().kind_ty() { Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args }) } else { - assert_eq!( - args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(), - ty::ClosureKind::FnOnce, - "FIXME(async_closures): Generate a by-mut body here." - ); Some(Instance { - def: ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id }, + def: ty::InstanceDef::CoroutineKindShim { + coroutine_def_id, + target_kind: args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(), + }, args, }) } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 9ceb3ec3f61..c9137f374a2 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -1681,7 +1681,7 @@ impl<'tcx> TyCtxt<'tcx> { | ty::InstanceDef::Virtual(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } - | ty::InstanceDef::CoroutineByMoveShim { .. } + | ty::InstanceDef::CoroutineKindShim { .. } | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) | ty::InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_middle/src/ty/print/mod.rs b/compiler/rustc_middle/src/ty/print/mod.rs index 7026d2af298..19f8ba124f1 100644 --- a/compiler/rustc_middle/src/ty/print/mod.rs +++ b/compiler/rustc_middle/src/ty/print/mod.rs @@ -3,6 +3,7 @@ use crate::ty::{self, Ty, TyCtxt}; use rustc_data_structures::fx::FxHashSet; use rustc_data_structures::sso::SsoHashSet; +use rustc_hir as hir; use rustc_hir::def_id::{CrateNum, DefId, LocalDefId}; use rustc_hir::definitions::{DefPathData, DisambiguatedDefPathData}; @@ -130,8 +131,24 @@ pub trait Printer<'tcx>: Sized { parent_args = &args[..generics.parent_count.min(args.len())]; match key.disambiguated_data.data { - // Closures' own generics are only captures, don't print them. - DefPathData::Closure => {} + DefPathData::Closure => { + // FIXME(async_closures): This is somewhat ugly. + // We need to additionally print the `kind` field of a closure if + // it is desugared from a coroutine-closure. + if let Some(hir::CoroutineKind::Desugared( + _, + hir::CoroutineSource::Closure, + )) = self.tcx().coroutine_kind(def_id) + && args.len() >= parent_args.len() + 1 + { + return self.path_generic_args( + |cx| cx.print_def_path(def_id, parent_args), + &args[..parent_args.len() + 1][..1], + ); + } else { + // Closures' own generics are only captures, don't print them. + } + } // This covers both `DefKind::AnonConst` and `DefKind::InlineConst`. // Anon consts doesn't have their own generics, and inline consts' own // generics are their inferred types, so don't print them. diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs index 1cc0a5026d1..fcd4715b9e8 100644 --- a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -6,7 +6,7 @@ use rustc_data_structures::fx::FxIndexSet; use rustc_hir as hir; use rustc_middle::mir::visit::MutVisitor; -use rustc_middle::mir::{self, MirPass}; +use rustc_middle::mir::{self, dump_mir, MirPass}; use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt}; use rustc_target::abi::FieldIdx; @@ -24,7 +24,9 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { }; let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() }; - if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce { + + let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap(); + if coroutine_kind == ty::ClosureKind::FnOnce { return; } @@ -58,14 +60,49 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody { let mut by_move_body = body.clone(); MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body); + dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(())); by_move_body.source = mir::MirSource { - instance: InstanceDef::CoroutineByMoveShim { + instance: InstanceDef::CoroutineKindShim { coroutine_def_id: coroutine_def_id.to_def_id(), + target_kind: ty::ClosureKind::FnOnce, }, promoted: None, }; - body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body); + + // If this is coming from an `AsyncFn` coroutine-closure, we must also create a by-mut body. + // This is actually just a copy of the by-ref body, but with a different self type. + // FIXME(async_closures): We could probably unify this with the by-ref body somehow. + if coroutine_kind == ty::ClosureKind::Fn { + let by_mut_coroutine_ty = Ty::new_coroutine( + tcx, + coroutine_def_id.to_def_id(), + ty::CoroutineArgs::new( + tcx, + ty::CoroutineArgsParts { + parent_args: args.as_coroutine().parent_args(), + kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnMut), + resume_ty: args.as_coroutine().resume_ty(), + yield_ty: args.as_coroutine().yield_ty(), + return_ty: args.as_coroutine().return_ty(), + witness: args.as_coroutine().witness(), + tupled_upvars_ty: args.as_coroutine().tupled_upvars_ty(), + }, + ) + .args, + ); + let mut by_mut_body = body.clone(); + by_mut_body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty = by_mut_coroutine_ty; + dump_mir(tcx, false, "coroutine_by_mut", &0, &by_mut_body, |_, _| Ok(())); + by_mut_body.source = mir::MirSource { + instance: InstanceDef::CoroutineKindShim { + coroutine_def_id: coroutine_def_id.to_def_id(), + target_kind: ty::ClosureKind::FnMut, + }, + promoted: None, + }; + body.coroutine.as_mut().unwrap().by_mut_body = Some(by_mut_body); + } } } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 24bc84a235c..e77553a03d6 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -318,7 +318,7 @@ impl<'tcx> Inliner<'tcx> { | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } - | InstanceDef::CoroutineByMoveShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index 77ff780393e..5b03bc361dd 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -88,7 +88,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } - | InstanceDef::CoroutineByMoveShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::ThreadLocalShim { .. } | InstanceDef::CloneShim(..) => {} diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index c7e770904fb..605e1ad46d7 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -190,10 +190,13 @@ fn run_passes_inner<'tcx>( body.pass_count = 1; } - if let Some(coroutine) = body.coroutine.as_mut() - && let Some(by_move_body) = coroutine.by_move_body.as_mut() - { - run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each); + if let Some(coroutine) = body.coroutine.as_mut() { + if let Some(by_move_body) = coroutine.by_move_body.as_mut() { + run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each); + } + if let Some(by_mut_body) = coroutine.by_mut_body.as_mut() { + run_passes_inner(tcx, by_mut_body, passes, phase_change, validate_each); + } } } diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 668ccdd8735..7b6de3a5439 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -72,32 +72,70 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' } => match target_kind { ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"), ty::ClosureKind::FnMut => { - let body = build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id); - // No need to optimize the body, it has already been optimized. - return body; + // No need to optimize the body, it has already been optimized + // since we steal it from the `AsyncFn::call` body and just fix + // the return type. + return build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id); } ty::ClosureKind::FnOnce => { build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id) } }, - ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id } => { - return tcx - .optimized_mir(coroutine_def_id) - .coroutine - .as_ref() - .unwrap() - .by_move_body - .as_ref() - .unwrap() - .clone(); - } + ty::InstanceDef::CoroutineKindShim { coroutine_def_id, target_kind } => match target_kind { + ty::ClosureKind::Fn => unreachable!(), + ty::ClosureKind::FnMut => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine_by_mut_body() + .unwrap() + .clone(); + } + ty::ClosureKind::FnOnce => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine_by_move_body() + .unwrap() + .clone(); + } + }, ty::InstanceDef::DropGlue(def_id, ty) => { // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end // of this function. Is this intentional? if let Some(ty::Coroutine(coroutine_def_id, args)) = ty.map(Ty::kind) { - let body = tcx.optimized_mir(*coroutine_def_id).coroutine_drop().unwrap(); + let coroutine_body = tcx.optimized_mir(*coroutine_def_id); + + let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind() + else { + bug!() + }; + + // If this is a regular coroutine, grab its drop shim. If this is a coroutine + // that comes from a coroutine-closure, and the kind ty differs from the "maximum" + // kind that it supports, then grab the appropriate drop shim. This ensures that + // the future returned by `<[coroutine-closure] as AsyncFnOnce>::call_once` will + // drop the coroutine-closure's upvars. + let body = if id_args.as_coroutine().kind_ty() == args.as_coroutine().kind_ty() { + coroutine_body.coroutine_drop().unwrap() + } else { + match args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() { + ty::ClosureKind::Fn => { + unreachable!() + } + ty::ClosureKind::FnMut => coroutine_body + .coroutine_by_mut_body() + .unwrap() + .coroutine_drop() + .unwrap(), + ty::ClosureKind::FnOnce => coroutine_body + .coroutine_by_move_body() + .unwrap() + .coroutine_drop() + .unwrap(), + } + }; + let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args); debug!("make_shim({:?}) = {:?}", instance, body); @@ -1076,7 +1114,11 @@ fn build_construct_coroutine_by_move_shim<'tcx>( target_kind: ty::ClosureKind::FnOnce, }); - new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span) + let body = + new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span); + dump_mir(tcx, false, "coroutine_closure_by_move", &0, &body, |_, _| Ok(())); + + body } fn build_construct_coroutine_by_mut_shim<'tcx>( @@ -1110,5 +1152,8 @@ fn build_construct_coroutine_by_mut_shim<'tcx>( target_kind: ty::ClosureKind::FnMut, }); + body.pass_count = 0; + dump_mir(tcx, false, "coroutine_closure_by_mut", &0, &body, |_, _| Ok(())); + body } diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index cf3c8e1fdd3..3376af98653 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -984,7 +984,7 @@ fn visit_instance_use<'tcx>( | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } - | ty::InstanceDef::CoroutineByMoveShim { .. } + | ty::InstanceDef::CoroutineKindShim { .. } | ty::InstanceDef::Item(..) | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 22b35c4344b..4f6ea66df7b 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -621,7 +621,7 @@ fn characteristic_def_id_of_mono_item<'tcx>( | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } - | ty::InstanceDef::CoroutineByMoveShim { .. } + | ty::InstanceDef::CoroutineKindShim { .. } | ty::InstanceDef::Intrinsic(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::Virtual(..) @@ -786,7 +786,7 @@ fn mono_item_visibility<'tcx>( | InstanceDef::Intrinsic(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } - | InstanceDef::CoroutineByMoveShim { .. } + | InstanceDef::CoroutineKindShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::FnPtrAddrShim(..) => return Visibility::Hidden, diff --git a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs index 066348dcb67..959a17d24b7 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs @@ -800,7 +800,7 @@ impl<'tcx> Stable<'tcx> for ty::Instance<'tcx> { | ty::InstanceDef::FnPtrAddrShim(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } - | ty::InstanceDef::CoroutineByMoveShim { .. } + | ty::InstanceDef::CoroutineKindShim { .. } | ty::InstanceDef::ThreadLocalShim(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_symbol_mangling/src/legacy.rs b/compiler/rustc_symbol_mangling/src/legacy.rs index 5af9503087a..646649293fc 100644 --- a/compiler/rustc_symbol_mangling/src/legacy.rs +++ b/compiler/rustc_symbol_mangling/src/legacy.rs @@ -64,16 +64,29 @@ pub(super) fn mangle<'tcx>( ) .unwrap(); - if let ty::InstanceDef::ThreadLocalShim(..) = instance.def { - let _ = printer.write_str("{{tls-shim}}"); - } - - if let ty::InstanceDef::VTableShim(..) = instance.def { - let _ = printer.write_str("{{vtable-shim}}"); - } - - if let ty::InstanceDef::ReifyShim(..) = instance.def { - let _ = printer.write_str("{{reify-shim}}"); + match instance.def { + ty::InstanceDef::ThreadLocalShim(..) => { + printer.write_str("{{tls-shim}}").unwrap(); + } + ty::InstanceDef::VTableShim(..) => { + printer.write_str("{{vtable-shim}}").unwrap(); + } + ty::InstanceDef::ReifyShim(..) => { + printer.write_str("{{reify-shim}}").unwrap(); + } + // FIXME(async_closures): This shouldn't be needed when we fix + // `Instance::ty`/`Instance::def_id`. + ty::InstanceDef::ConstructCoroutineInClosureShim { target_kind, .. } + | ty::InstanceDef::CoroutineKindShim { target_kind, .. } => match target_kind { + ty::ClosureKind::Fn => unreachable!(), + ty::ClosureKind::FnMut => { + printer.write_str("{{fn-mut-shim}}").unwrap(); + } + ty::ClosureKind::FnOnce => { + printer.write_str("{{fn-once-shim}}").unwrap(); + } + }, + _ => {} } printer.path.finish(hash) diff --git a/compiler/rustc_symbol_mangling/src/v0.rs b/compiler/rustc_symbol_mangling/src/v0.rs index d380cb9a19b..530221555c5 100644 --- a/compiler/rustc_symbol_mangling/src/v0.rs +++ b/compiler/rustc_symbol_mangling/src/v0.rs @@ -46,6 +46,13 @@ pub(super) fn mangle<'tcx>( ty::InstanceDef::VTableShim(_) => Some("vtable"), ty::InstanceDef::ReifyShim(_) => Some("reify"), + ty::InstanceDef::ConstructCoroutineInClosureShim { target_kind, .. } + | ty::InstanceDef::CoroutineKindShim { target_kind, .. } => match target_kind { + ty::ClosureKind::Fn => unreachable!(), + ty::ClosureKind::FnMut => Some("fn_mut"), + ty::ClosureKind::FnOnce => Some("fn_once"), + }, + _ => None, }; diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs index e023283a709..7b95b4d03a6 100644 --- a/compiler/rustc_ty_utils/src/abi.rs +++ b/compiler/rustc_ty_utils/src/abi.rs @@ -145,7 +145,7 @@ fn fn_sig_for_fn_abi<'tcx>( ) } ty::Coroutine(did, args) => { - // FIXME(async_closures): This isn't right for `CoroutineByMoveShim`. + // FIXME(async_closures): This isn't right for `CoroutineKindShim`. let coroutine_kind = tcx.coroutine_kind(did).unwrap(); let sig = args.as_coroutine().sig(); diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir new file mode 100644 index 00000000000..1fae40c5f40 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-abort.mir @@ -0,0 +1,47 @@ +// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move + +fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10}, _2: ResumeTy) -> () +yields () + { + debug _task_context => _2; + debug a => (_1.0: i32); + debug b => (_1.1: i32); + let mut _0: (); + let _3: i32; + scope 1 { + debug a => _3; + let _4: &i32; + scope 2 { + debug a => _4; + let _5: &i32; + scope 3 { + debug b => _5; + } + } + } + + bb0: { + StorageLive(_3); + _3 = (_1.0: i32); + FakeRead(ForLet(None), _3); + StorageLive(_4); + _4 = &_3; + FakeRead(ForLet(None), _4); + StorageLive(_5); + _5 = &(_1.1: i32); + FakeRead(ForLet(None), _5); + _0 = const (); + StorageDead(_5); + StorageDead(_4); + StorageDead(_3); + drop(_1) -> [return: bb1, unwind: bb2]; + } + + bb1: { + return; + } + + bb2 (cleanup): { + resume; + } +} diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir new file mode 100644 index 00000000000..1fae40c5f40 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.panic-unwind.mir @@ -0,0 +1,47 @@ +// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move + +fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10}, _2: ResumeTy) -> () +yields () + { + debug _task_context => _2; + debug a => (_1.0: i32); + debug b => (_1.1: i32); + let mut _0: (); + let _3: i32; + scope 1 { + debug a => _3; + let _4: &i32; + scope 2 { + debug a => _4; + let _5: &i32; + scope 3 { + debug b => _5; + } + } + } + + bb0: { + StorageLive(_3); + _3 = (_1.0: i32); + FakeRead(ForLet(None), _3); + StorageLive(_4); + _4 = &_3; + FakeRead(ForLet(None), _4); + StorageLive(_5); + _5 = &(_1.1: i32); + FakeRead(ForLet(None), _5); + _0 = const (); + StorageDead(_5); + StorageDead(_4); + StorageDead(_3); + drop(_1) -> [return: bb1, unwind: bb2]; + } + + bb1: { + return; + } + + bb2 (cleanup): { + resume; + } +} diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_mut.0.panic-abort.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_mut.0.panic-abort.mir new file mode 100644 index 00000000000..9886d6f68a4 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_mut.0.panic-abort.mir @@ -0,0 +1,47 @@ +// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_mut + +fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10}, _2: ResumeTy) -> () +yields () + { + debug _task_context => _2; + debug a => (_1.0: i32); + debug b => (*(_1.1: &i32)); + let mut _0: (); + let _3: i32; + scope 1 { + debug a => _3; + let _4: &i32; + scope 2 { + debug a => _4; + let _5: &i32; + scope 3 { + debug b => _5; + } + } + } + + bb0: { + StorageLive(_3); + _3 = (_1.0: i32); + FakeRead(ForLet(None), _3); + StorageLive(_4); + _4 = &_3; + FakeRead(ForLet(None), _4); + StorageLive(_5); + _5 = &(*(_1.1: &i32)); + FakeRead(ForLet(None), _5); + _0 = const (); + StorageDead(_5); + StorageDead(_4); + StorageDead(_3); + drop(_1) -> [return: bb1, unwind: bb2]; + } + + bb1: { + return; + } + + bb2 (cleanup): { + resume; + } +} diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_mut.0.panic-unwind.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_mut.0.panic-unwind.mir new file mode 100644 index 00000000000..9886d6f68a4 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_mut.0.panic-unwind.mir @@ -0,0 +1,47 @@ +// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_mut + +fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10}, _2: ResumeTy) -> () +yields () + { + debug _task_context => _2; + debug a => (_1.0: i32); + debug b => (*(_1.1: &i32)); + let mut _0: (); + let _3: i32; + scope 1 { + debug a => _3; + let _4: &i32; + scope 2 { + debug a => _4; + let _5: &i32; + scope 3 { + debug b => _5; + } + } + } + + bb0: { + StorageLive(_3); + _3 = (_1.0: i32); + FakeRead(ForLet(None), _3); + StorageLive(_4); + _4 = &_3; + FakeRead(ForLet(None), _4); + StorageLive(_5); + _5 = &(*(_1.1: &i32)); + FakeRead(ForLet(None), _5); + _0 = const (); + StorageDead(_5); + StorageDead(_4); + StorageDead(_3); + drop(_1) -> [return: bb1, unwind: bb2]; + } + + bb1: { + return; + } + + bb2 (cleanup): { + resume; + } +} diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir new file mode 100644 index 00000000000..7df4eb49260 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-abort.mir @@ -0,0 +1,10 @@ +// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move + +fn main::{closure#0}::{closure#0}(_1: {coroutine-closure@$DIR/async_closure_shims.rs:39:33: 39:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10} { + let mut _0: {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10}; + + bb0: { + _0 = {coroutine@$DIR/async_closure_shims.rs:39:53: 42:10 (#0)} { a: move _2, b: move (_1.0: i32) }; + return; + } +} diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir new file mode 100644 index 00000000000..7df4eb49260 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.panic-unwind.mir @@ -0,0 +1,10 @@ +// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move + +fn main::{closure#0}::{closure#0}(_1: {coroutine-closure@$DIR/async_closure_shims.rs:39:33: 39:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10} { + let mut _0: {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10}; + + bb0: { + _0 = {coroutine@$DIR/async_closure_shims.rs:39:53: 42:10 (#0)} { a: move _2, b: move (_1.0: i32) }; + return; + } +} diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_mut.0.panic-abort.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_mut.0.panic-abort.mir new file mode 100644 index 00000000000..517b8d0dd88 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_mut.0.panic-abort.mir @@ -0,0 +1,16 @@ +// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_mut + +fn main::{closure#0}::{closure#0}(_1: &mut {coroutine-closure@$DIR/async_closure_shims.rs:39:33: 39:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10} { + debug a => _2; + debug b => ((*_1).0: i32); + let mut _0: {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10}; + let mut _3: &i32; + + bb0: { + StorageLive(_3); + _3 = &((*_1).0: i32); + _0 = {coroutine@$DIR/async_closure_shims.rs:39:53: 42:10 (#0)} { a: _2, b: move _3 }; + StorageDead(_3); + return; + } +} diff --git a/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_mut.0.panic-unwind.mir b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_mut.0.panic-unwind.mir new file mode 100644 index 00000000000..517b8d0dd88 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_mut.0.panic-unwind.mir @@ -0,0 +1,16 @@ +// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_mut + +fn main::{closure#0}::{closure#0}(_1: &mut {coroutine-closure@$DIR/async_closure_shims.rs:39:33: 39:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10} { + debug a => _2; + debug b => ((*_1).0: i32); + let mut _0: {async closure body@$DIR/async_closure_shims.rs:39:53: 42:10}; + let mut _3: &i32; + + bb0: { + StorageLive(_3); + _3 = &((*_1).0: i32); + _0 = {coroutine@$DIR/async_closure_shims.rs:39:53: 42:10 (#0)} { a: _2, b: move _3 }; + StorageDead(_3); + return; + } +} diff --git a/tests/mir-opt/async_closure_shims.rs b/tests/mir-opt/async_closure_shims.rs new file mode 100644 index 00000000000..ef3bdaaa145 --- /dev/null +++ b/tests/mir-opt/async_closure_shims.rs @@ -0,0 +1,46 @@ +// edition:2021 +// skip-filecheck +// EMIT_MIR_FOR_EACH_PANIC_STRATEGY + +#![feature(async_closure, noop_waker, async_fn_traits)] + +use std::future::Future; +use std::ops::{AsyncFnMut, AsyncFnOnce}; +use std::pin::pin; +use std::task::*; + +pub fn block_on(fut: impl Future) -> T { + let mut fut = pin!(fut); + let ctx = &mut Context::from_waker(Waker::noop()); + + loop { + match fut.as_mut().poll(ctx) { + Poll::Pending => {} + Poll::Ready(t) => break t, + } + } +} + +async fn call_mut(f: &mut impl AsyncFnMut(i32)) { + f(0).await; +} + +async fn call_once(f: impl AsyncFnOnce(i32)) { + f(1).await; +} + +// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_move.0.mir +// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}.coroutine_closure_by_mut.0.mir +// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_mut.0.mir +// EMIT_MIR async_closure_shims.main-{closure#0}-{closure#0}-{closure#0}.coroutine_by_move.0.mir +fn main() { + block_on(async { + let b = 2i32; + let mut async_closure = async move |a: i32| { + let a = &a; + let b = &b; + }; + call_mut(&mut async_closure).await; + call_once(async_closure).await; + }); +} diff --git a/tests/ui/async-await/async-closures/async-fn-mut-for-async-fn.rs b/tests/ui/async-await/async-closures/async-fn-mut-for-async-fn.rs new file mode 100644 index 00000000000..8d7dc6a276b --- /dev/null +++ b/tests/ui/async-await/async-closures/async-fn-mut-for-async-fn.rs @@ -0,0 +1,23 @@ +// aux-build:block-on.rs +// edition:2021 +// run-pass + +// FIXME(async_closures): When `fn_sig_for_fn_abi` is fixed, remove this. +// ignore-pass (test emits codegen-time warnings) + +#![feature(async_closure, async_fn_traits)] + +extern crate block_on; + +use std::ops::AsyncFnMut; + +fn main() { + block_on::block_on(async { + let x = async || {}; + + async fn needs_async_fn_mut(mut x: impl AsyncFnMut()) { + x().await; + } + needs_async_fn_mut(x).await; + }); +} diff --git a/tests/ui/async-await/async-closures/async-fn-mut-for-async-fn.stderr b/tests/ui/async-await/async-closures/async-fn-mut-for-async-fn.stderr new file mode 100644 index 00000000000..48917e8b23f --- /dev/null +++ b/tests/ui/async-await/async-closures/async-fn-mut-for-async-fn.stderr @@ -0,0 +1 @@ +WARN rustc_codegen_ssa::mir::locals Unexpected initial operand type: expected std::pin::Pin<&ReErased mut Coroutine(DefId(0:8 ~ async_fn_mut_for_async_fn[3241]::main::{closure#0}::{closure#0}::{closure#0}), [i16, std::future::ResumeTy, (), (), CoroutineWitness(DefId(0:8 ~ async_fn_mut_for_async_fn[3241]::main::{closure#0}::{closure#0}::{closure#0}), []), ()])>, found std::pin::Pin<&ReErased mut Coroutine(DefId(0:8 ~ async_fn_mut_for_async_fn[3241]::main::{closure#0}::{closure#0}::{closure#0}), [i8, std::future::ResumeTy, (), (), CoroutineWitness(DefId(0:8 ~ async_fn_mut_for_async_fn[3241]::main::{closure#0}::{closure#0}::{closure#0}), []), ()])>.See . diff --git a/tests/ui/async-await/async-closures/async-fn-once-for-async-fn.stderr b/tests/ui/async-await/async-closures/async-fn-once-for-async-fn.stderr index 9ae4692f003..978a5a653f9 100644 --- a/tests/ui/async-await/async-closures/async-fn-once-for-async-fn.stderr +++ b/tests/ui/async-await/async-closures/async-fn-once-for-async-fn.stderr @@ -1,2 +1 @@ -WARN rustc_codegen_ssa::mir::locals Unexpected initial operand type: expected std::pin::Pin<&ReErased mut Coroutine(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#1}), [i32, std::future::ResumeTy, (), (), CoroutineWitness(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#1}), []), ()])>, found std::pin::Pin<&ReErased mut Coroutine(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#1}), [i8, std::future::ResumeTy, (), (), CoroutineWitness(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#1}), []), ()])>.See . -WARN rustc_codegen_ssa::mir::locals Unexpected initial operand type: expected *mut Coroutine(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#1}), [i8, std::future::ResumeTy, (), (), CoroutineWitness(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#1}), []), ()]), found *mut Coroutine(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#1}), [i32, std::future::ResumeTy, (), (), CoroutineWitness(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#1}), []), ()]).See . +WARN rustc_codegen_ssa::mir::locals Unexpected initial operand type: expected std::pin::Pin<&ReErased mut Coroutine(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#0}::{closure#0}), [i32, std::future::ResumeTy, (), (), CoroutineWitness(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#0}::{closure#0}), []), ()])>, found std::pin::Pin<&ReErased mut Coroutine(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#0}::{closure#0}), [i8, std::future::ResumeTy, (), (), CoroutineWitness(DefId(0:8 ~ async_fn_once_for_async_fn[6cdf]::main::{closure#0}::{closure#0}::{closure#0}), []), ()])>.See . diff --git a/tests/ui/async-await/async-closures/def-path.rs b/tests/ui/async-await/async-closures/def-path.rs index 2883a1715b0..87e99ddda64 100644 --- a/tests/ui/async-await/async-closures/def-path.rs +++ b/tests/ui/async-await/async-closures/def-path.rs @@ -8,7 +8,7 @@ fn main() { //~^ NOTE the expected `async` closure body let () = x(); //~^ ERROR mismatched types - //~| NOTE this expression has type `{static main::{closure#0}::{closure#0} upvar_tys= + //~| NOTE this expression has type `{static main::{closure#0}::{closure#0}< //~| NOTE expected `async` closure body, found `()` - //~| NOTE expected `async` closure body `{static main::{closure#0}::{closure#0} + //~| NOTE expected `async` closure body `{static main::{closure#0}::{closure#0}< } diff --git a/tests/ui/async-await/async-closures/def-path.stderr b/tests/ui/async-await/async-closures/def-path.stderr index 4b37e50aac4..dae45825f37 100644 --- a/tests/ui/async-await/async-closures/def-path.stderr +++ b/tests/ui/async-await/async-closures/def-path.stderr @@ -5,11 +5,11 @@ LL | let x = async || {}; | -- the expected `async` closure body LL | LL | let () = x(); - | ^^ --- this expression has type `{static main::{closure#0}::{closure#0} upvar_tys=?7t witness=?8t}` + | ^^ --- this expression has type `{static main::{closure#0}::{closure#0} upvar_tys=?15t witness=?6t}` | | | expected `async` closure body, found `()` | - = note: expected `async` closure body `{static main::{closure#0}::{closure#0} upvar_tys=?7t witness=?8t}` + = note: expected `async` closure body `{static main::{closure#0}::{closure#0} upvar_tys=?15t witness=?6t}` found unit type `()` error: aborting due to 1 previous error diff --git a/tests/ui/async-await/async-closures/drop.rs b/tests/ui/async-await/async-closures/drop.rs new file mode 100644 index 00000000000..1b7f2f8a600 --- /dev/null +++ b/tests/ui/async-await/async-closures/drop.rs @@ -0,0 +1,40 @@ +// aux-build:block-on.rs +// edition:2018 +// run-pass +// check-run-results + +#![feature(async_closure, async_fn_traits)] +#![allow(unused)] + +extern crate block_on; + +use std::ops::AsyncFnOnce; + +struct DropMe(i32); + +impl Drop for DropMe { + fn drop(&mut self) { + println!("{} was dropped", self.0); + } +} + +async fn call_once(f: impl AsyncFnOnce()) { + println!("before call"); + let fut = Box::pin(f()); + println!("after call"); + drop(fut); + println!("future dropped"); +} + +fn main() { + block_on::block_on(async { + let d = DropMe(42); + let async_closure = async move || { + let d = &d; + println!("called"); + }; + + call_once(async_closure).await; + println!("after"); + }); +} diff --git a/tests/ui/async-await/async-closures/drop.run.stdout b/tests/ui/async-await/async-closures/drop.run.stdout new file mode 100644 index 00000000000..ab233f491ba --- /dev/null +++ b/tests/ui/async-await/async-closures/drop.run.stdout @@ -0,0 +1,5 @@ +before call +after call +42 was dropped +future dropped +after diff --git a/tests/ui/async-await/async-closures/mangle.rs b/tests/ui/async-await/async-closures/mangle.rs new file mode 100644 index 00000000000..98065c3c711 --- /dev/null +++ b/tests/ui/async-await/async-closures/mangle.rs @@ -0,0 +1,36 @@ +// aux-build:block-on.rs +// edition:2021 +// build-pass +// revisions: v0 legacy +//[v0] compile-flags: -Csymbol-mangling-version=v0 +//[legacy] compile-flags: -Csymbol-mangling-version=legacy -Zunstable-options + +// FIXME(async_closures): When `fn_sig_for_fn_abi` is fixed, remove this. +// ignore-pass (test emits codegen-time warnings) + +#![feature(async_closure, noop_waker, async_fn_traits)] + +extern crate block_on; + +use std::future::Future; +use std::ops::{AsyncFnMut, AsyncFnOnce}; +use std::pin::pin; +use std::task::*; + +async fn call_mut(f: &mut impl AsyncFnMut()) { + f().await; +} + +async fn call_once(f: impl AsyncFnOnce()) { + f().await; +} + +fn main() { + block_on::block_on(async { + let mut async_closure = async move || { + println!("called"); + }; + call_mut(&mut async_closure).await; + call_once(async_closure).await; + }); +} -- cgit 1.4.1-3-g733a5