about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-09-26 23:20:59 -0400
committerMichael Goulet <michael@errs.io>2024-09-27 15:43:18 -0400
commit4fb097a5de14399b86dda519e1e597104a15b3b2 (patch)
tree8a2a3d7ca1690bab89d870fb44059358c00005c8
parentd4ee408afc59b36ff59b6fd12d47c1beeba8e985 (diff)
downloadrust-4fb097a5de14399b86dda519e1e597104a15b3b2.tar.gz
rust-4fb097a5de14399b86dda519e1e597104a15b3b2.zip
Instantiate binders when checking supertrait upcasting
-rw-r--r--compiler/rustc_infer/src/infer/at.rs19
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs16
-rw-r--r--compiler/rustc_next_trait_solver/src/solve/trait_goals.rs49
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/mod.rs96
4 files changed, 132 insertions, 48 deletions
diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs
index 9f6a1763866..37f7f04db3f 100644
--- a/compiler/rustc_infer/src/infer/at.rs
+++ b/compiler/rustc_infer/src/infer/at.rs
@@ -159,7 +159,24 @@ impl<'a, 'tcx> At<'a, 'tcx> {
             ToTrace::to_trace(self.cause, true, expected, actual),
             self.param_env,
             define_opaque_types,
-        );
+            ToTrace::to_trace(self.cause, expected, actual),
+            expected,
+            actual,
+        )
+    }
+
+    /// Makes `expected == actual`.
+    pub fn eq_trace<T>(
+        self,
+        define_opaque_types: DefineOpaqueTypes,
+        trace: TypeTrace<'tcx>,
+        expected: T,
+        actual: T,
+    ) -> InferResult<'tcx, ()>
+    where
+        T: Relate<TyCtxt<'tcx>>,
+    {
+        let mut fields = CombineFields::new(self.infcx, trace, self.param_env, define_opaque_types);
         fields.equate(StructurallyRelateAliases::No).relate(expected, actual)?;
         Ok(InferOk {
             value: (),
diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
index 12ad0724b5c..932875b759f 100644
--- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
@@ -448,10 +448,10 @@ where
                 }
             }
         } else {
-            self.delegate.enter_forall(kind, |kind| {
-                let goal = goal.with(self.cx(), ty::Binder::dummy(kind));
-                self.add_goal(GoalSource::InstantiateHigherRanked, goal);
-                self.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
+            self.enter_forall(kind, |ecx, kind| {
+                let goal = goal.with(ecx.cx(), ty::Binder::dummy(kind));
+                ecx.add_goal(GoalSource::InstantiateHigherRanked, goal);
+                ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
             })
         }
     }
@@ -840,12 +840,14 @@ where
         self.delegate.instantiate_binder_with_infer(value)
     }
 
+    /// `enter_forall`, but takes `&mut self` and passes it back through the
+    /// callback since it can't be aliased during the call.
     pub(super) fn enter_forall<T: TypeFoldable<I> + Copy, U>(
-        &self,
+        &mut self,
         value: ty::Binder<I, T>,
-        f: impl FnOnce(T) -> U,
+        f: impl FnOnce(&mut Self, T) -> U,
     ) -> U {
-        self.delegate.enter_forall(value, f)
+        self.delegate.enter_forall(value, |value| f(self, value))
     }
 
     pub(super) fn resolve_vars_if_possible<T>(&self, value: T) -> T
diff --git a/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs b/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
index 0befe7f5e8a..493bbf1e665 100644
--- a/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
+++ b/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
@@ -895,10 +895,13 @@ where
                 source_projection.item_def_id() == target_projection.item_def_id()
                     && ecx
                         .probe(|_| ProbeKind::UpcastProjectionCompatibility)
