about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-01-24 22:27:25 +0000
committerMichael Goulet <michael@errs.io>2024-02-06 02:22:58 +0000
commita82bae2172499864c12a1d0b412931ad884911f7 (patch)
treec7299fdfd83be3818fcffdb86639146c9d29bb69 /compiler/rustc_trait_selection
parentc567eddec2c628d4f13707866731e1b2013ad236 (diff)
downloadrust-a82bae2172499864c12a1d0b412931ad884911f7.tar.gz
rust-a82bae2172499864c12a1d0b412931ad884911f7.zip
Teach typeck/borrowck/solvers how to deal with async closures
Diffstat (limited to 'compiler/rustc_trait_selection')
-rw-r--r--compiler/rustc_trait_selection/src/solve/assembly/mod.rs18
-rw-r--r--compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs111
-rw-r--r--compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs113
-rw-r--r--compiler/rustc_trait_selection/src/solve/trait_goals.rs60
-rw-r--r--compiler/rustc_trait_selection/src/traits/project.rs182
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs46
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/confirmation.rs50
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/mod.rs12
8 files changed, 590 insertions, 2 deletions
diff --git a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs
index e1c68039e79..8451fbcc434 100644
--- a/compiler/rustc_trait_selection/src/solve/assembly/mod.rs
+++ b/compiler/rustc_trait_selection/src/solve/assembly/mod.rs
@@ -182,6 +182,20 @@ pub(super) trait GoalKind<'tcx>:
         kind: ty::ClosureKind,
     ) -> QueryResult<'tcx>;
 
+    /// An async closure is known to implement the `AsyncFn<A>` family of traits
+    /// where `A` is given by the signature of the type.
+    fn consider_builtin_async_fn_trait_candidates(
+        ecx: &mut EvalCtxt<'_, 'tcx>,
+        goal: Goal<'tcx, Self>,
+        kind: ty::ClosureKind,
+    ) -> QueryResult<'tcx>;
+
+    /// TODO:
+    fn consider_builtin_async_fn_kind_helper_candidate(
+        ecx: &mut EvalCtxt<'_, 'tcx>,
+        goal: Goal<'tcx, Self>,
+    ) -> QueryResult<'tcx>;
+
     /// `Tuple` is implemented if the `Self` type is a tuple.
     fn consider_builtin_tuple_candidate(
         ecx: &mut EvalCtxt<'_, 'tcx>,
@@ -461,6 +475,10 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
             G::consider_builtin_fn_ptr_trait_candidate(self, goal)
         } else if let Some(kind) = self.tcx().fn_trait_kind_from_def_id(trait_def_id) {
             G::consider_builtin_fn_trait_candidates(self, goal, kind)
+        } else if let Some(kind) = self.tcx().async_fn_trait_kind_from_def_id(trait_def_id) {
+            G::consider_builtin_async_fn_trait_candidates(self, goal, kind)
+        } else if lang_items.async_fn_kind_helper() == Some(trait_def_id) {
+            G::consider_builtin_async_fn_kind_helper_candidate(self, goal)
         } else if lang_items.tuple_trait() == Some(trait_def_id) {
             G::consider_builtin_tuple_candidate(self, goal)
         } else if lang_items.pointee_trait() == Some(trait_def_id) {
diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
index 3c571e1d96f..c35134c78eb 100644
--- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
+++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
@@ -1,11 +1,12 @@
 //! Code which is used by built-in goals that match "structurally", such a auto
 //! traits, `Copy`/`Clone`.
 use rustc_data_structures::fx::FxHashMap;
+use rustc_hir::LangItem;
 use rustc_hir::{def_id::DefId, Movability, Mutability};
 use rustc_infer::traits::query::NoSolution;
 use rustc_middle::traits::solve::Goal;
 use rustc_middle::ty::{
-    self, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
+    self, ToPredicate, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
 };
 
 use crate::solve::EvalCtxt;
@@ -306,6 +307,114 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_callable<'tcx>(
     }
 }
 
