about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_trait_selection/src')
-rw-r--r--compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs178
-rw-r--r--compiler/rustc_trait_selection/src/traits/project.rs327
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs51
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/confirmation.rs126
4 files changed, 481 insertions, 201 deletions
diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
index 8dec04e2c4f..819b070cf8b 100644
--- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
+++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
@@ -323,34 +323,27 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc
     self_ty: Ty<'tcx>,
     goal_kind: ty::ClosureKind,
     env_region: ty::Region<'tcx>,
-) -> Result<
-    (ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Option<ty::Predicate<'tcx>>),
-    NoSolution,
-> {
+) -> Result<(ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Vec<ty::Predicate<'tcx>>), NoSolution>
+{
     match *self_ty.kind() {
         ty::CoroutineClosure(def_id, args) => {
             let args = args.as_coroutine_closure();
             let kind_ty = args.kind_ty();
-
-            if let Some(closure_kind) = kind_ty.to_opt_closure_kind() {
+            let sig = args.coroutine_closure_sig().skip_binder();
+            let mut nested = vec![];
+            let coroutine_ty = if let Some(closure_kind) = kind_ty.to_opt_closure_kind() {
                 if !closure_kind.extends(goal_kind) {
                     return Err(NoSolution);
                 }
-                Ok((
-                    args.coroutine_closure_sig().map_bound(|sig| {
-                        let coroutine_ty = sig.to_coroutine_given_kind_and_upvars(
-                            tcx,
-                            args.parent_args(),
-                            tcx.coroutine_for_closure(def_id),
-                            goal_kind,
-                            env_region,
-                            args.tupled_upvars_ty(),
-                            args.coroutine_captures_by_ref_ty(),
-                        );
-                        (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty)
-                    }),
-                    None,
-                ))
+                sig.to_coroutine_given_kind_and_upvars(
+                    tcx,
+                    args.parent_args(),
+                    tcx.coroutine_for_closure(def_id),
+                    goal_kind,
+                    env_region,
+                    args.tupled_upvars_ty(),
+                    args.coroutine_captures_by_ref_ty(),
+                )
             } else {
                 let async_fn_kind_trait_def_id =
                     tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
@@ -367,42 +360,117 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc
                 // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars`
                 // will project to the right upvars for the generator, appending the inputs and
                 // coroutine upvars respecting the closure kind.
-                Ok((
-                    args.coroutine_closure_sig().map_bound(|sig| {
-                        let tupled_upvars_ty = Ty::new_projection(
-                            tcx,
-                            upvars_projection_def_id,
-                            [
-                                ty::GenericArg::from(kind_ty),
-                                Ty::from_closure_kind(tcx, goal_kind).into(),
-                                env_region.into(),
-                                sig.tupled_inputs_ty.into(),
-                                args.tupled_upvars_ty().into(),
-                                args.coroutine_captures_by_ref_ty().into(),
-                            ],
-                        );
-                        let coroutine_ty = sig.to_coroutine(
-                            tcx,
-                            args.parent_args(),
-                            Ty::from_closure_kind(tcx, goal_kind),
-                            tcx.coroutine_for_closure(def_id),
-                            tupled_upvars_ty,
-                        );
-                        (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty)
-                    }),
-                    Some(
-                        ty::TraitRef::new(
-                            tcx,
-                            async_fn_kind_trait_def_id,
-                            [kind_ty, Ty::from_closure_kind(tcx, goal_kind)],
-                        )
-                        .to_predicate(tcx),
-                    ),
-                ))
-            }
+                nested.push(
+                    ty::TraitRef::new(
+                        tcx,
+                        async_fn_kind_trait_def_id,
+                        [kind_ty, Ty::from_closure_kind(tcx, goal_kind)],
+                    )
+                    .to_predicate(tcx),
+                );
+                let tupled_upvars_ty = Ty::new_projection(
+                    tcx,
+                    upvars_projection_def_id,
+                    [
+                        ty::GenericArg::from(kind_ty),
+                        Ty::from_closure_kind(tcx, goal_kind).into(),
+                        env_region.into(),
+                        sig.tupled_inputs_ty.into(),
+                        args.tupled_upvars_ty().into(),
+                        args.coroutine_captures_by_ref_ty().into(),
+                    ],
+                );
+                sig.to_coroutine(
+                    tcx,
+                    args.parent_args(),
+                    Ty::from_closure_kind(tcx, goal_kind),
+                    tcx.coroutine_for_closure(def_id),
+                    tupled_upvars_ty,
+                )
+            };
+
+            Ok((
+                args.coroutine_closure_sig().rebind((
+                    sig.tupled_inputs_ty,
+                    sig.return_ty,
+                    coroutine_ty,
+                )),
+                nested,
+            ))
         }
 
-        ty::FnDef(..) | ty::FnPtr(..) | ty::Closure(..) => Err(NoSolution),
+        ty::FnDef(..) | ty::FnPtr(..) => {
+            let bound_sig = self_ty.fn_sig(tcx);
+            let sig = bound_sig.skip_binder();
+            let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None);
+            // `FnDef` and `FnPtr` only implement `AsyncFn*` when their
+            // return type implements `Future`.
+            let nested = vec![
+                bound_sig
+                    .rebind(ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()]))
+                    .to_predicate(tcx),
+            ];
+            let future_output_def_id = tcx
+                .associated_items(future_trait_def_id)
+                .filter_by_name_unhygienic(sym::Output)
+                .next()
+                .unwrap()
+                .def_id;
+            let future_output_ty = Ty::new_projection(tcx, future_output_def_id, [sig.output()]);
+            Ok((
+                bound_sig.rebind((Ty::new_tup(tcx, sig.inputs()), sig.output(), future_output_ty)),
+                nested,
+            ))
+        }
+        ty::Closure(_, args) => {
+            let args = args.as_closure();
+            let bound_sig = args.sig();
+            let sig = bound_sig.skip_binder();
+            let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None);
+            // `Closure`s only implement `AsyncFn*` when their return type
+            // implements `Future`.
+            let mut nested = vec![
+                bound_sig
+                    .rebind(ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()]))
+                    .to_predicate(tcx),
+            ];
+
+            // Additionally, we need to check that the closure kind
+            // is still compatible.
+            let kind_ty = args.kind_ty();
+            if let Some(closure_kind) = kind_ty.to_opt_closure_kind() {
+                if !closure_kind.extends(goal_kind) {
+                    return Err(NoSolution);
+                }
+            } else {
+                let async_fn_kind_trait_def_id =
+                    tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
+                // When we don't know the closure kind (and therefore also the closure's upvars,
+                // which are computed at the same time), we must delay the computation of the
+                // generator's upvars. We do this using the `AsyncFnKindHelper`, which as a trait
+                // goal functions similarly to the old `ClosureKind` predicate, and ensures that
+                // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars`
+                // will project to the right upvars for the generator, appending the inputs and
+                // coroutine upvars respecting the closure kind.
+                nested.push(
+                    ty::TraitRef::new(
+                        tcx,
+                        async_fn_kind_trait_def_id,
+                        [kind_ty, Ty::from_closure_kind(tcx, goal_kind)],
+                    )
+                    .to_predicate(tcx),
+                );
+            }
+
+            let future_output_def_id = tcx
+                .associated_items(future_trait_def_id)
+                .filter_by_name_unhygienic(sym::Output)
+                .next()
+                .unwrap()
+                .def_id;
+            let future_output_ty = Ty::new_projection(tcx, future_output_def_id, [sig.output()]);
+            Ok((bound_sig.rebind((sig.inputs()[0], sig.output(), future_output_ty)), nested))
+        }
 
         ty::Bool
         | ty::Char
diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs
index 95f833372fb..054402acb5c 100644
--- a/compiler/rustc_trait_selection/src/traits/project.rs
+++ b/compiler/rustc_trait_selection/src/traits/project.rs
@@ -2087,7 +2087,9 @@ fn confirm_select_candidate<'cx, 'tcx>(
             } else if lang_items.async_iterator_trait() == Some(trait_def_id) {
                 confirm_async_iterator_candidate(selcx, obligation, data)
             } else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() {
-                if obligation.predicate.self_ty().is_closure() {
+                if obligation.predicate.self_ty().is_closure()
+                    || obligation.predicate.self_ty().is_coroutine_closure()
+                {
                     confirm_closure_candidate(selcx, obligation, data)
                 } else {
                     confirm_fn_pointer_candidate(selcx, obligation, data)
@@ -2410,11 +2412,75 @@ fn confirm_closure_candidate<'cx, 'tcx>(
     obligation: &ProjectionTyObligation<'tcx>,
     nested: Vec<PredicateObligation<'tcx>>,
 ) -> Progress<'tcx> {
+    let tcx = selcx.tcx();
     let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());
-    let ty::Closure(_, args) = self_ty.kind() else {
-        unreachable!("expected closure self type for closure candidate, found {self_ty}")
+    let closure_sig = match *self_ty.kind() {
+        ty::Closure(_, args) => args.as_closure().sig(),
+
+        // Construct a "normal" `FnOnce` signature for coroutine-closure. This is
+        // basically duplicated with the `AsyncFnOnce::CallOnce` confirmation, but
+        // I didn't see a good way to unify those.
+        ty::CoroutineClosure(def_id, args) => {
+            let args = args.as_coroutine_closure();
+            let kind_ty = args.kind_ty();
+            args.coroutine_closure_sig().map_bound(|sig| {
+                // If we know the kind and upvars, use that directly.
+                // Otherwise, defer to `AsyncFnKindHelper::Upvars` to delay
+                // the projection, like the `AsyncFn*` traits do.
+                let output_ty = if let Some(_) = kind_ty.to_opt_closure_kind() {
+                    sig.to_coroutine_given_kind_and_upvars(
+                        tcx,
+                        args.parent_args(),
+                        tcx.coroutine_for_closure(def_id),
+                        ty::ClosureKind::FnOnce,
+                        tcx.lifetimes.re_static,
+                        args.tupled_upvars_ty(),
+                        args.coroutine_captures_by_ref_ty(),
+                    )
+                } else {
+                    let async_fn_kind_trait_def_id =
+                        tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
+                    let upvars_projection_def_id = tcx
+                        .associated_items(async_fn_kind_trait_def_id)
+                        .filter_by_name_unhygienic(sym::Upvars)
+                        .next()
+                        .unwrap()
+                        .def_id;
+                    let tupled_upvars_ty = Ty::new_projection(
+                        tcx,
+                        upvars_projection_def_id,
+                        [
+                            ty::GenericArg::from(kind_ty),
+                            Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce).into(),
+                            tcx.lifetimes.re_static.into(),
+                            sig.tupled_inputs_ty.into(),
+                            args.tupled_upvars_ty().into(),
+                            args.coroutine_captures_by_ref_ty().into(),
+                        ],
+                    );
+                    sig.to_coroutine(
+                        tcx,
+                        args.parent_args(),
+                        Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
+                        tcx.coroutine_for_closure(def_id),
+                        tupled_upvars_ty,
+                    )
+                };
+                tcx.mk_fn_sig(
+                    [sig.tupled_inputs_ty],
+                    output_ty,
+                    sig.c_variadic,
+                    sig.unsafety,
+                    sig.abi,
+                )
+            })
+        }
+
+        _ => {
+            unreachable!("expected closure self type for closure candidate, found {self_ty}");
+        }
     };
-    let closure_sig = args.as_closure().sig();
+
     let Normalized { value: closure_sig, obligations } = normalize_with_depth(
         selcx,
         obligation.param_env,
@@ -2470,126 +2536,171 @@ fn confirm_callable_candidate<'cx, 'tcx>(
 fn confirm_async_closure_candidate<'cx, 'tcx>(
     selcx: &mut SelectionContext<'cx, 'tcx>,
     obligation: &ProjectionTyObligation<'tcx>,
-    mut nested: Vec<PredicateObligation<'tcx>>,
+    nested: Vec<PredicateObligation<'tcx>>,
 ) -> Progress<'tcx> {
+    let tcx = selcx.tcx();
     let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());
-    let ty::CoroutineClosure(def_id, args) = *self_ty.kind() else {
-        unreachable!(
-            "expected coroutine-closure self type for coroutine-closure candidate, found {self_ty}"
-        )
-    };
-    let args = args.as_coroutine_closure();
-    let kind_ty = args.kind_ty();
 
-    let tcx = selcx.tcx();
     let goal_kind =
         tcx.async_fn_trait_kind_from_def_id(obligation.predicate.trait_def_id(tcx)).unwrap();
-
-    let async_fn_kind_helper_trait_def_id =
-        tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
-    nested.push(obligation.with(
-        tcx,
-        ty::TraitRef::new(
-            tcx,
-            async_fn_kind_helper_trait_def_id,
-            [kind_ty, Ty::from_closure_kind(tcx, goal_kind)],
-        ),
-    ));
-
     let env_region = match goal_kind {
         ty::ClosureKind::Fn | ty::ClosureKind::FnMut => obligation.predicate.args.region_at(2),
         ty::ClosureKind::FnOnce => tcx.lifetimes.re_static,
     };
-
-    let upvars_projection_def_id = tcx
-        .associated_items(async_fn_kind_helper_trait_def_id)
-        .filter_by_name_unhygienic(sym::Upvars)
-        .next()
-        .unwrap()
-        .def_id;
-
-    // FIXME(async_closures): Confirmation is kind of a mess here. Ideally,
-    // we'd short-circuit when we know that the goal_kind >= closure_kind, and not
-    // register a nested predicate or create a new projection ty here. But I'm too
-    // lazy to make this more efficient atm, and we can always tweak it later,
-    // since all this does is make the solver do more work.
-    //
-    // The code duplication due to the different length args is kind of weird, too.
-    //
-    // See the logic in `structural_traits` in the new solver to understand a bit
-    // more clearly how this *should* look.
-    let poly_cache_entry = args.coroutine_closure_sig().map_bound(|sig| {
-        let (projection_ty, term) = match tcx.item_name(obligation.predicate.def_id) {
-            sym::CallOnceFuture => {
-                let tupled_upvars_ty = Ty::new_projection(
+    let item_name = tcx.item_name(obligation.predicate.def_id);
+
+    let poly_cache_entry = match *self_ty.kind() {
+        ty::CoroutineClosure(def_id, args) => {
+            let args = args.as_coroutine_closure();
+            let kind_ty = args.kind_ty();
+            let sig = args.coroutine_closure_sig().skip_binder();
+
+            let term = match item_name {
+                sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => {
+                    if let Some(closure_kind) = kind_ty.to_opt_closure_kind() {
+                        if !closure_kind.extends(goal_kind) {
+                            bug!("we should not be confirming if the closure kind is not met");
+                        }
+                        sig.to_coroutine_given_kind_and_upvars(
+                            tcx,
+                            args.parent_args(),
+                            tcx.coroutine_for_closure(def_id),
+                            goal_kind,
+                            env_region,
+                            args.tupled_upvars_ty(),
+                            args.coroutine_captures_by_ref_ty(),
+                        )
+                    } else {
+                        let async_fn_kind_trait_def_id =
+                            tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
+                        let upvars_projection_def_id = tcx
+                            .associated_items(async_fn_kind_trait_def_id)
+                            .filter_by_name_unhygienic(sym::Upvars)
+                            .next()
+                            .unwrap()
+                            .def_id;
+                        // When we don't know the closure kind (and therefore also the closure's upvars,
+                        // which are computed at the same time), we must delay the computation of the
+                        // generator's upvars. We do this using the `AsyncFnKindHelper`, which as a trait
+                        // goal functions similarly to the old `ClosureKind` predicate, and ensures that
+                        // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars`
+                        // will project to the right upvars for the generator, appending the inputs and
+                        // coroutine upvars respecting the closure kind.
+                        // N.B. No need to register a `AsyncFnKindHelper` goal here, it's already in `nested`.
+                        let tupled_upvars_ty = Ty::new_projection(
+                            tcx,
+                            upvars_projection_def_id,
+                            [
+                                ty::GenericArg::from(kind_ty),
+                                Ty::from_closure_kind(tcx, goal_kind).into(),
+                                env_region.into(),
+                                sig.tupled_inputs_ty.into(),
+                                args.tupled_upvars_ty().into(),
+                                args.coroutine_captures_by_ref_ty().into(),
+                            ],
+                        );
+                        sig.to_coroutine(
+                            tcx,
+                            args.parent_args(),
+                            Ty::from_closure_kind(tcx, goal_kind),
+                            tcx.coroutine_for_closure(def_id),
+                            tupled_upvars_ty,
+                        )
+                    }
+                }
+                sym::Output => sig.return_ty,
+                name => bug!("no such associated type: {name}"),
+            };
+            let projection_ty = match item_name {
+                sym::CallOnceFuture | sym::Output => ty::AliasTy::new(
                     tcx,
-                    upvars_projection_def_id,
-                    [
-                        ty::GenericArg::from(kind_ty),
-                        Ty::from_closure_kind(tcx, goal_kind).into(),
-                        env_region.into(),
-                        sig.tupled_inputs_ty.into(),
-                        args.tupled_upvars_ty().into(),
-                        args.coroutine_captures_by_ref_ty().into(),
-                    ],
-                );
-                let coroutine_ty = sig.to_coroutine(
+                    obligation.predicate.def_id,
+                    [self_ty, sig.tupled_inputs_ty],
+                ),
+                sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new(
                     tcx,
-                    args.parent_args(),
-                    Ty::from_closure_kind(tcx, goal_kind),
-                    tcx.coroutine_for_closure(def_id),
-                    tupled_upvars_ty,
-                );
-                (
-                    ty::AliasTy::new(
-                        tcx,
-                        obligation.predicate.def_id,
-                        [self_ty, sig.tupled_inputs_ty],
-                    ),
-                    coroutine_ty.into(),
-                )
-            }
-            sym::CallMutFuture | sym::CallFuture => {
-                let tupled_upvars_ty = Ty::new_projection(
+                    obligation.predicate.def_id,
+                    [ty::GenericArg::from(self_ty), sig.tupled_inputs_ty.into(), env_region.into()],
+                ),
+                name => bug!("no such associated type: {name}"),
+            };
+
+            args.coroutine_closure_sig()
+                .rebind(ty::ProjectionPredicate { projection_ty, term: term.into() })
+        }
+        ty::FnDef(..) | ty::FnPtr(..) => {
+            let bound_sig = self_ty.fn_sig(tcx);
+            let sig = bound_sig.skip_binder();
+
+            let term = match item_name {
+                sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => sig.output(),
+                sym::Output => {
+                    let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None);
+                    let future_output_def_id = tcx
+                        .associated_items(future_trait_def_id)
+                        .filter_by_name_unhygienic(sym::Output)
+                        .next()
+                        .unwrap()
+                        .def_id;
+                    Ty::new_projection(tcx, future_output_def_id, [sig.output()])
+                }
+                name => bug!("no such associated type: {name}"),
+            };
+            let projection_ty = match item_name {
+                sym::CallOnceFuture | sym::Output => ty::AliasTy::new(
+                    tcx,
+                    obligation.predicate.def_id,
+                    [self_ty, Ty::new_tup(tcx, sig.inputs())],
+                ),
+                sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new(
                     tcx,
-                    upvars_projection_def_id,
+                    obligation.predicate.def_id,
                     [
-                        ty::GenericArg::from(kind_ty),
-                        Ty::from_closure_kind(tcx, goal_kind).into(),
+                        ty::GenericArg::from(self_ty),
+                        Ty::new_tup(tcx, sig.inputs()).into(),
                         env_region.into(),
-                        sig.tupled_inputs_ty.into(),
-                        args.tupled_upvars_ty().into(),
-                        args.coroutine_captures_by_ref_ty().into(),
                     ],
-                );
-                let coroutine_ty = sig.to_coroutine(
+                ),
+                name => bug!("no such associated type: {name}"),
+            };
+
+            bound_sig.rebind(ty::ProjectionPredicate { projection_ty, term: term.into() })
+        }
+        ty::Closure(_, args) => {
+            let args = args.as_closure();
+            let bound_sig = args.sig();
+            let sig = bound_sig.skip_binder();
+
+            let term = match item_name {
+                sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => sig.output(),
+                sym::Output => {
+                    let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None);
+                    let future_output_def_id = tcx
+                        .associated_items(future_trait_def_id)
+                        .filter_by_name_unhygienic(sym::Output)
+                        .next()
+                        .unwrap()
+                        .def_id;
+                    Ty::new_projection(tcx, future_output_def_id, [sig.output()])
+                }
+                name => bug!("no such associated type: {name}"),
+            };
+            let projection_ty = match item_name {
+                sym::CallOnceFuture | sym::Output => {
+                    ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.inputs()[0]])
+                }
+                sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new(
                     tcx,
-                    args.parent_args(),
-                    Ty::from_closure_kind(tcx, goal_kind),
-                    tcx.coroutine_for_closure(def_id),
-                    tupled_upvars_ty,
-                );
-                (
-                    ty::AliasTy::new(
-                        tcx,
-                        obligation.predicate.def_id,
-                        [
-                            ty::GenericArg::from(self_ty),
-                            sig.tupled_inputs_ty.into(),
-                            env_region.into(),
-                        ],
-                    ),
-                    coroutine_ty.into(),
-                )
-            }
-            sym::Output => (
-                ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.tupled_inputs_ty]),
-                sig.return_ty.into(),
-            ),
-            name => bug!("no such associated type: {name}"),
-        };
-        ty::ProjectionPredicate { projection_ty, term }
-    });
+                    obligation.predicate.def_id,
+                    [ty::GenericArg::from(self_ty), sig.inputs()[0].into(), env_region.into()],
+                ),
+                name => bug!("no such associated type: {name}"),
+            };
+
+            bound_sig.rebind(ty::ProjectionPredicate { projection_ty, term: term.into() })
+        }
+        _ => bug!("expected callable type for AsyncFn candidate"),
+    };
 
     confirm_param_env_candidate(selcx, obligation, poly_cache_entry, true)
         .with_addl_obligations(nested)
diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
index 27dbe0351da..f9a292c2bd7 100644
--- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
@@ -374,6 +374,43 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
                     }
                 }
             }
+            ty::CoroutineClosure(def_id, args) => {
+                let is_const = self.tcx().is_const_fn_raw(def_id);
+                match self.infcx.closure_kind(self_ty) {
+                    Some(closure_kind) => {
+                        let no_borrows = match self
+                            .infcx
+                            .shallow_resolve(args.as_coroutine_closure().tupled_upvars_ty())
+                            .kind()
+                        {
+                            ty::Tuple(tys) => tys.is_empty(),
+                            ty::Error(_) => false,
+                            _ => bug!("tuple_fields called on non-tuple"),
+                        };
+                        // A coroutine-closure implements `FnOnce` *always*, since it may
+                        // always be called once. It additionally implements `Fn`/`FnMut`
+                        // only if it has no upvars (therefore no borrows from the closure
+                        // that would need to be represented with a lifetime) and if the
+                        // closure kind permits it.
+                        // FIXME(async_closures): Actually, it could also implement `Fn`/`FnMut`
+                        // if it takes all of its upvars by copy, and none by ref. This would
+                        // require us to record a bit more information during upvar analysis.
+                        if no_borrows && closure_kind.extends(kind) {
+                            candidates.vec.push(ClosureCandidate { is_const });
+                        } else if kind == ty::ClosureKind::FnOnce {
+                            candidates.vec.push(ClosureCandidate { is_const });
+                        }
+                    }
+                    None => {
+                        if kind == ty::ClosureKind::FnOnce {
+                            candidates.vec.push(ClosureCandidate { is_const });
+                        } else {
+                            // This stays ambiguous until kind+upvars are determined.
+                            candidates.ambiguous = true;
+                        }
+                    }
+                }
+            }
             ty::Infer(ty::TyVar(_)) => {
                 debug!("assemble_unboxed_closure_candidates: ambiguous self-type");
                 candidates.ambiguous = true;
@@ -403,8 +440,18 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
                 }
                 candidates.vec.push(AsyncClosureCandidate);
             }