-                        .enter(|ecx| -> Result<(), NoSolution> {
-                            ecx.sub(param_env, source_projection, target_projection)?;
-                            let _ = ecx.try_evaluate_added_goals()?;
-                            Ok(())
+                        .enter(|ecx| -> Result<_, NoSolution> {
+                            ecx.enter_forall(target_projection, |ecx, target_projection| {
+                                let source_projection =
+                                    ecx.instantiate_binder_with_infer(source_projection);
+                                ecx.eq(param_env, source_projection, target_projection)?;
+                                ecx.try_evaluate_added_goals()
+                            })
                         })
                         .is_ok()
             };
@@ -909,11 +912,14 @@ where
                     // Check that a's supertrait (upcast_principal) is compatible
                     // with the target (b_ty).
                     ty::ExistentialPredicate::Trait(target_principal) => {
-                        ecx.sub(
-                            param_env,
-                            upcast_principal.unwrap(),
-                            bound.rebind(target_principal),
-                        )?;
+                        let source_principal = upcast_principal.unwrap();
+                        let target_principal = bound.rebind(target_principal);
+                        ecx.enter_forall(target_principal, |ecx, target_principal| {
+                            let source_principal =
+                                ecx.instantiate_binder_with_infer(source_principal);
+                            ecx.eq(param_env, source_principal, target_principal)?;
+                            ecx.try_evaluate_added_goals()
+                        })?;
                     }
                     // Check that b_ty's projection is satisfied by exactly one of
                     // a_ty's projections. First, we look through the list to see if
@@ -934,7 +940,12 @@ where
                                 Certainty::AMBIGUOUS,
                             );
                         }