+// Returns a binder of the tupled inputs types, output type, and coroutine type
+// from a builtin async closure type.
+pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    self_ty: Ty<'tcx>,
+    goal_kind: ty::ClosureKind,
+    env_region: ty::Region<'tcx>,
+) -> Result<(ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Vec<ty::Predicate<'tcx>>), NoSolution>
+{
+    match *self_ty.kind() {
+        ty::CoroutineClosure(def_id, args) => {
+            let args = args.as_coroutine_closure();
+            let kind_ty = args.kind_ty();
+
+            if let Some(closure_kind) = kind_ty.to_opt_closure_kind() {
+                if !closure_kind.extends(goal_kind) {
+                    return Err(NoSolution);
+                }
+                Ok((
+                    args.coroutine_closure_sig().map_bound(|sig| {
+                        let coroutine_ty = sig.to_coroutine_given_kind_and_upvars(
+                            tcx,
+                            args.parent_args(),
+                            tcx.coroutine_for_closure(def_id),
+                            goal_kind,
+                            env_region,
+                            args.tupled_upvars_ty(),
+                            args.coroutine_captures_by_ref_ty(),
+                        );
+                        (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty)
+                    }),
+                    vec![],
+                ))
+            } else {
+                let helper_trait_def_id = tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
+                // FIXME(async_closures): Make this into a lang item.
+                let upvars_projection_def_id = tcx
+                    .associated_items(helper_trait_def_id)
+                    .in_definition_order()
+                    .next()
+                    .unwrap()
+                    .def_id;
+                Ok((
+                    args.coroutine_closure_sig().map_bound(|sig| {
+                        let tupled_upvars_ty = Ty::new_projection(
+                            tcx,
+                            upvars_projection_def_id,
+                            [
+                                ty::GenericArg::from(kind_ty),
+                                Ty::from_closure_kind(tcx, goal_kind).into(),
+                                env_region.into(),
+                                sig.tupled_inputs_ty.into(),
+                                args.tupled_upvars_ty().into(),
+                                args.coroutine_captures_by_ref_ty().into(),
+                            ],
+                        );
+                        let coroutine_ty = sig.to_coroutine(
+                            tcx,
+                            args.parent_args(),
+                            tcx.coroutine_for_closure(def_id),
+                            tupled_upvars_ty,
+                        );
+                        (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty)
+                    }),
+                    vec![
+                        ty::TraitRef::new(
+                            tcx,
+                            helper_trait_def_id,
+                            [kind_ty, Ty::from_closure_kind(tcx, goal_kind)],
+                        )
+                        .to_predicate(tcx),
+                    ],
+                ))
+            }
+        }
+
+        ty::FnDef(..) | ty::FnPtr(..) | ty::Closure(..) => Err(NoSolution),
+
+        ty::Bool
+        | ty::Char
+        | ty::Int(_)
+        | ty::Uint(_)
+        | ty::Float(_)
+        | ty::Adt(_, _)
+        | ty::Foreign(_)
+        | ty::Str
+        | ty::Array(_, _)
+        | ty::Slice(_)
+        | ty::RawPtr(_)
+        | ty::Ref(_, _, _)
+        | ty::Dynamic(_, _, _)
+        | ty::Coroutine(_, _)
+        | ty::CoroutineWitness(..)
+        | ty::Never
+        | ty::Tuple(_)
+        | ty::Alias(_, _)
+        | ty::Param(_)
+        | ty::Placeholder(..)
+        | ty::Infer(ty::IntVar(_) | ty::FloatVar(_))
+        | ty::Error(_) => Err(NoSolution),
+
+        ty::Bound(..)
+        | ty::Infer(ty::TyVar(_) | ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_)) => {
+            bug!("unexpected type `{self_ty}`")
+        }
+    }
+}
+
 /// Assemble a list of predicates that would be present on a theoretical
 /// user impl for an object type. These predicates must be checked any time
 /// we assemble a built-in object candidate for an object type, since they
