diff options
Diffstat (limited to 'compiler/rustc_trait_selection/src')
4 files changed, 481 insertions, 201 deletions
diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs index 8dec04e2c4f..819b070cf8b 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs @@ -323,34 +323,27 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc self_ty: Ty<'tcx>, goal_kind: ty::ClosureKind, env_region: ty::Region<'tcx>, -) -> Result< - (ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Option<ty::Predicate<'tcx>>), - NoSolution, -> { +) -> Result<(ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Vec<ty::Predicate<'tcx>>), NoSolution> +{ match *self_ty.kind() { ty::CoroutineClosure(def_id, args) => { let args = args.as_coroutine_closure(); let kind_ty = args.kind_ty(); - - if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { + let sig = args.coroutine_closure_sig().skip_binder(); + let mut nested = vec![]; + let coroutine_ty = if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { if !closure_kind.extends(goal_kind) { return Err(NoSolution); } - Ok(( - args.coroutine_closure_sig().map_bound(|sig| { - let coroutine_ty = sig.to_coroutine_given_kind_and_upvars( - tcx, - args.parent_args(), - tcx.coroutine_for_closure(def_id), - goal_kind, - env_region, - args.tupled_upvars_ty(), - args.coroutine_captures_by_ref_ty(), - ); - (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty) - }), - None, - )) + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + goal_kind, + env_region, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ) } else { let async_fn_kind_trait_def_id = tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); @@ -367,42 +360,117 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars` // will project to the right upvars for the generator, appending the inputs and // coroutine upvars respecting the closure kind. - Ok(( - args.coroutine_closure_sig().map_bound(|sig| { - let tupled_upvars_ty = Ty::new_projection( - tcx, - upvars_projection_def_id, - [ - ty::GenericArg::from(kind_ty), - Ty::from_closure_kind(tcx, goal_kind).into(), - env_region.into(), - sig.tupled_inputs_ty.into(), - args.tupled_upvars_ty().into(), - args.coroutine_captures_by_ref_ty().into(), - ], - ); - let coroutine_ty = sig.to_coroutine( - tcx, - args.parent_args(), - Ty::from_closure_kind(tcx, goal_kind), - tcx.coroutine_for_closure(def_id), - tupled_upvars_ty, - ); - (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty) - }), - Some( - ty::TraitRef::new( - tcx, - async_fn_kind_trait_def_id, - [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], - ) - .to_predicate(tcx), - ), - )) - } + nested.push( + ty::TraitRef::new( + tcx, + async_fn_kind_trait_def_id, + [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], + ) + .to_predicate(tcx), + ); + let tupled_upvars_ty = Ty::new_projection( + tcx, + upvars_projection_def_id, + [ + ty::GenericArg::from(kind_ty), + Ty::from_closure_kind(tcx, goal_kind).into(), + env_region.into(), + sig.tupled_inputs_ty.into(), + args.tupled_upvars_ty().into(), + args.coroutine_captures_by_ref_ty().into(), + ], + ); + sig.to_coroutine( + tcx, + args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), + tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ) + }; + + Ok(( + args.coroutine_closure_sig().rebind(( + sig.tupled_inputs_ty, + sig.return_ty, + coroutine_ty, + )), + nested, + )) } - ty::FnDef(..) | ty::FnPtr(..) | ty::Closure(..) => Err(NoSolution), + ty::FnDef(..) | ty::FnPtr(..) => { + let bound_sig = self_ty.fn_sig(tcx); + let sig = bound_sig.skip_binder(); + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + // `FnDef` and `FnPtr` only implement `AsyncFn*` when their + // return type implements `Future`. + let nested = vec![ + bound_sig + .rebind(ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()])) + .to_predicate(tcx), + ]; + let future_output_def_id = tcx + .associated_items(future_trait_def_id) + .filter_by_name_unhygienic(sym::Output) + .next() + .unwrap() + .def_id; + let future_output_ty = Ty::new_projection(tcx, future_output_def_id, [sig.output()]); + Ok(( + bound_sig.rebind((Ty::new_tup(tcx, sig.inputs()), sig.output(), future_output_ty)), + nested, + )) + } + ty::Closure(_, args) => { + let args = args.as_closure(); + let bound_sig = args.sig(); + let sig = bound_sig.skip_binder(); + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + // `Closure`s only implement `AsyncFn*` when their return type + // implements `Future`. + let mut nested = vec![ + bound_sig + .rebind(ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()])) + .to_predicate(tcx), + ]; + + // Additionally, we need to check that the closure kind + // is still compatible. + let kind_ty = args.kind_ty(); + if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { + if !closure_kind.extends(goal_kind) { + return Err(NoSolution); + } + } else { + let async_fn_kind_trait_def_id = + tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); + // When we don't know the closure kind (and therefore also the closure's upvars, + // which are computed at the same time), we must delay the computation of the + // generator's upvars. We do this using the `AsyncFnKindHelper`, which as a trait + // goal functions similarly to the old `ClosureKind` predicate, and ensures that + // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars` + // will project to the right upvars for the generator, appending the inputs and + // coroutine upvars respecting the closure kind. + nested.push( + ty::TraitRef::new( + tcx, + async_fn_kind_trait_def_id, + [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], + ) + .to_predicate(tcx), + ); + } + + let future_output_def_id = tcx + .associated_items(future_trait_def_id) + .filter_by_name_unhygienic(sym::Output) + .next() + .unwrap() + .def_id; + let future_output_ty = Ty::new_projection(tcx, future_output_def_id, [sig.output()]); + Ok((bound_sig.rebind((sig.inputs()[0], sig.output(), future_output_ty)), nested)) + } ty::Bool | ty::Char diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 95f833372fb..054402acb5c 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -2087,7 +2087,9 @@ fn confirm_select_candidate<'cx, 'tcx>( } else if lang_items.async_iterator_trait() == Some(trait_def_id) { confirm_async_iterator_candidate(selcx, obligation, data) } else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() { - if obligation.predicate.self_ty().is_closure() { + if obligation.predicate.self_ty().is_closure() + || obligation.predicate.self_ty().is_coroutine_closure() + { confirm_closure_candidate(selcx, obligation, data) } else { confirm_fn_pointer_candidate(selcx, obligation, data) @@ -2410,11 +2412,75 @@ fn confirm_closure_candidate<'cx, 'tcx>( obligation: &ProjectionTyObligation<'tcx>, nested: Vec<PredicateObligation<'tcx>>, ) -> Progress<'tcx> { + let tcx = selcx.tcx(); let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty()); - let ty::Closure(_, args) = self_ty.kind() else { - unreachable!("expected closure self type for closure candidate, found {self_ty}") + let closure_sig = match *self_ty.kind() { + ty::Closure(_, args) => args.as_closure().sig(), + + // Construct a "normal" `FnOnce` signature for coroutine-closure. This is + // basically duplicated with the `AsyncFnOnce::CallOnce` confirmation, but + // I didn't see a good way to unify those. + ty::CoroutineClosure(def_id, args) => { + let args = args.as_coroutine_closure(); + let kind_ty = args.kind_ty(); + args.coroutine_closure_sig().map_bound(|sig| { + // If we know the kind and upvars, use that directly. + // Otherwise, defer to `AsyncFnKindHelper::Upvars` to delay + // the projection, like the `AsyncFn*` traits do. + let output_ty = if let Some(_) = kind_ty.to_opt_closure_kind() { + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + ty::ClosureKind::FnOnce, + tcx.lifetimes.re_static, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ) + } else { + let async_fn_kind_trait_def_id = + tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); + let upvars_projection_def_id = tcx + .associated_items(async_fn_kind_trait_def_id) + .filter_by_name_unhygienic(sym::Upvars) + .next() + .unwrap() + .def_id; + let tupled_upvars_ty = Ty::new_projection( + tcx, + upvars_projection_def_id, + [ + ty::GenericArg::from(kind_ty), + Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce).into(), + tcx.lifetimes.re_static.into(), + sig.tupled_inputs_ty.into(), + args.tupled_upvars_ty().into(), + args.coroutine_captures_by_ref_ty().into(), + ], + ); + sig.to_coroutine( + tcx, + args.parent_args(), + Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce), + tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ) + }; + tcx.mk_fn_sig( + [sig.tupled_inputs_ty], + output_ty, + sig.c_variadic, + sig.unsafety, + sig.abi, + ) + }) + } + + _ => { + unreachable!("expected closure self type for closure candidate, found {self_ty}"); + } }; - let closure_sig = args.as_closure().sig(); + let Normalized { value: closure_sig, obligations } = normalize_with_depth( selcx, obligation.param_env, @@ -2470,126 +2536,171 @@ fn confirm_callable_candidate<'cx, 'tcx>( fn confirm_async_closure_candidate<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTyObligation<'tcx>, - mut nested: Vec<PredicateObligation<'tcx>>, + nested: Vec<PredicateObligation<'tcx>>, ) -> Progress<'tcx> { + let tcx = selcx.tcx(); let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty()); - let ty::CoroutineClosure(def_id, args) = *self_ty.kind() else { - unreachable!( - "expected coroutine-closure self type for coroutine-closure candidate, found {self_ty}" - ) - }; - let args = args.as_coroutine_closure(); - let kind_ty = args.kind_ty(); - let tcx = selcx.tcx(); let goal_kind = tcx.async_fn_trait_kind_from_def_id(obligation.predicate.trait_def_id(tcx)).unwrap(); - - let async_fn_kind_helper_trait_def_id = - tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); - nested.push(obligation.with( - tcx, - ty::TraitRef::new( - tcx, - async_fn_kind_helper_trait_def_id, - [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], - ), - )); - let env_region = match goal_kind { ty::ClosureKind::Fn | ty::ClosureKind::FnMut => obligation.predicate.args.region_at(2), ty::ClosureKind::FnOnce => tcx.lifetimes.re_static, }; - - let upvars_projection_def_id = tcx - .associated_items(async_fn_kind_helper_trait_def_id) - .filter_by_name_unhygienic(sym::Upvars) - .next() - .unwrap() - .def_id; - - // FIXME(async_closures): Confirmation is kind of a mess here. Ideally, - // we'd short-circuit when we know that the goal_kind >= closure_kind, and not - // register a nested predicate or create a new projection ty here. But I'm too - // lazy to make this more efficient atm, and we can always tweak it later, - // since all this does is make the solver do more work. - // - // The code duplication due to the different length args is kind of weird, too. - // - // See the logic in `structural_traits` in the new solver to understand a bit - // more clearly how this *should* look. - let poly_cache_entry = args.coroutine_closure_sig().map_bound(|sig| { - let (projection_ty, term) = match tcx.item_name(obligation.predicate.def_id) { - sym::CallOnceFuture => { - let tupled_upvars_ty = Ty::new_projection( + let item_name = tcx.item_name(obligation.predicate.def_id); + + let poly_cache_entry = match *self_ty.kind() { + ty::CoroutineClosure(def_id, args) => { + let args = args.as_coroutine_closure(); + let kind_ty = args.kind_ty(); + let sig = args.coroutine_closure_sig().skip_binder(); + + let term = match item_name { + sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => { + if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { + if !closure_kind.extends(goal_kind) { + bug!("we should not be confirming if the closure kind is not met"); + } + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + goal_kind, + env_region, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ) + } else { + let async_fn_kind_trait_def_id = + tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); + let upvars_projection_def_id = tcx + .associated_items(async_fn_kind_trait_def_id) + .filter_by_name_unhygienic(sym::Upvars) + .next() + .unwrap() + .def_id; + // When we don't know the closure kind (and therefore also the closure's upvars, + // which are computed at the same time), we must delay the computation of the + // generator's upvars. We do this using the `AsyncFnKindHelper`, which as a trait + // goal functions similarly to the old `ClosureKind` predicate, and ensures that + // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars` + // will project to the right upvars for the generator, appending the inputs and + // coroutine upvars respecting the closure kind. + // N.B. No need to register a `AsyncFnKindHelper` goal here, it's already in `nested`. + let tupled_upvars_ty = Ty::new_projection( + tcx, + upvars_projection_def_id, + [ + ty::GenericArg::from(kind_ty), + Ty::from_closure_kind(tcx, goal_kind).into(), + env_region.into(), + sig.tupled_inputs_ty.into(), + args.tupled_upvars_ty().into(), + args.coroutine_captures_by_ref_ty().into(), + ], + ); + sig.to_coroutine( + tcx, + args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), + tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ) + } + } + sym::Output => sig.return_ty, + name => bug!("no such associated type: {name}"), + }; + let projection_ty = match item_name { + sym::CallOnceFuture | sym::Output => ty::AliasTy::new( tcx, - upvars_projection_def_id, - [ - ty::GenericArg::from(kind_ty), - Ty::from_closure_kind(tcx, goal_kind).into(), - env_region.into(), - sig.tupled_inputs_ty.into(), - args.tupled_upvars_ty().into(), - args.coroutine_captures_by_ref_ty().into(), - ], - ); - let coroutine_ty = sig.to_coroutine( + obligation.predicate.def_id, + [self_ty, sig.tupled_inputs_ty], + ), + sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new( tcx, - args.parent_args(), - Ty::from_closure_kind(tcx, goal_kind), - tcx.coroutine_for_closure(def_id), - tupled_upvars_ty, - ); - ( - ty::AliasTy::new( - tcx, - obligation.predicate.def_id, - [self_ty, sig.tupled_inputs_ty], - ), - coroutine_ty.into(), - ) - } - sym::CallMutFuture | sym::CallFuture => { - let tupled_upvars_ty = Ty::new_projection( + obligation.predicate.def_id, + [ty::GenericArg::from(self_ty), sig.tupled_inputs_ty.into(), env_region.into()], + ), + name => bug!("no such associated type: {name}"), + }; + + args.coroutine_closure_sig() + .rebind(ty::ProjectionPredicate { projection_ty, term: term.into() }) + } + ty::FnDef(..) | ty::FnPtr(..) => { + let bound_sig = self_ty.fn_sig(tcx); + let sig = bound_sig.skip_binder(); + + let term = match item_name { + sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => sig.output(), + sym::Output => { + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + let future_output_def_id = tcx + .associated_items(future_trait_def_id) + .filter_by_name_unhygienic(sym::Output) + .next() + .unwrap() + .def_id; + Ty::new_projection(tcx, future_output_def_id, [sig.output()]) + } + name => bug!("no such associated type: {name}"), + }; + let projection_ty = match item_name { + sym::CallOnceFuture | sym::Output => ty::AliasTy::new( + tcx, + obligation.predicate.def_id, + [self_ty, Ty::new_tup(tcx, sig.inputs())], + ), + sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new( tcx, - upvars_projection_def_id, + obligation.predicate.def_id, [ - ty::GenericArg::from(kind_ty), - Ty::from_closure_kind(tcx, goal_kind).into(), + ty::GenericArg::from(self_ty), + Ty::new_tup(tcx, sig.inputs()).into(), env_region.into(), - sig.tupled_inputs_ty.into(), - args.tupled_upvars_ty().into(), - args.coroutine_captures_by_ref_ty().into(), ], - ); - let coroutine_ty = sig.to_coroutine( + ), + name => bug!("no such associated type: {name}"), + }; + + bound_sig.rebind(ty::ProjectionPredicate { projection_ty, term: term.into() }) + } + ty::Closure(_, args) => { + let args = args.as_closure(); + let bound_sig = args.sig(); + let sig = bound_sig.skip_binder(); + + let term = match item_name { + sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => sig.output(), + sym::Output => { + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + let future_output_def_id = tcx + .associated_items(future_trait_def_id) + .filter_by_name_unhygienic(sym::Output) + .next() + .unwrap() + .def_id; + Ty::new_projection(tcx, future_output_def_id, [sig.output()]) + } + name => bug!("no such associated type: {name}"), + }; + let projection_ty = match item_name { + sym::CallOnceFuture | sym::Output => { + ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.inputs()[0]]) + } + sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new( tcx, - args.parent_args(), - Ty::from_closure_kind(tcx, goal_kind), - tcx.coroutine_for_closure(def_id), - tupled_upvars_ty, - ); - ( - ty::AliasTy::new( - tcx, - obligation.predicate.def_id, - [ - ty::GenericArg::from(self_ty), - sig.tupled_inputs_ty.into(), - env_region.into(), - ], - ), - coroutine_ty.into(), - ) - } - sym::Output => ( - ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.tupled_inputs_ty]), - sig.return_ty.into(), - ), - name => bug!("no such associated type: {name}"), - }; - ty::ProjectionPredicate { projection_ty, term } - }); + obligation.predicate.def_id, + [ty::GenericArg::from(self_ty), sig.inputs()[0].into(), env_region.into()], + ), + name => bug!("no such associated type: {name}"), + }; + + bound_sig.rebind(ty::ProjectionPredicate { projection_ty, term: term.into() }) + } + _ => bug!("expected callable type for AsyncFn candidate"), + }; confirm_param_env_candidate(selcx, obligation, poly_cache_entry, true) .with_addl_obligations(nested) diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 27dbe0351da..f9a292c2bd7 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -374,6 +374,43 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { } } } + ty::CoroutineClosure(def_id, args) => { + let is_const = self.tcx().is_const_fn_raw(def_id); + match self.infcx.closure_kind(self_ty) { + Some(closure_kind) => { + let no_borrows = match self + .infcx + .shallow_resolve(args.as_coroutine_closure().tupled_upvars_ty()) + .kind() + { + ty::Tuple(tys) => tys.is_empty(), + ty::Error(_) => false, + _ => bug!("tuple_fields called on non-tuple"), + }; + // A coroutine-closure implements `FnOnce` *always*, since it may + // always be called once. It additionally implements `Fn`/`FnMut` + // only if it has no upvars (therefore no borrows from the closure + // that would need to be represented with a lifetime) and if the + // closure kind permits it. + // FIXME(async_closures): Actually, it could also implement `Fn`/`FnMut` + // if it takes all of its upvars by copy, and none by ref. This would + // require us to record a bit more information during upvar analysis. + if no_borrows && closure_kind.extends(kind) { + candidates.vec.push(ClosureCandidate { is_const }); + } else if kind == ty::ClosureKind::FnOnce { + candidates.vec.push(ClosureCandidate { is_const }); + } + } + None => { + if kind == ty::ClosureKind::FnOnce { + candidates.vec.push(ClosureCandidate { is_const }); + } else { + // This stays ambiguous until kind+upvars are determined. + candidates.ambiguous = true; + } + } + } + } ty::Infer(ty::TyVar(_)) => { debug!("assemble_unboxed_closure_candidates: ambiguous self-type"); candidates.ambiguous = true; @@ -403,8 +440,18 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { } candidates.vec.push(AsyncClosureCandidate); } - ty::Infer(ty::TyVar(_)) => { - candidates.ambiguous = true; + // Closures and fn pointers implement `AsyncFn*` if their return types + // implement `Future`, which is checked later. + ty::Closure(_, args) => { + if let Some(closure_kind) = args.as_closure().kind_ty().to_opt_closure_kind() + && !closure_kind.extends(goal_kind) + { + return; + } + candidates.vec.push(AsyncClosureCandidate); + } + ty::FnDef(..) | ty::FnPtr(..) => { + candidates.vec.push(AsyncClosureCandidate); } _ => {} } diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs index 42074f4a079..6ca24933979 100644 --- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs +++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs @@ -872,17 +872,25 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { // touch bound regions, they just capture the in-scope // type/region parameters. let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); - let ty::Closure(closure_def_id, args) = *self_ty.kind() else { - bug!("closure candidate for non-closure {:?}", obligation); + let trait_ref = match *self_ty.kind() { + ty::Closure(_, args) => { + self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_) + } + ty::CoroutineClosure(_, args) => { + args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { + ty::TraitRef::new( + self.tcx(), + obligation.predicate.def_id(), + [self_ty, sig.tupled_inputs_ty], + ) + }) + } + _ => { + bug!("closure candidate for non-closure {:?}", obligation); + } }; - let trait_ref = - self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_); - let nested = self.confirm_poly_trait_refs(obligation, trait_ref)?; - - debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations"); - - Ok(nested) + self.confirm_poly_trait_refs(obligation, trait_ref) } #[instrument(skip(self), level = "debug")] @@ -890,40 +898,86 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { &mut self, obligation: &PolyTraitObligation<'tcx>, ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> { - // Okay to skip binder because the args on closure types never - // touch bound regions, they just capture the in-scope - // type/region parameters. + let tcx = self.tcx(); let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); - let ty::CoroutineClosure(closure_def_id, args) = *self_ty.kind() else { - bug!("async closure candidate for non-coroutine-closure {:?}", obligation); - }; - let trait_ref = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { - ty::TraitRef::new( - self.tcx(), - obligation.predicate.def_id(), - [self_ty, sig.tupled_inputs_ty], - ) - }); + let mut nested = vec![]; + let (trait_ref, kind_ty) = match *self_ty.kind() { + ty::CoroutineClosure(_, args) => { + let args = args.as_coroutine_closure(); + let trait_ref = args.coroutine_closure_sig().map_bound(|sig| { + ty::TraitRef::new( + self.tcx(), + obligation.predicate.def_id(), + [self_ty, sig.tupled_inputs_ty], + ) + }); + (trait_ref, args.kind_ty()) + } + ty::FnDef(..) | ty::FnPtr(..) => { + let sig = self_ty.fn_sig(tcx); + let trait_ref = sig.map_bound(|sig| { + ty::TraitRef::new( + self.tcx(), + obligation.predicate.def_id(), + [self_ty, Ty::new_tup(tcx, sig.inputs())], + ) + }); + // We must additionally check that the return type impls `Future`. + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + nested.push(obligation.with( + tcx, + sig.map_bound(|sig| { + ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()]) + }), + )); + (trait_ref, Ty::from_closure_kind(tcx, ty::ClosureKind::Fn)) + } + ty::Closure(_, args) => { + let sig = args.as_closure().sig(); + let trait_ref = sig.map_bound(|sig| { + ty::TraitRef::new( + self.tcx(), + obligation.predicate.def_id(), + [self_ty, sig.inputs()[0]], + ) + }); + // We must additionally check that the return type impls `Future`. + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + nested.push(obligation.with( + tcx, + sig.map_bound(|sig| { + ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()]) + }), + )); + (trait_ref, Ty::from_closure_kind(tcx, ty::ClosureKind::Fn)) + } + _ => bug!("expected callable type for AsyncFn candidate"), + }; - let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?; + nested.extend(self.confirm_poly_trait_refs(obligation, trait_ref)?); let goal_kind = self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap(); - nested.push(obligation.with( - self.tcx(), - ty::TraitRef::from_lang_item( + + // If we have not yet determiend the `ClosureKind` of the closure or coroutine-closure, + // then additionally register an `AsyncFnKindHelper` goal which will fail if the kind + // is constrained to an insufficient type later on. + if let Some(closure_kind) = self.infcx.shallow_resolve(kind_ty).to_opt_closure_kind() { + if !closure_kind.extends(goal_kind) { + return Err(SelectionError::Unimplemented); + } + } else { + nested.push(obligation.with( self.tcx(), - LangItem::AsyncFnKindHelper, - obligation.cause.span, - [ - args.as_coroutine_closure().kind_ty(), - Ty::from_closure_kind(self.tcx(), goal_kind), - ], - ), - )); - - debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations"); + ty::TraitRef::from_lang_item( + self.tcx(), + LangItem::AsyncFnKindHelper, + obligation.cause.span, + [kind_ty, Ty::from_closure_kind(self.tcx(), goal_kind)], + ), + )); + } Ok(nested) } |