-                        ecx.sub(param_env, source_projection, target_projection)?;
+                        ecx.enter_forall(target_projection, |ecx, target_projection| {
+                            let source_projection =
+                                ecx.instantiate_binder_with_infer(source_projection);
+                            ecx.eq(param_env, source_projection, target_projection)?;
+                            ecx.try_evaluate_added_goals()
+                        })?;
                     }
                     // Check that b_ty's auto traits are present in a_ty's bounds.
                     ty::ExistentialPredicate::AutoTrait(def_id) => {
@@ -1187,17 +1198,15 @@ where
         ) -> Result<Vec<ty::Binder<I, I::Ty>>, NoSolution>,
     ) -> Result<Candidate<I>, NoSolution> {
         self.probe_trait_candidate(source).enter(|ecx| {
-            ecx.add_goals(
-                GoalSource::ImplWhereBound,
-                constituent_tys(ecx, goal.predicate.self_ty())?
-                    .into_iter()
-                    .map(|ty| {
-                        ecx.enter_forall(ty, |ty| {
-                            goal.with(ecx.cx(), goal.predicate.with_self_ty(ecx.cx(), ty))
-                        })
+            let goals = constituent_tys(ecx, goal.predicate.self_ty())?
+                .into_iter()
+                .map(|ty| {
+                    ecx.enter_forall(ty, |ecx, ty| {
+                        goal.with(ecx.cx(), goal.predicate.with_self_ty(ecx.cx(), ty))
                     })
-                    .collect::<Vec<_>>(),
-            );
+                })
+                .collect::<Vec<_>>();
+            ecx.add_goals(GoalSource::ImplWhereBound, goals);
             ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
         })
     }
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index 2922a4898e9..c3042ab9636 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -16,6 +16,7 @@ use rustc_hir::LangItem;
 use rustc_hir::def_id::DefId;
 use rustc_infer::infer::BoundRegionConversionTime::{self, HigherRankedType};
 use rustc_infer::infer::DefineOpaqueTypes;
+use rustc_infer::infer::at::ToTrace;
 use rustc_infer::infer::relate::TypeRelation;
 use rustc_infer::traits::TraitObligation;
 use rustc_middle::bug;
@@ -44,7 +45,7 @@ use super::{
     TraitQueryMode, const_evaluatable, project, util, wf,
 };
 use crate::error_reporting::InferCtxtErrorExt;
-use crate::infer::{InferCtxt, InferCtxtExt, InferOk, TypeFreshener};
+use crate::infer::{InferCtxt, InferOk, TypeFreshener};
 use crate::solve::InferCtxtSelectExt as _;
 use crate::traits::normalize::{normalize_with_depth, normalize_with_depth_to};
 use crate::traits::project::{ProjectAndUnifyResult, ProjectionCacheKeyExt};
@@ -2579,16 +2580,32 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
                 // Check that a_ty's supertrait (upcast_principal) is compatible
                 // with the target (b_ty).
                 ty::ExistentialPredicate::Trait(target_principal) => {
+                    let hr_source_principal = upcast_principal.map_bound(|trait_ref| {
+                        ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref)
+                    });
+                    let hr_target_principal = bound.rebind(target_principal);
+
                     nested.extend(
                         self.infcx
-                            .at(&obligation.cause, obligation.param_env)
-                            .sup(
-                                DefineOpaqueTypes::Yes,
-                                bound.rebind(target_principal),
-                                upcast_principal.map_bound(|trait_ref| {
-                                    ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref)
-                                }),
-                            )
+                            .enter_forall(hr_target_principal, |target_principal| {
+                                let source_principal =
+                                    self.infcx.instantiate_binder_with_fresh_vars(
+                                        obligation.cause.span,
+                                        HigherRankedType,
+                                        hr_source_principal,
+                                    );
+                                self.infcx.at(&obligation.cause, obligation.param_env).eq_trace(
+                                    DefineOpaqueTypes::Yes,
+                                    ToTrace::to_trace(
+                                        &obligation.cause,
+                                        true,
+                                        hr_target_principal,
+                                        hr_source_principal,
+                                    ),
+                                    target_principal,
+                                    source_principal,
+                                )
+                            })
                             .map_err(|_| SelectionError::Unimplemented)?
                             .into_obligations(),
                     );
@@ -2599,19 +2616,41 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
                 // return ambiguity. Otherwise, if exactly one matches, equate
                 // it with b_ty's projection.
                 ty::ExistentialPredicate::Projection(target_projection) => {
-                    let target_projection = bound.rebind(target_projection);
+                    let hr_target_projection = bound.rebind(target_projection);
+
                     let mut matching_projections =
-                        a_data.projection_bounds().filter(|source_projection| {
+                        a_data.projection_bounds().filter(|&hr_source_projection| {
                             // Eager normalization means that we can just use can_eq
                             // here instead of equating and processing obligations.
-                            source_projection.item_def_id() == target_projection.item_def_id()
-                                && self.infcx.can_eq(
-                                    obligation.param_env,
-                                    *source_projection,
-                                    target_projection,
-                                )
+                            hr_source_projection.item_def_id() == hr_target_projection.item_def_id()
+                                && self.infcx.probe(|_| {
+                                    self.infcx
+                                        .enter_forall(hr_target_projection, |target_projection| {
+                                            let source_projection =
+                                                self.infcx.instantiate_binder_with_fresh_vars(
+                                                    obligation.cause.span,
+                                                    HigherRankedType,
+                                                    hr_source_projection,
+                                                );
+                                            self.infcx
+                                                .at(&obligation.cause, obligation.param_env)
+                                                .eq_trace(
+                                                    DefineOpaqueTypes::Yes,
+                                                    ToTrace::to_trace(
+                                                        &obligation.cause,
+                                                        true,
+                                                        hr_target_projection,
+                                                        hr_source_projection,
+                                                    ),
+                                                    target_projection,
+                                                    source_projection,
+                                                )
+                                        })
+                                        .is_ok()
+                                })
                         });
-                    let Some(source_projection) = matching_projections.next() else {
+
+                    let Some(hr_source_projection) = matching_projections.next() else {
                         return Err(SelectionError::Unimplemented);
                     };
                     if matching_projections.next().is_some() {
@@ -2619,8 +2658,25 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
                     }
                     nested.extend(
                         self.infcx
-                            .at(&obligation.cause, obligation.param_env)
-                            .sup(DefineOpaqueTypes::Yes, target_projection, source_projection)
+                            .enter_forall(hr_target_projection, |target_projection| {
+                                let source_projection =
+                                    self.infcx.instantiate_binder_with_fresh_vars(
+                                        obligation.cause.span,
+                                        HigherRankedType,
+                                        hr_source_projection,
+                                    );
+                                self.infcx.at(&obligation.cause, obligation.param_env).eq_trace(
+                                    DefineOpaqueTypes::Yes,
+                                    ToTrace::to_trace(
+                                        &obligation.cause,
+                                        true,
+                                        hr_target_projection,
+                                        hr_source_projection,
+                                    ),
+                                    target_projection,
+                                    source_projection,
+                                )
+                            })
                             .map_err(|_| SelectionError::Unimplemented)?
                             .into_obligations(),
                     );