diff --git a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs
index e1c6f67f05e..47ba549022d 100644
--- a/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs
+++ b/compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs
@@ -366,6 +366,119 @@ impl<'tcx> assembly::GoalKind<'tcx> for NormalizesTo<'tcx> {
         Self::consider_implied_clause(ecx, goal, pred, [goal.with(tcx, output_is_sized_pred)])
     }
 
+    fn consider_builtin_async_fn_trait_candidates(
+        ecx: &mut EvalCtxt<'_, 'tcx>,
+        goal: Goal<'tcx, Self>,
+        goal_kind: ty::ClosureKind,
+    ) -> QueryResult<'tcx> {
+        let tcx = ecx.tcx();
+
+        let env_region = match goal_kind {
+            ty::ClosureKind::Fn | ty::ClosureKind::FnMut => goal.predicate.alias.args.region_at(2),
+            // Doesn't matter what this region is
+            ty::ClosureKind::FnOnce => tcx.lifetimes.re_static,
+        };
+        let (tupled_inputs_and_output_and_coroutine, nested_preds) =
+            structural_traits::extract_tupled_inputs_and_output_from_async_callable(
+                tcx,
+                goal.predicate.self_ty(),
+                goal_kind,
+                env_region,
+            )?;
+        let output_is_sized_pred =
+            tupled_inputs_and_output_and_coroutine.map_bound(|(_, output, _)| {
+                ty::TraitRef::from_lang_item(tcx, LangItem::Sized, DUMMY_SP, [output])
+            });
+
+        let pred = tupled_inputs_and_output_and_coroutine
+            .map_bound(|(inputs, output, coroutine)| {
+                let (projection_ty, term) = match tcx.item_name(goal.predicate.def_id()) {
+                    sym::CallOnceFuture => (
+                        ty::AliasTy::new(
+                            tcx,
+                            goal.predicate.def_id(),
+                            [goal.predicate.self_ty(), inputs],
+                        ),
+                        coroutine.into(),
+                    ),
+                    sym::CallMutFuture | sym::CallFuture => (
+                        ty::AliasTy::new(
+                            tcx,
+                            goal.predicate.def_id(),
+                            [
+                                ty::GenericArg::from(goal.predicate.self_ty()),
+                                inputs.into(),
+                                env_region.into(),
+                            ],
+                        ),
+                        coroutine.into(),
+                    ),
+                    sym::Output => (
+                        ty::AliasTy::new(
+                            tcx,
+                            goal.predicate.def_id(),
+                            [ty::GenericArg::from(goal.predicate.self_ty()), inputs.into()],
+                        ),
+                        output.into(),
+                    ),
+                    name => bug!("no such associated type: {name}"),
+                };
+                ty::ProjectionPredicate { projection_ty, term }
+            })
+            .to_predicate(tcx);
+
+        // A built-in `AsyncFn` impl only holds if the output is sized.
+        // (FIXME: technically we only need to check this if the type is a fn ptr...)
+        Self::consider_implied_clause(
+            ecx,
+            goal,
+            pred,
+            [goal.with(tcx, output_is_sized_pred)]
+                .into_iter()
+                .chain(nested_preds.into_iter().map(|pred| goal.with(tcx, pred))),
+        )
+    }
+
+    fn consider_builtin_async_fn_kind_helper_candidate(
+        ecx: &mut EvalCtxt<'_, 'tcx>,
+        goal: Goal<'tcx, Self>,
+    ) -> QueryResult<'tcx> {
+        let [
+            closure_fn_kind_ty,
+            goal_kind_ty,
+            borrow_region,
+            tupled_inputs_ty,
+            tupled_upvars_ty,
+            coroutine_captures_by_ref_ty,
+        ] = **goal.predicate.alias.args
+        else {
+            bug!();
+        };
+
+        let Some(closure_kind) = closure_fn_kind_ty.expect_ty().to_opt_closure_kind() else {
+            // We don't need to worry about the self type being an infer var.
+            return Err(NoSolution);
+        };
+        let Some(goal_kind) = goal_kind_ty.expect_ty().to_opt_closure_kind() else {
+            return Err(NoSolution);
+        };
+        if !closure_kind.extends(goal_kind) {
+            return Err(NoSolution);
+        }
+
+        let upvars_ty = ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind(
+            ecx.tcx(),
+            goal_kind,
+            tupled_inputs_ty.expect_ty(),
+            tupled_upvars_ty.expect_ty(),
+            coroutine_captures_by_ref_ty.expect_ty(),
+            borrow_region.expect_region(),
+        );
+
+        ecx.eq(goal.param_env, goal.predicate.term.ty().unwrap(), upvars_ty)?;
+        ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
+    }
+
     fn consider_builtin_tuple_candidate(
         _ecx: &mut EvalCtxt<'_, 'tcx>,
         goal: Goal<'tcx, Self>,
diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs
index efaad47b6dd..fd09a6b671d 100644
--- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs
+++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs
@@ -303,6 +303,66 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {
         Self::consider_implied_clause(ecx, goal, pred, [goal.with(tcx, output_is_sized_pred)])
     }
 
+    fn consider_builtin_async_fn_trait_candidates(
+        ecx: &mut EvalCtxt<'_, 'tcx>,
+        goal: Goal<'tcx, Self>,
+        goal_kind: ty::ClosureKind,
+    ) -> QueryResult<'tcx> {
+        if goal.predicate.polarity != ty::ImplPolarity::Positive {
+            return Err(NoSolution);
+        }
+
+        let tcx = ecx.tcx();
+        let (tupled_inputs_and_output_and_coroutine, nested_preds) =
+            structural_traits::extract_tupled_inputs_and_output_from_async_callable(
+                tcx,
+                goal.predicate.self_ty(),
+                goal_kind,
+                // This region doesn't matter because we're throwing away the coroutine type
+                tcx.lifetimes.re_static,
+            )?;
+        let output_is_sized_pred =
+            tupled_inputs_and_output_and_coroutine.map_bound(|(_, output, _)| {
+                ty::TraitRef::from_lang_item(tcx, LangItem::Sized, DUMMY_SP, [output])
+            });
+
+        let pred = tupled_inputs_and_output_and_coroutine
+            .map_bound(|(inputs, _, _)| {
+                ty::TraitRef::new(tcx, goal.predicate.def_id(), [goal.predicate.self_ty(), inputs])
+            })
+            .to_predicate(tcx);
+        // A built-in `AsyncFn` impl only holds if the output is sized.
+        // (FIXME: technically we only need to check this if the type is a fn ptr...)
+        Self::consider_implied_clause(
+            ecx,
+            goal,
+            pred,
+            [goal.with(tcx, output_is_sized_pred)]
+                .into_iter()
+                .chain(nested_preds.into_iter().map(|pred| goal.with(tcx, pred))),
+        )
+    }
+
+    fn consider_builtin_async_fn_kind_helper_candidate(
+        ecx: &mut EvalCtxt<'_, 'tcx>,
+        goal: Goal<'tcx, Self>,
+    ) -> QueryResult<'tcx> {
+        let [closure_fn_kind_ty, goal_kind_ty] = **goal.predicate.trait_ref.args else {
+            bug!();
+        };
+
+        let Some(closure_kind) = closure_fn_kind_ty.expect_ty().to_opt_closure_kind() else {
+            // We don't need to worry about the self type being an infer var.
+            return Err(NoSolution);
+        };
+        let goal_kind = goal_kind_ty.expect_ty().to_opt_closure_kind().unwrap();
+        if closure_kind.extends(goal_kind) {
+            ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
+        } else {
+            Err(NoSolution)
+        }
+    }
+
     fn consider_builtin_tuple_candidate(
         ecx: &mut EvalCtxt<'_, 'tcx>,
         goal: Goal<'tcx, Self>,
diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs
index a960befcd4b..648f14beaa7 100644
--- a/compiler/rustc_trait_selection/src/traits/project.rs
+++ b/compiler/rustc_trait_selection/src/traits/project.rs
@@ -1833,10 +1833,28 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
                     lang_items.fn_trait(),
                     lang_items.fn_mut_trait(),
                     lang_items.fn_once_trait(),
+                    lang_items.async_fn_trait(),
+                    lang_items.async_fn_mut_trait(),
+                    lang_items.async_fn_once_trait(),
                 ].contains(&Some(trait_ref.def_id))
                 {
                     true
-                }else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) {
+                } else if lang_items.async_fn_kind_helper() == Some(trait_ref.def_id) {
+                    // FIXME(async_closures): Validity constraints here could be cleaned up.
+                    if obligation.predicate.args.type_at(0).is_ty_var()
+                        || obligation.predicate.args.type_at(4).is_ty_var()
+                        || obligation.predicate.args.type_at(5).is_ty_var()
+                    {
+                        candidate_set.mark_ambiguous();
+                        true
+                    } else if obligation.predicate.args.type_at(0).to_opt_closure_kind().is_some()
+                        && obligation.predicate.args.type_at(1).to_opt_closure_kind().is_some()
+                    {
+                        true
+                    } else {
+                        false
+                    }
+                } else if lang_items.discriminant_kind_trait() == Some(trait_ref.def_id) {
                     match self_ty.kind() {
                         ty::Bool
                         | ty::Char
@@ -2061,6 +2079,10 @@ fn confirm_select_candidate<'cx, 'tcx>(
                 } else {
                     confirm_fn_pointer_candidate(selcx, obligation, data)
                 }
+            } else if selcx.tcx().async_fn_trait_kind_from_def_id(trait_def_id).is_some() {
+                confirm_async_closure_candidate(selcx, obligation, data)
+            } else if lang_items.async_fn_kind_helper() == Some(trait_def_id) {
+                confirm_async_fn_kind_helper_candidate(selcx, obligation, data)
             } else {
                 confirm_builtin_candidate(selcx, obligation, data)
             }
@@ -2421,6 +2443,164 @@ fn confirm_callable_candidate<'cx, 'tcx>(
     confirm_param_env_candidate(selcx, obligation, predicate, true)
 }
 
+fn confirm_async_closure_candidate<'cx, 'tcx>(
+    selcx: &mut SelectionContext<'cx, 'tcx>,
+    obligation: &ProjectionTyObligation<'tcx>,
+    mut nested: Vec<PredicateObligation<'tcx>>,
+) -> Progress<'tcx> {
+    let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());
+    let ty::CoroutineClosure(def_id, args) = *self_ty.kind() else {
+        unreachable!(
+            "expected coroutine-closure self type for coroutine-closure candidate, found {self_ty}"
+        )
+    };
+    let args = args.as_coroutine_closure();
+    let kind_ty = args.kind_ty();
+
+    let tcx = selcx.tcx();
+    let goal_kind =
+        tcx.async_fn_trait_kind_from_def_id(obligation.predicate.trait_def_id(tcx)).unwrap();
+
+    let helper_trait_def_id = tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
+    nested.push(obligation.with(
+        tcx,
+        ty::TraitRef::new(
+            tcx,
+            helper_trait_def_id,
+            [kind_ty, Ty::from_closure_kind(tcx, goal_kind)],
+        ),
+    ));
+
+    let env_region = match goal_kind {
+        ty::ClosureKind::Fn | ty::ClosureKind::FnMut => obligation.predicate.args.region_at(2),
+        ty::ClosureKind::FnOnce => tcx.lifetimes.re_static,
+    };
+
+    // FIXME(async_closures): Make this into a lang item.
+    let upvars_projection_def_id =
+        tcx.associated_items(helper_trait_def_id).in_definition_order().next().unwrap().def_id;
+
+    // FIXME(async_closures): Confirmation is kind of a mess here. Ideally,
+    // we'd short-circuit when we know that the goal_kind >= closure_kind, and not
+    // register a nested predicate or create a new projection ty here. But I'm too
+    // lazy to make this more efficient atm, and we can always tweak it later,
+    // since all this does is make the solver do more work.
+    //
+    // The code duplication due to the different length args is kind of weird, too.
+    let poly_cache_entry = args.coroutine_closure_sig().map_bound(|sig| {
+        let (projection_ty, term) = match tcx.item_name(obligation.predicate.def_id) {
+            sym::CallOnceFuture => {
+                let tupled_upvars_ty = Ty::new_projection(
+                    tcx,
+                    upvars_projection_def_id,
+                    [
+                        ty::GenericArg::from(kind_ty),
+                        Ty::from_closure_kind(tcx, goal_kind).into(),
+                        env_region.into(),
+                        sig.tupled_inputs_ty.into(),
+                        args.tupled_upvars_ty().into(),
+                        args.coroutine_captures_by_ref_ty().into(),
+                    ],
+                );
+                let coroutine_ty = sig.to_coroutine(
+                    tcx,
+                    args.parent_args(),
+                    tcx.coroutine_for_closure(def_id),
+                    tupled_upvars_ty,
+                );
+                (
+                    ty::AliasTy::new(
+                        tcx,
+                        obligation.predicate.def_id,
+                        [self_ty, sig.tupled_inputs_ty],
+                    ),
+                    coroutine_ty.into(),
+                )
+            }
+            sym::CallMutFuture | sym::CallFuture => {
+                let tupled_upvars_ty = Ty::new_projection(
+                    tcx,
+                    upvars_projection_def_id,
+                    [
+                        ty::GenericArg::from(kind_ty),
+                        Ty::from_closure_kind(tcx, goal_kind).into(),
+                        env_region.into(),
+                        sig.tupled_inputs_ty.into(),
+                        args.tupled_upvars_ty().into(),
+                        args.coroutine_captures_by_ref_ty().into(),
+                    ],
+                );
+                let coroutine_ty = sig.to_coroutine(
+                    tcx,
+                    args.parent_args(),
+                    tcx.coroutine_for_closure(def_id),
+                    tupled_upvars_ty,
+                );
+                (
+                    ty::AliasTy::new(
+                        tcx,
+                        obligation.predicate.def_id,
+                        [
+                            ty::GenericArg::from(self_ty),
+                            sig.tupled_inputs_ty.into(),
+                            env_region.into(),
+                        ],
+                    ),
+                    coroutine_ty.into(),
+                )
+            }
+            sym::Output => (
+                ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.tupled_inputs_ty]),
+                sig.return_ty.into(),
+            ),
+            name => bug!("no such associated type: {name}"),
+        };
+        ty::ProjectionPredicate { projection_ty, term }
+    });
+
+    confirm_param_env_candidate(selcx, obligation, poly_cache_entry, true)
+        .with_addl_obligations(nested)
+}
+
+fn confirm_async_fn_kind_helper_candidate<'cx, 'tcx>(
+    selcx: &mut SelectionContext<'cx, 'tcx>,
+    obligation: &ProjectionTyObligation<'tcx>,
+    nested: Vec<PredicateObligation<'tcx>>,
+) -> Progress<'tcx> {
+    let [
+        // We already checked that the goal_kind >= closure_kind
+        _closure_kind_ty,
+        goal_kind_ty,
+        borrow_region,
+        tupled_inputs_ty,
+        tupled_upvars_ty,
+        coroutine_captures_by_ref_ty,
+    ] = **obligation.predicate.args
+    else {
+        bug!();
+    };
+
+    let predicate = ty::ProjectionPredicate {
+        projection_ty: ty::AliasTy::new(
+            selcx.tcx(),
+            obligation.predicate.def_id,
+            obligation.predicate.args,
+        ),
+        term: ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind(
+            selcx.tcx(),
+            goal_kind_ty.expect_ty().to_opt_closure_kind().unwrap(),
+            tupled_inputs_ty.expect_ty(),
+            tupled_upvars_ty.expect_ty(),
+            coroutine_captures_by_ref_ty.expect_ty(),
+            borrow_region.expect_region(),
+        )
+        .into(),
+    };
+
+    confirm_param_env_candidate(selcx, obligation, ty::Binder::dummy(predicate), false)
+        .with_addl_obligations(nested)
+}
+
 fn confirm_param_env_candidate<'cx, 'tcx>(
     selcx: &mut SelectionContext<'cx, 'tcx>,
     obligation: &ProjectionTyObligation<'tcx>,
diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
index b354ebf111f..75aedd5cd22 100644
--- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
@@ -117,9 +117,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
                     self.assemble_iterator_candidates(obligation, &mut candidates);
                 } else if lang_items.async_iterator_trait() == Some(def_id) {
                     self.assemble_async_iterator_candidates(obligation, &mut candidates);
+                } else if lang_items.async_fn_kind_helper() == Some(def_id) {
+                    self.assemble_async_fn_kind_helper_candidates(obligation, &mut candidates);
                 }
 
                 self.assemble_closure_candidates(obligation, &mut candidates);