-            ty::Infer(ty::TyVar(_)) => {
-                candidates.ambiguous = true;
+            // Closures and fn pointers implement `AsyncFn*` if their return types
+            // implement `Future`, which is checked later.
+            ty::Closure(_, args) => {
+                if let Some(closure_kind) = args.as_closure().kind_ty().to_opt_closure_kind()
+                    && !closure_kind.extends(goal_kind)
+                {
+                    return;
+                }
+                candidates.vec.push(AsyncClosureCandidate);
+            }
+            ty::FnDef(..) | ty::FnPtr(..) => {
+                candidates.vec.push(AsyncClosureCandidate);
             }
             _ => {}
         }
diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
index 42074f4a079..6ca24933979 100644
--- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
@@ -872,17 +872,25 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         // touch bound regions, they just capture the in-scope
         // type/region parameters.
         let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
-        let ty::Closure(closure_def_id, args) = *self_ty.kind() else {
-            bug!("closure candidate for non-closure {:?}", obligation);
+        let trait_ref = match *self_ty.kind() {
+            ty::Closure(_, args) => {
+                self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
+            }
+            ty::CoroutineClosure(_, args) => {
+                args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
+                    ty::TraitRef::new(
+                        self.tcx(),
+                        obligation.predicate.def_id(),
+                        [self_ty, sig.tupled_inputs_ty],
+                    )
+                })
+            }
+            _ => {
+                bug!("closure candidate for non-closure {:?}", obligation);
+            }
         };
 
