about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-02-05 18:33:41 +0000
committerMichael Goulet <michael@errs.io>2024-02-06 17:20:40 +0000
commitf3d32f2f0cd03885686470c48250bd6773c1b9aa (patch)
treea5334f8375966aee3d7558b34a835d35262188b6
parent4a2fe4491ea616983a0cf0cbbd145a39768f4e7a (diff)
downloadrust-f3d32f2f0cd03885686470c48250bd6773c1b9aa.tar.gz
rust-f3d32f2f0cd03885686470c48250bd6773c1b9aa.zip
Flatten confirmation logic
-rw-r--r--compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs105
-rw-r--r--compiler/rustc_trait_selection/src/traits/project.rs138
2 files changed, 107 insertions, 136 deletions
diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
index d02578c4846..4fd9a29c0b2 100644
--- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
+++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs
@@ -318,34 +318,27 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc
     self_ty: Ty<'tcx>,
     goal_kind: ty::ClosureKind,
     env_region: ty::Region<'tcx>,
-) -> Result<
-    (ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Option<ty::Predicate<'tcx>>),
-    NoSolution,
-> {
+) -> Result<(ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Vec<ty::Predicate<'tcx>>), NoSolution>
+{
     match *self_ty.kind() {
         ty::CoroutineClosure(def_id, args) => {
             let args = args.as_coroutine_closure();
             let kind_ty = args.kind_ty();
-
-            if let Some(closure_kind) = kind_ty.to_opt_closure_kind() {
+            let sig = args.coroutine_closure_sig().skip_binder();
+            let mut nested = vec![];
+            let coroutine_ty = if let Some(closure_kind) = kind_ty.to_opt_closure_kind() {
                 if !closure_kind.extends(goal_kind) {
                     return Err(NoSolution);
                 }
-                Ok((
-                    args.coroutine_closure_sig().map_bound(|sig| {
-                        let coroutine_ty = sig.to_coroutine_given_kind_and_upvars(
-                            tcx,
-                            args.parent_args(),
-                            tcx.coroutine_for_closure(def_id),
-                            goal_kind,
-                            env_region,
-                            args.tupled_upvars_ty(),
-                            args.coroutine_captures_by_ref_ty(),
-                        );
-                        (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty)
-                    }),
-                    None,
-                ))
+                sig.to_coroutine_given_kind_and_upvars(
+                    tcx,
+                    args.parent_args(),
+                    tcx.coroutine_for_closure(def_id),
+                    goal_kind,
+                    env_region,
+                    args.tupled_upvars_ty(),
+                    args.coroutine_captures_by_ref_ty(),
+                )
             } else {
                 let async_fn_kind_trait_def_id =
                     tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
@@ -362,39 +355,43 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc
                 // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars`
                 // will project to the right upvars for the generator, appending the inputs and
                 // coroutine upvars respecting the closure kind.
-                Ok((
-                    args.coroutine_closure_sig().map_bound(|sig| {
-                        let tupled_upvars_ty = Ty::new_projection(
-                            tcx,
-                            upvars_projection_def_id,
-                            [
-                                ty::GenericArg::from(kind_ty),
-                                Ty::from_closure_kind(tcx, goal_kind).into(),
-                                env_region.into(),
-                                sig.tupled_inputs_ty.into(),
-                                args.tupled_upvars_ty().into(),
-                                args.coroutine_captures_by_ref_ty().into(),
-                            ],
-                        );
-                        let coroutine_ty = sig.to_coroutine(
-                            tcx,
-                            args.parent_args(),
-                            Ty::from_closure_kind(tcx, goal_kind),
-                            tcx.coroutine_for_closure(def_id),
-                            tupled_upvars_ty,
-                        );
-                        (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty)
-                    }),
-                    Some(
-                        ty::TraitRef::new(
-                            tcx,
-                            async_fn_kind_trait_def_id,
-                            [kind_ty, Ty::from_closure_kind(tcx, goal_kind)],
-                        )
-                        .to_predicate(tcx),
-                    ),
-                ))
-            }
+                nested.push(
+                    ty::TraitRef::new(
+                        tcx,
+                        async_fn_kind_trait_def_id,
+                        [kind_ty, Ty::from_closure_kind(tcx, goal_kind)],
+                    )
+                    .to_predicate(tcx),
+                );
+                let tupled_upvars_ty = Ty::new_projection(
+                    tcx,
+                    upvars_projection_def_id,
+                    [
+                        ty::GenericArg::from(kind_ty),
+                        Ty::from_closure_kind(tcx, goal_kind).into(),
+                        env_region.into(),
+                        sig.tupled_inputs_ty.into(),
+                        args.tupled_upvars_ty().into(),
+                        args.coroutine_captures_by_ref_ty().into(),
+                    ],
+                );
+                sig.to_coroutine(
+                    tcx,
+                    args.parent_args(),
+                    Ty::from_closure_kind(tcx, goal_kind),
+                    tcx.coroutine_for_closure(def_id),
+                    tupled_upvars_ty,
+                )
+            };
+
+            Ok((
+                args.coroutine_closure_sig().rebind((
+                    sig.tupled_inputs_ty,
+                    sig.return_ty,
+                    coroutine_ty,
+                )),
+                nested,
+            ))
         }
 
         ty::FnDef(..) | ty::FnPtr(..) | ty::Closure(..) => Err(NoSolution),
diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs
index 955c81eee6b..88c28761d25 100644
--- a/compiler/rustc_trait_selection/src/traits/project.rs
+++ b/compiler/rustc_trait_selection/src/traits/project.rs
@@ -2446,8 +2446,9 @@ fn confirm_callable_candidate<'cx, 'tcx>(
 fn confirm_async_closure_candidate<'cx, 'tcx>(
     selcx: &mut SelectionContext<'cx, 'tcx>,
     obligation: &ProjectionTyObligation<'tcx>,
-    mut nested: Vec<PredicateObligation<'tcx>>,
+    nested: Vec<PredicateObligation<'tcx>>,
 ) -> Progress<'tcx> {
+    let tcx = selcx.tcx();
     let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());
     let ty::CoroutineClosure(def_id, args) = *self_ty.kind() else {
         unreachable!(
@@ -2456,76 +2457,48 @@ fn confirm_async_closure_candidate<'cx, 'tcx>(
     };
     let args = args.as_coroutine_closure();
     let kind_ty = args.kind_ty();
+    let sig = args.coroutine_closure_sig().skip_binder();
 
-    let tcx = selcx.tcx();
     let goal_kind =
         tcx.async_fn_trait_kind_from_def_id(obligation.predicate.trait_def_id(tcx)).unwrap();
-
-    let async_fn_kind_helper_trait_def_id =
-        tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
-    nested.push(obligation.with(
-        tcx,
-        ty::TraitRef::new(
-            tcx,
-            async_fn_kind_helper_trait_def_id,
-            [kind_ty, Ty::from_closure_kind(tcx, goal_kind)],
-        ),
-    ));
-
     let env_region = match goal_kind {
         ty::ClosureKind::Fn | ty::ClosureKind::FnMut => obligation.predicate.args.region_at(2),
         ty::ClosureKind::FnOnce => tcx.lifetimes.re_static,
     };
 
-    let upvars_projection_def_id = tcx
-        .associated_items(async_fn_kind_helper_trait_def_id)
-        .filter_by_name_unhygienic(sym::Upvars)
-        .next()
-        .unwrap()
-        .def_id;
-
-    // FIXME(async_closures): Confirmation is kind of a mess here. Ideally,
-    // we'd short-circuit when we know that the goal_kind >= closure_kind, and not
-    // register a nested predicate or create a new projection ty here. But I'm too
-    // lazy to make this more efficient atm, and we can always tweak it later,
-    // since all this does is make the solver do more work.
-    //
-    // The code duplication due to the different length args is kind of weird, too.
-    //
-    // See the logic in `structural_traits` in the new solver to understand a bit
-    // more clearly how this *should* look.
-    let poly_cache_entry = args.coroutine_closure_sig().map_bound(|sig| {
-        let (projection_ty, term) = match tcx.item_name(obligation.predicate.def_id) {
-            sym::CallOnceFuture => {
-                let tupled_upvars_ty = Ty::new_projection(
-                    tcx,
-                    upvars_projection_def_id,
-                    [
-                        ty::GenericArg::from(kind_ty),
-                        Ty::from_closure_kind(tcx, goal_kind).into(),
-                        env_region.into(),
-                        sig.tupled_inputs_ty.into(),
-                        args.tupled_upvars_ty().into(),
-                        args.coroutine_captures_by_ref_ty().into(),
-                    ],
-                );
-                let coroutine_ty = sig.to_coroutine(
+    let item_name = tcx.item_name(obligation.predicate.def_id);
+    let term = match item_name {
+        sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => {
+            if let Some(closure_kind) = kind_ty.to_opt_closure_kind() {
+                if !closure_kind.extends(goal_kind) {
+                    bug!("we should not be confirming if the closure kind is not met");
+                }
+                sig.to_coroutine_given_kind_and_upvars(
                     tcx,
                     args.parent_args(),
-                    Ty::from_closure_kind(tcx, goal_kind),
                     tcx.coroutine_for_closure(def_id),
-                    tupled_upvars_ty,
-                );
-                (
-                    ty::AliasTy::new(
-                        tcx,
-                        obligation.predicate.def_id,
-                        [self_ty, sig.tupled_inputs_ty],
-                    ),
-                    coroutine_ty.into(),
+                    goal_kind,
+                    env_region,
+                    args.tupled_upvars_ty(),
+                    args.coroutine_captures_by_ref_ty(),
                 )
-            }
-            sym::CallMutFuture | sym::CallFuture => {
+            } else {
+                let async_fn_kind_trait_def_id =
+                    tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
+                let upvars_projection_def_id = tcx
+                    .associated_items(async_fn_kind_trait_def_id)
+                    .filter_by_name_unhygienic(sym::Upvars)
+                    .next()
+                    .unwrap()
+                    .def_id;
+                // When we don't know the closure kind (and therefore also the closure's upvars,
+                // which are computed at the same time), we must delay the computation of the
+                // generator's upvars. We do this using the `AsyncFnKindHelper`, which as a trait
+                // goal functions similarly to the old `ClosureKind` predicate, and ensures that
+                // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars`
+                // will project to the right upvars for the generator, appending the inputs and
+                // coroutine upvars respecting the closure kind.
+                // N.B. No need to register a `AsyncFnKindHelper` goal here, it's already in `nested`.
                 let tupled_upvars_ty = Ty::new_projection(
                     tcx,
                     upvars_projection_def_id,
@@ -2538,37 +2511,38 @@ fn confirm_async_closure_candidate<'cx, 'tcx>(
                         args.coroutine_captures_by_ref_ty().into(),
                     ],
                 );
-                let coroutine_ty = sig.to_coroutine(
+                sig.to_coroutine(
                     tcx,
                     args.parent_args(),
                     Ty::from_closure_kind(tcx, goal_kind),
                     tcx.coroutine_for_closure(def_id),
                     tupled_upvars_ty,
-                );
-                (
-                    ty::AliasTy::new(
-                        tcx,
-                        obligation.predicate.def_id,
-                        [
-                            ty::GenericArg::from(self_ty),
-                            sig.tupled_inputs_ty.into(),
-                            env_region.into(),
-                        ],
-                    ),
-                    coroutine_ty.into(),
                 )
             }
-            sym::Output => (
-                ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.tupled_inputs_ty]),
-                sig.return_ty.into(),
-            ),
-            name => bug!("no such associated type: {name}"),
-        };
-        ty::ProjectionPredicate { projection_ty, term }
-    });
+        }
+        sym::Output => sig.return_ty,
+        name => bug!("no such associated type: {name}"),
+    };
+    let projection_ty = match item_name {
+        sym::CallOnceFuture | sym::Output => {
+            ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.tupled_inputs_ty])
+        }
+        sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new(
+            tcx,
+            obligation.predicate.def_id,
+            [ty::GenericArg::from(self_ty), sig.tupled_inputs_ty.into(), env_region.into()],
+        ),
+        name => bug!("no such associated type: {name}"),
+    };
 
-    confirm_param_env_candidate(selcx, obligation, poly_cache_entry, true)
-        .with_addl_obligations(nested)
+    confirm_param_env_candidate(
+        selcx,
+        obligation,
+        args.coroutine_closure_sig()
+            .rebind(ty::ProjectionPredicate { projection_ty, term: term.into() }),
+        true,
+    )
+    .with_addl_obligations(nested)
 }
 
 fn confirm_async_fn_kind_helper_candidate<'cx, 'tcx>(