+                self.assemble_async_closure_candidates(obligation, &mut candidates);
                 self.assemble_fn_pointer_candidates(obligation, &mut candidates);
                 self.assemble_candidates_from_impls(obligation, &mut candidates);
                 self.assemble_candidates_from_object_ty(obligation, &mut candidates);
@@ -335,6 +338,49 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         }
     }
 
+    fn assemble_async_closure_candidates(
+        &mut self,
+        obligation: &PolyTraitObligation<'tcx>,
+        candidates: &mut SelectionCandidateSet<'tcx>,
+    ) {
+        let Some(goal_kind) =
+            self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id())
+        else {
+            return;
+        };
+
+        match *obligation.self_ty().skip_binder().kind() {
+            ty::CoroutineClosure(_, args) => {
+                if let Some(closure_kind) =
+                    args.as_coroutine_closure().kind_ty().to_opt_closure_kind()
+                    && !closure_kind.extends(goal_kind)
+                {
+                    return;
+                }
+                candidates.vec.push(AsyncClosureCandidate);
+            }
+            ty::Infer(ty::TyVar(_)) => {
+                candidates.ambiguous = true;
+            }
+            _ => {}
+        }
+    }
+
+    fn assemble_async_fn_kind_helper_candidates(
+        &mut self,
+        obligation: &PolyTraitObligation<'tcx>,
+        candidates: &mut SelectionCandidateSet<'tcx>,
+    ) {
+        if let Some(closure_kind) = obligation.self_ty().skip_binder().to_opt_closure_kind()
+            && let Some(goal_kind) =
+                obligation.predicate.skip_binder().trait_ref.args.type_at(1).to_opt_closure_kind()
+        {
+            if closure_kind.extends(goal_kind) {
+                candidates.vec.push(AsyncFnKindHelperCandidate);
+            }
+        }
+    }
+
     /// Implements one of the `Fn()` family for a fn pointer.
     fn assemble_fn_pointer_candidates(
         &mut self,
diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
index 74f388e53a3..79336a9d4e8 100644
--- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
@@ -83,6 +83,13 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
                 ImplSource::Builtin(BuiltinImplSource::Misc, vtable_closure)
             }
 