-        let trait_ref =
-            self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_);
-        let nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
-
-        debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations");
-
-        Ok(nested)
+        self.confirm_poly_trait_refs(obligation, trait_ref)
     }
 
     #[instrument(skip(self), level = "debug")]
@@ -890,40 +898,86 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on closure types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
+        let tcx = self.tcx();
         let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
-        let ty::CoroutineClosure(closure_def_id, args) = *self_ty.kind() else {
-            bug!("async closure candidate for non-coroutine-closure {:?}", obligation);
-        };
 
-        let trait_ref = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
-            ty::TraitRef::new(
-                self.tcx(),
-                obligation.predicate.def_id(),
-                [self_ty, sig.tupled_inputs_ty],
-            )
-        });
+        let mut nested = vec![];
+        let (trait_ref, kind_ty) = match *self_ty.kind() {
+            ty::CoroutineClosure(_, args) => {
+                let args = args.as_coroutine_closure();
+                let trait_ref = args.coroutine_closure_sig().map_bound(|sig| {
+                    ty::TraitRef::new(
+                        self.tcx(),
+                        obligation.predicate.def_id(),
+                        [self_ty, sig.tupled_inputs_ty],
+                    )
+                });
+                (trait_ref, args.kind_ty())
+            }
+            ty::FnDef(..) | ty::FnPtr(..) => {
+                let sig = self_ty.fn_sig(tcx);
+                let trait_ref = sig.map_bound(|sig| {
+                    ty::TraitRef::new(
+                        self.tcx(),
+                        obligation.predicate.def_id(),
+                        [self_ty, Ty::new_tup(tcx, sig.inputs())],
+                    )
+                });
+                // We must additionally check that the return type impls `Future`.
+                let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None);
+                nested.push(obligation.with(
+                    tcx,
+                    sig.map_bound(|sig| {
+                        ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()])
+                    }),
+                ));
+                (trait_ref, Ty::from_closure_kind(tcx, ty::ClosureKind::Fn))
+            }
+            ty::Closure(_, args) => {
+                let sig = args.as_closure().sig();
+                let trait_ref = sig.map_bound(|sig| {
+                    ty::TraitRef::new(
+                        self.tcx(),
+                        obligation.predicate.def_id(),
+                        [self_ty, sig.inputs()[0]],
+                    )
+                });
+                // We must additionally check that the return type impls `Future`.
+                let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None);
+                nested.push(obligation.with(
+                    tcx,
+                    sig.map_bound(|sig| {
+                        ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()])
+                    }),
+                ));
+                (trait_ref, Ty::from_closure_kind(tcx, ty::ClosureKind::Fn))
+            }
+            _ => bug!("expected callable type for AsyncFn candidate"),
+        };
 
