about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-03-09 22:12:33 +0000
committerMichael Goulet <michael@errs.io>2024-04-01 22:48:23 -0400
commit5f59b7f7634ac4b05923b47de00f2a13b213af1f (patch)
treeeff28d21a19ed4202f947a3079079e6a76091f5f /compiler
parentdd5e502d4b095be205c07da1481846f0a07c43b4 (diff)
downloadrust-5f59b7f7634ac4b05923b47de00f2a13b213af1f.tar.gz
rust-5f59b7f7634ac4b05923b47de00f2a13b213af1f.zip
Instantiate closure-like bounds with placeholders to deal with binders correctly
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/confirmation.rs150
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/mod.rs20
2 files changed, 88 insertions, 82 deletions
diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
index 6f512a1173f..47f136cb2b7 100644
--- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs
@@ -678,17 +678,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         fn_host_effect: ty::Const<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
         debug!(?obligation, "confirm_fn_pointer_candidate");
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
 
         let tcx = self.tcx();
-
-        let Some(self_ty) = self.infcx.shallow_resolve(obligation.self_ty().no_bound_vars()) else {
-            // FIXME: Ideally we'd support `for<'a> fn(&'a ()): Fn(&'a ())`,
-            // but we do not currently. Luckily, such a bound is not
-            // particularly useful, so we don't expect users to write
-            // them often.
-            return Err(SelectionError::Unimplemented);
-        };
-
         let sig = self_ty.fn_sig(tcx);
         let trait_ref = closure_trait_ref_and_return_type(
             tcx,
@@ -700,7 +693,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         )
         .map_bound(|(trait_ref, _)| trait_ref);
 
-        let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
+        let mut nested = self.equate_trait_refs(
+            &obligation.cause,
+            obligation.param_env,
+            placeholder_predicate.trait_ref,
+            trait_ref,
+        )?;
         let cause = obligation.derived_cause(BuiltinDerivedObligation);
 
         // Confirm the `type Output: Sized;` bound that is present on `FnOnce`
@@ -748,10 +746,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on coroutine types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
         let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
             bug!("closure candidate for non-closure {:?}", obligation);
         };
@@ -760,15 +756,6 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
 
         let coroutine_sig = args.as_coroutine().sig();
 
-        // NOTE: The self-type is a coroutine type and hence is
-        // in fact unparameterized (or at least does not reference any
-        // regions bound in the obligation).
-        let self_ty = obligation
-            .predicate
-            .self_ty()
-            .no_bound_vars()
-            .expect("unboxed closure type should not capture bound vars from the predicate");
-
         let (trait_ref, _, _) = super::util::coroutine_trait_ref_and_outputs(
             self.tcx(),
             obligation.predicate.def_id(),
@@ -776,7 +763,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
             coroutine_sig,
         );
 
-        let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
+        let nested = self.equate_trait_refs(
+            &obligation.cause,
+            obligation.param_env,
+            placeholder_predicate.trait_ref,
+            ty::Binder::dummy(trait_ref),
+        )?;
         debug!(?trait_ref, ?nested, "coroutine candidate obligations");
 
         Ok(nested)
@@ -786,10 +778,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on coroutine types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
         let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
             bug!("closure candidate for non-closure {:?}", obligation);
         };
@@ -801,11 +791,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         let (trait_ref, _) = super::util::future_trait_ref_and_outputs(
             self.tcx(),
             obligation.predicate.def_id(),
-            obligation.predicate.no_bound_vars().expect("future has no bound vars").self_ty(),
+            self_ty,
             coroutine_sig,
         );
 
-        let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
+        let nested = self.equate_trait_refs(
+            &obligation.cause,
+            obligation.param_env,
+            placeholder_predicate.trait_ref,
+            ty::Binder::dummy(trait_ref),
+        )?;
         debug!(?trait_ref, ?nested, "future candidate obligations");
 
         Ok(nested)
@@ -815,10 +810,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on coroutine types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
         let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
             bug!("closure candidate for non-closure {:?}", obligation);
         };
@@ -830,11 +823,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         let (trait_ref, _) = super::util::iterator_trait_ref_and_outputs(
             self.tcx(),
             obligation.predicate.def_id(),
-            obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
+            self_ty,
             gen_sig,
         );
 
-        let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
+        let nested = self.equate_trait_refs(
+            &obligation.cause,
+            obligation.param_env,
+            placeholder_predicate.trait_ref,
+            ty::Binder::dummy(trait_ref),
+        )?;
         debug!(?trait_ref, ?nested, "iterator candidate obligations");
 
         Ok(nested)
@@ -844,10 +842,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on coroutine types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
         let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
             bug!("closure candidate for non-closure {:?}", obligation);
         };
@@ -859,11 +855,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         let (trait_ref, _) = super::util::async_iterator_trait_ref_and_outputs(
             self.tcx(),
             obligation.predicate.def_id(),
-            obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
+            self_ty,
             gen_sig,
         );
 
-        let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
+        let nested = self.equate_trait_refs(
+            &obligation.cause,
+            obligation.param_env,
+            placeholder_predicate.trait_ref,
+            ty::Binder::dummy(trait_ref),
+        )?;
         debug!(?trait_ref, ?nested, "iterator candidate obligations");
 
         Ok(nested)