+            AsyncClosureCandidate => {
+                let vtable_closure = self.confirm_async_closure_candidate(obligation)?;
+                ImplSource::Builtin(BuiltinImplSource::Misc, vtable_closure)
+            }
+
+            AsyncFnKindHelperCandidate => ImplSource::Builtin(BuiltinImplSource::Misc, vec![]),
+
             CoroutineCandidate => {
                 let vtable_coroutine = self.confirm_coroutine_candidate(obligation)?;
                 ImplSource::Builtin(BuiltinImplSource::Misc, vtable_coroutine)
@@ -869,6 +876,49 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         Ok(nested)
     }
 
+    #[instrument(skip(self), level = "debug")]
+    fn confirm_async_closure_candidate(
+        &mut self,
+        obligation: &PolyTraitObligation<'tcx>,
+    ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
+        // Okay to skip binder because the args on closure types never
+        // touch bound regions, they just capture the in-scope
+        // type/region parameters.
+        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let ty::CoroutineClosure(closure_def_id, args) = *self_ty.kind() else {
+            bug!("async closure candidate for non-coroutine-closure {:?}", obligation);
+        };
+
+        let trait_ref = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
+            ty::TraitRef::new(
+                self.tcx(),
+                obligation.predicate.def_id(),
+                [self_ty, sig.tupled_inputs_ty],
+            )
+        });
+
+        let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
+
+        let goal_kind =
+            self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap();
+        nested.push(obligation.with(
+            self.tcx(),
+            ty::TraitRef::from_lang_item(
+                self.tcx(),
+                LangItem::AsyncFnKindHelper,
+                obligation.cause.span,
+                [
+                    args.as_coroutine_closure().kind_ty(),
+                    Ty::from_closure_kind(self.tcx(), goal_kind),
+                ],
+            ),
+        ));
+
+        debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations");
+
+        Ok(nested)
+    }
+
     /// In the case of closure types and fn pointers,
     /// we currently treat the input type parameters on the trait as
     /// outputs. This means that when we have a match we have only
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index e79277b89c4..7f41c73b72f 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -1864,6 +1864,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
                 ImplCandidate(..)
                 | AutoImplCandidate
                 | ClosureCandidate { .. }