-        let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
+        nested.extend(self.confirm_poly_trait_refs(obligation, trait_ref)?);
 
         let goal_kind =
             self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap();
-        nested.push(obligation.with(
-            self.tcx(),
-            ty::TraitRef::from_lang_item(
+
+        // If we have not yet determiend the `ClosureKind` of the closure or coroutine-closure,
+        // then additionally register an `AsyncFnKindHelper` goal which will fail if the kind
+        // is constrained to an insufficient type later on.
+        if let Some(closure_kind) = self.infcx.shallow_resolve(kind_ty).to_opt_closure_kind() {
+            if !closure_kind.extends(goal_kind) {
+                return Err(SelectionError::Unimplemented);
+            }
+        } else {
+            nested.push(obligation.with(
                 self.tcx(),
-                LangItem::AsyncFnKindHelper,
-                obligation.cause.span,
-                [
-                    args.as_coroutine_closure().kind_ty(),
-                    Ty::from_closure_kind(self.tcx(), goal_kind),
-                ],
-            ),
-        ));
-
-        debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations");
+                ty::TraitRef::from_lang_item(
+                    self.tcx(),
+                    LangItem::AsyncFnKindHelper,
+                    obligation.cause.span,
+                    [kind_ty, Ty::from_closure_kind(self.tcx(), goal_kind)],
+                ),
+            ));
+        }
 
         Ok(nested)
     }