about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-02-05 19:59:05 +0000
committerMichael Goulet <michael@errs.io>2024-02-06 20:52:13 +0000
commitb8c93f1223695217cbabc1f3f1e428c358bb4e7a (patch)
tree26e652c7a756bc99db9b2c1c95bf03c3e72b531d
parent08af64e96be28c3680d6e8c96d437a560d3a9ae3 (diff)
downloadrust-b8c93f1223695217cbabc1f3f1e428c358bb4e7a.tar.gz
rust-b8c93f1223695217cbabc1f3f1e428c358bb4e7a.zip
Coroutine closures implement regular Fn traits, when possible
-rw-r--r--compiler/rustc_hir_typeck/src/closure.rs17
-rw-r--r--compiler/rustc_trait_selection/src/traits/project.rs74
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs25
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/confirmation.rs26
-rw-r--r--compiler/rustc_ty_utils/src/instance.rs18
5 files changed, 142 insertions, 18 deletions
diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs
index a985fa201d0..5bdd9412d0e 100644
--- a/compiler/rustc_hir_typeck/src/closure.rs
+++ b/compiler/rustc_hir_typeck/src/closure.rs
@@ -56,11 +56,18 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         // It's always helpful for inference if we know the kind of
         // closure sooner rather than later, so first examine the expected
         // type, and see if can glean a closure kind from there.
-        let (expected_sig, expected_kind) = match expected.to_option(self) {
-            Some(ty) => {
-                self.deduce_closure_signature(self.try_structurally_resolve_type(expr_span, ty))
-            }
-            None => (None, None),
+        let (expected_sig, expected_kind) = match closure.kind {
+            hir::ClosureKind::Closure => match expected.to_option(self) {
+                Some(ty) => {
+                    self.deduce_closure_signature(self.try_structurally_resolve_type(expr_span, ty))
+                }
+                None => (None, None),
+            },
+            // We don't want to deduce a signature from `Fn` bounds for coroutines
+            // or coroutine-closures, because the former does not implement `Fn`
+            // ever, and the latter's signature doesn't correspond to the coroutine
+            // type that it returns.
+            hir::ClosureKind::Coroutine(_) | hir::ClosureKind::CoroutineClosure(_) => (None, None),
         };
 
         let ClosureSignatures { bound_sig, mut liberated_sig } =
diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs
index f45a20ccd32..0dc11d785c4 100644
--- a/compiler/rustc_trait_selection/src/traits/project.rs
+++ b/compiler/rustc_trait_selection/src/traits/project.rs
@@ -2074,7 +2074,9 @@ fn confirm_select_candidate<'cx, 'tcx>(
             } else if lang_items.async_iterator_trait() == Some(trait_def_id) {
                 confirm_async_iterator_candidate(selcx, obligation, data)
             } else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() {
-                if obligation.predicate.self_ty().is_closure() {
+                if obligation.predicate.self_ty().is_closure()
+                    || obligation.predicate.self_ty().is_coroutine_closure()
+                {
                     confirm_closure_candidate(selcx, obligation, data)
                 } else {
                     confirm_fn_pointer_candidate(selcx, obligation, data)
@@ -2386,11 +2388,75 @@ fn confirm_closure_candidate<'cx, 'tcx>(
     obligation: &ProjectionTyObligation<'tcx>,
     nested: Vec<PredicateObligation<'tcx>>,
 ) -> Progress<'tcx> {
+    let tcx = selcx.tcx();
     let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());
-    let ty::Closure(_, args) = self_ty.kind() else {
-        unreachable!("expected closure self type for closure candidate, found {self_ty}")
+    let closure_sig = match *self_ty.kind() {
+        ty::Closure(_, args) => args.as_closure().sig(),
+
+        // Construct a "normal" `FnOnce` signature for coroutine-closure. This is
+        // basically duplicated with the `AsyncFnOnce::CallOnce` confirmation, but
+        // I didn't see a good way to unify those.
+        ty::CoroutineClosure(def_id, args) => {
+            let args = args.as_coroutine_closure();
+            let kind_ty = args.kind_ty();
+            args.coroutine_closure_sig().map_bound(|sig| {
+                // If we know the kind and upvars, use that directly.
+                // Otherwise, defer to `AsyncFnKindHelper::Upvars` to delay
+                // the projection, like the `AsyncFn*` traits do.
+                let output_ty = if let Some(_) = kind_ty.to_opt_closure_kind() {
+                    sig.to_coroutine_given_kind_and_upvars(
+                        tcx,
+                        args.parent_args(),
+                        tcx.coroutine_for_closure(def_id),
+                        ty::ClosureKind::FnOnce,
+                        tcx.lifetimes.re_static,
+                        args.tupled_upvars_ty(),
+                        args.coroutine_captures_by_ref_ty(),
+                    )
+                } else {
+                    let async_fn_kind_trait_def_id =
+                        tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
+                    let upvars_projection_def_id = tcx
+                        .associated_items(async_fn_kind_trait_def_id)
+                        .filter_by_name_unhygienic(sym::Upvars)
+                        .next()
+                        .unwrap()
+                        .def_id;
+                    let tupled_upvars_ty = Ty::new_projection(
+                        tcx,
+                        upvars_projection_def_id,
+                        [
+                            ty::GenericArg::from(kind_ty),
+                            Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce).into(),
+                            tcx.lifetimes.re_static.into(),
+                            sig.tupled_inputs_ty.into(),
+                            args.tupled_upvars_ty().into(),
+                            args.coroutine_captures_by_ref_ty().into(),
+                        ],
+                    );
+                    sig.to_coroutine(
+                        tcx,
+                        args.parent_args(),
+                        Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
+                        tcx.coroutine_for_closure(def_id),
+                        tupled_upvars_ty,
+                    )
+                };
+                tcx.mk_fn_sig(
+                    [sig.tupled_inputs_ty],
+                    output_ty,
+                    sig.c_variadic,
+                    sig.unsafety,
+                    sig.abi,
+                )
+            })
+        }
+
+        _ => {
+            unreachable!("expected closure self type for closure candidate, found {self_ty}");
+        }
     };
-    let closure_sig = args.as_closure().sig();
+
     let Normalized { value: closure_sig, obligations } = normalize_with_depth(
         selcx,
         obligation.param_env,
diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
index 34dc8553714..a82acc3ba05 100644
--- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs
@@ -332,6 +332,31 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
                     }
                 }
             }
+            ty::CoroutineClosure(def_id, args) => {
+                let is_const = self.tcx().is_const_fn_raw(def_id);
+                match self.infcx.closure_kind(self_ty) {
+                    Some(closure_kind) => {
+                        let no_borrows = self
+                            .infcx
+                            .shallow_resolve(args.as_coroutine_closure().tupled_upvars_ty())
+                            .tuple_fields()
+                            .is_empty();
+                        if no_borrows && closure_kind.extends(kind) {
+                            candidates.vec.push(ClosureCandidate { is_const });
+                        } else if kind == ty::ClosureKind::FnOnce {
+                            candidates.vec.push(ClosureCandidate { is_const });
+                        }
+                    }
+                    None => {
+                        if kind == ty::ClosureKind::FnOnce {
+                            candidates.vec.push(ClosureCandidate { is_const });
+                        } else {
+                            // This stays ambiguous until kind+upvars are determined.
+                            candidates.ambiguous = true;
+                        }
+                    }
+                }
+            }
             ty::Infer(ty::TyVar(_)) => {
                 debug!("assemble_unboxed_closure_candidates: ambiguous self-type");
                 candidates.ambiguous = true;
diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
index 42845169549..f2dc4b1be73 100644
--- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
@@ -865,17 +865,25 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         // touch bound regions, they just capture the in-scope
         // type/region parameters.
         let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
-        let ty::Closure(closure_def_id, args) = *self_ty.kind() else {
-            bug!("closure candidate for non-closure {:?}", obligation);
+        let trait_ref = match *self_ty.kind() {
+            ty::Closure(_, args) => {
+                self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
+            }
+            ty::CoroutineClosure(_, args) => {
+                args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
+                    ty::TraitRef::new(
+                        self.tcx(),
+                        obligation.predicate.def_id(),
+                        [self_ty, sig.tupled_inputs_ty],
+                    )
+                })
+            }
+            _ => {
+                bug!("closure candidate for non-closure {:?}", obligation);
+            }
         };
 
-        let trait_ref =
-            self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_);
-        let nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
-
-        debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations");
-
-        Ok(nested)
+        self.confirm_poly_trait_refs(obligation, trait_ref)
     }
 
     #[instrument(skip(self), level = "debug")]
diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs
index bcc7c98ed69..eae80199ce5 100644
--- a/compiler/rustc_ty_utils/src/instance.rs
+++ b/compiler/rustc_ty_utils/src/instance.rs
@@ -278,6 +278,24 @@ fn resolve_associated_item<'tcx>(
                         def: ty::InstanceDef::FnPtrShim(trait_item_id, rcvr_args.type_at(0)),
                         args: rcvr_args,
                     }),
+                    ty::CoroutineClosure(coroutine_closure_def_id, args) => {
+                        // When a coroutine-closure implements the `Fn` traits, then it
+                        // always dispatches to the `FnOnce` implementation. This is to
+                        // ensure that the `closure_kind` of the resulting closure is in
+                        // sync with the built-in trait implementations (since all of the
+                        // implementations return `FnOnce::Output`).
+                        if ty::ClosureKind::FnOnce == args.as_coroutine_closure().kind() {
+                            Some(Instance::new(coroutine_closure_def_id, args))
+                        } else {
+                            Some(Instance {
+                                def: ty::InstanceDef::ConstructCoroutineInClosureShim {
+                                    coroutine_closure_def_id,
+                                    target_kind: ty::ClosureKind::FnOnce,
+                                },
+                                args,
+                            })
+                        }
+                    }
                     _ => bug!(
                         "no built-in definition for `{trait_ref}::{}` for non-fn type",
                         tcx.item_name(trait_item_id)