+                | AsyncClosureCandidate
+                | AsyncFnKindHelperCandidate
                 | CoroutineCandidate
                 | FutureCandidate
                 | IteratorCandidate
@@ -1894,6 +1896,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
                 ImplCandidate(_)
                 | AutoImplCandidate
                 | ClosureCandidate { .. }
+                | AsyncClosureCandidate
+                | AsyncFnKindHelperCandidate
                 | CoroutineCandidate
                 | FutureCandidate
                 | IteratorCandidate
@@ -1930,6 +1934,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
                 ImplCandidate(..)
                 | AutoImplCandidate
                 | ClosureCandidate { .. }
+                | AsyncClosureCandidate
+                | AsyncFnKindHelperCandidate
                 | CoroutineCandidate
                 | FutureCandidate
                 | IteratorCandidate
@@ -1946,6 +1952,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
                 ImplCandidate(..)
                 | AutoImplCandidate
                 | ClosureCandidate { .. }
+                | AsyncClosureCandidate
+                | AsyncFnKindHelperCandidate
                 | CoroutineCandidate
                 | FutureCandidate
                 | IteratorCandidate
@@ -2054,6 +2062,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
             (
                 ImplCandidate(_)
                 | ClosureCandidate { .. }
+                | AsyncClosureCandidate
+                | AsyncFnKindHelperCandidate
                 | CoroutineCandidate
                 | FutureCandidate
                 | IteratorCandidate
@@ -2066,6 +2076,8 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
                 | TraitAliasCandidate,
                 ImplCandidate(_)
                 | ClosureCandidate { .. }
+                | AsyncClosureCandidate
+                | AsyncFnKindHelperCandidate
                 | CoroutineCandidate
                 | FutureCandidate
                 | IteratorCandidate