@@ -874,14 +875,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        // Okay to skip binder because the args on closure types never
-        // touch bound regions, they just capture the in-scope
-        // type/region parameters.
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty: Ty<'_> = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
+
         let trait_ref = match *self_ty.kind() {
-            ty::Closure(_, args) => {
-                self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
-            }
+            ty::Closure(..) => self.closure_trait_ref_unnormalized(
+                self_ty,
+                obligation.predicate.def_id(),
+                self.tcx().consts.true_,
+            ),
             ty::CoroutineClosure(_, args) => {
                 args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
                     ty::TraitRef::new(
@@ -896,7 +898,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
             }
         };
 
-        self.confirm_poly_trait_refs(obligation, trait_ref)
+        self.equate_trait_refs(
+            &obligation.cause,
+            obligation.param_env,
+            placeholder_predicate.trait_ref,
+            trait_ref,
+        )
     }
 
     #[instrument(skip(self), level = "debug")]
@@ -904,8 +911,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         &mut self,
         obligation: &PolyTraitObligation<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
+        let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
+        let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
+
         let tcx = self.tcx();
-        let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
 
         let mut nested = vec![];
         let (trait_ref, kind_ty) = match *self_ty.kind() {
@@ -972,7 +981,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
             _ => bug!("expected callable type for AsyncFn candidate"),
         };
 
-        nested.extend(self.confirm_poly_trait_refs(obligation, trait_ref)?);
+        nested.extend(self.equate_trait_refs(
+            &obligation.cause,
+            obligation.param_env,
+            placeholder_predicate.trait_ref,
+            trait_ref,
+        )?);
 
         let goal_kind =
             self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap();
@@ -1025,34 +1039,34 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
     /// selection of the impl. Therefore, if there is a mismatch, we
     /// report an error to the user.
     #[instrument(skip(self), level = "trace")]
-    fn confirm_poly_trait_refs(
+    fn equate_trait_refs(
         &mut self,
-        obligation: &PolyTraitObligation<'tcx>,
-        self_ty_trait_ref: ty::PolyTraitRef<'tcx>,
+        cause: &ObligationCause<'tcx>,
+        param_env: ty::ParamEnv<'tcx>,
+        obligation_trait_ref: ty::TraitRef<'tcx>,
+        found_trait_ref: ty::PolyTraitRef<'tcx>,
     ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
-        let obligation_trait_ref =
-            self.infcx.enter_forall_and_leak_universe(obligation.predicate.to_poly_trait_ref());
-        let self_ty_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
-            obligation.cause.span,
+        let found_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
+            cause.span,
             HigherRankedType,
-            self_ty_trait_ref,
+            found_trait_ref,
         );
         // Normalize the obligation and expected trait refs together, because why not
-        let Normalized { obligations: nested, value: (obligation_trait_ref, expected_trait_ref) } =
+        let Normalized { obligations: nested, value: (obligation_trait_ref, found_trait_ref) } =
             ensure_sufficient_stack(|| {
                 normalize_with_depth(
                     self,
-                    obligation.param_env,
-                    obligation.cause.clone(),
-                    obligation.recursion_depth + 1,
-                    (obligation_trait_ref, self_ty_trait_ref),
+                    param_env,
+                    cause.clone(),
+                    0,
+                    (obligation_trait_ref, found_trait_ref),
                 )
             });
 
         // needed to define opaque types for tests/ui/type-alias-impl-trait/assoc-projection-ice.rs
         self.infcx
-            .at(&obligation.cause, obligation.param_env)
-            .eq(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
+            .at(&cause, param_env)
+            .eq(DefineOpaqueTypes::Yes, obligation_trait_ref, found_trait_ref)
             .map(|InferOk { mut obligations, .. }| {
                 obligations.extend(nested);
                 obligations
@@ -1060,7 +1074,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
             .map_err(|terr| {
                 SignatureMismatch(Box::new(SignatureMismatchData {
                     expected_trait_ref: ty::Binder::dummy(obligation_trait_ref),
-                    found_trait_ref: ty::Binder::dummy(expected_trait_ref),
+                    found_trait_ref: ty::Binder::dummy(found_trait_ref),
                     terr,
                 }))
             })
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index 1894fbba302..084f7b1a8c2 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -2679,26 +2679,18 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
     #[instrument(skip(self), level = "debug")]
     fn closure_trait_ref_unnormalized(
         &mut self,
-        obligation: &PolyTraitObligation<'tcx>,
-        args: GenericArgsRef<'tcx>,
+        self_ty: Ty<'tcx>,
+        fn_trait_def_id: DefId,
         fn_host_effect: ty::Const<'tcx>,
     ) -> ty::PolyTraitRef<'tcx> {
+        let ty::Closure(_, args) = *self_ty.kind() else {
+            bug!("expected closure, found {self_ty}");
+        };
         let closure_sig = args.as_closure().sig();
 
-        debug!(?closure_sig);
-
-        // NOTE: The self-type is an unboxed closure type and hence is
-        // in fact unparameterized (or at least does not reference any
-        // regions bound in the obligation).
-        let self_ty = obligation
-            .predicate
-            .self_ty()
-            .no_bound_vars()
-            .expect("unboxed closure type should not capture bound vars from the predicate");
-
         closure_trait_ref_and_return_type(
             self.tcx(),
-            obligation.predicate.def_id(),
+            fn_trait_def_id,
             self_ty,
             closure_sig,
             util::TupleArgumentsFlag::No,