about summary refs log tree commit diff
diff options
context:
space:
mode:
authorlcnr <rust@lcnr.de>2023-01-17 12:26:28 +0100
committerlcnr <rust@lcnr.de>2023-01-18 08:11:15 +0100
commit9a757d6ee4229bd7e022a2f3f26946cf05fddbcf (patch)
treef36a2b6b59fb2b1af5f07a1352c1a70d4ad9631e
parent660c28391c79bd12e116724a8877a2148630dee5 (diff)
downloadrust-9a757d6ee4229bd7e022a2f3f26946cf05fddbcf.tar.gz
rust-9a757d6ee4229bd7e022a2f3f26946cf05fddbcf.zip
add `eq` to `InferCtxtExt`
-rw-r--r--compiler/rustc_trait_selection/src/solve/infcx_ext.rs40
-rw-r--r--compiler/rustc_trait_selection/src/solve/project_goals.rs29
-rw-r--r--compiler/rustc_trait_selection/src/solve/trait_goals.rs18
3 files changed, 49 insertions, 38 deletions
diff --git a/compiler/rustc_trait_selection/src/solve/infcx_ext.rs b/compiler/rustc_trait_selection/src/solve/infcx_ext.rs
index 8a8c3091d54..f92d6463134 100644
--- a/compiler/rustc_trait_selection/src/solve/infcx_ext.rs
+++ b/compiler/rustc_trait_selection/src/solve/infcx_ext.rs
@@ -1,11 +1,28 @@
+use rustc_infer::infer::at::ToTrace;
 use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
-use rustc_infer::infer::InferCtxt;
-use rustc_middle::ty::Ty;
+use rustc_infer::infer::{InferCtxt, InferOk};
+use rustc_infer::traits::query::NoSolution;
+use rustc_infer::traits::ObligationCause;
+use rustc_middle::ty::{self, Ty};
 use rustc_span::DUMMY_SP;
 
+use super::Goal;
+
 /// Methods used inside of the canonical queries of the solver.
+///
+/// Most notably these do not care about diagnostics information.
+/// If you find this while looking for methods to use outside of the
+/// solver, you may look at the implementation of these method for
+/// help.
 pub(super) trait InferCtxtExt<'tcx> {
     fn next_ty_infer(&self) -> Ty<'tcx>;
+
+    fn eq<T: ToTrace<'tcx>>(
+        &self,
+        param_env: ty::ParamEnv<'tcx>,
+        lhs: T,
+        rhs: T,
+    ) -> Result<Vec<Goal<'tcx, ty::Predicate<'tcx>>>, NoSolution>;
 }
 
 impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
@@ -15,4 +32,23 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
             span: DUMMY_SP,
         })
     }
+
+    #[instrument(level = "debug", skip(self, param_env), ret)]
+    fn eq<T: ToTrace<'tcx>>(
+        &self,
+        param_env: ty::ParamEnv<'tcx>,
+        lhs: T,
+        rhs: T,
+    ) -> Result<Vec<Goal<'tcx, ty::Predicate<'tcx>>>, NoSolution> {
+        self.at(&ObligationCause::dummy(), param_env)
+            .define_opaque_types(false)
+            .eq(lhs, rhs)
+            .map(|InferOk { value: (), obligations }| {
+                obligations.into_iter().map(|o| o.into()).collect()
+            })
+            .map_err(|e| {
+                debug!(?e, "failed to equate");
+                NoSolution
+            })
+    }
 }
diff --git a/compiler/rustc_trait_selection/src/solve/project_goals.rs b/compiler/rustc_trait_selection/src/solve/project_goals.rs
index 92c5d4e53f5..cf07926f85a 100644
--- a/compiler/rustc_trait_selection/src/solve/project_goals.rs
+++ b/compiler/rustc_trait_selection/src/solve/project_goals.rs
@@ -1,14 +1,15 @@
 use crate::traits::{specialization_graph, translate_substs};
 
 use super::assembly::{self, Candidate, CandidateSource};
+use super::infcx_ext::InferCtxtExt;
 use super::{Certainty, EvalCtxt, Goal, QueryResult};
 use rustc_errors::ErrorGuaranteed;
 use rustc_hir::def::DefKind;
 use rustc_hir::def_id::DefId;
-use rustc_infer::infer::{InferCtxt, InferOk};
+use rustc_infer::infer::InferCtxt;
 use rustc_infer::traits::query::NoSolution;
 use rustc_infer::traits::specialization_graph::LeafDef;
-use rustc_infer::traits::{ObligationCause, Reveal};
+use rustc_infer::traits::Reveal;
 use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
 use rustc_middle::ty::ProjectionPredicate;
 use rustc_middle::ty::TypeVisitable;
@@ -112,14 +113,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {
             let impl_substs = ecx.infcx.fresh_substs_for_item(DUMMY_SP, impl_def_id);
             let impl_trait_ref = impl_trait_ref.subst(tcx, impl_substs);
 
-            let Ok(InferOk { obligations, .. }) = ecx.infcx
-                .at(&ObligationCause::dummy(), goal.param_env)
-                .define_opaque_types(false)
-                .eq(goal_trait_ref, impl_trait_ref)
-                .map_err(|e| debug!("failed to equate trait refs: {e:?}"))
-            else {
-                return Err(NoSolution)
-            };
+            let mut nested_goals = ecx.infcx.eq(goal.param_env, goal_trait_ref, impl_trait_ref)?;
             let where_clause_bounds = tcx
                 .predicates_of(impl_def_id)
                 .instantiate(tcx, impl_substs)
@@ -127,8 +121,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {
                 .into_iter()
                 .map(|pred| goal.with(tcx, pred));
 
-            let nested_goals =
-                obligations.into_iter().map(|o| o.into()).chain(where_clause_bounds).collect();
+            nested_goals.extend(where_clause_bounds);
             let trait_ref_certainty = ecx.evaluate_all(nested_goals)?;
 
             let Some(assoc_def) = fetch_eligible_assoc_item_def(
@@ -185,16 +178,8 @@ impl<'tcx> assembly::GoalKind<'tcx> for ProjectionPredicate<'tcx> {
                 ty.map_bound(|ty| ty.into())
             };
 
-            let Ok(InferOk { obligations, .. }) = ecx.infcx
-                .at(&ObligationCause::dummy(), goal.param_env)
-                .define_opaque_types(false)
-                .eq(goal.predicate.term,  term.subst(tcx, substs))
-                .map_err(|e| debug!("failed to equate trait refs: {e:?}"))
-            else {
-                return Err(NoSolution);
-            };
-
-            let nested_goals = obligations.into_iter().map(|o| o.into()).collect();
+            let nested_goals =
+                ecx.infcx.eq(goal.param_env, goal.predicate.term, term.subst(tcx, substs))?;
             let rhs_certainty = ecx.evaluate_all(nested_goals)?;
 
             Ok(trait_ref_certainty.unify_and(rhs_certainty))
diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs
index 3c8314aa565..bbe175d5cc8 100644
--- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs
+++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs
@@ -3,11 +3,10 @@
 use std::iter;
 
 use super::assembly::{self, Candidate, CandidateSource};
+use super::infcx_ext::InferCtxtExt;
 use super::{Certainty, EvalCtxt, Goal, QueryResult};
 use rustc_hir::def_id::DefId;
-use rustc_infer::infer::InferOk;
 use rustc_infer::traits::query::NoSolution;
-use rustc_infer::traits::ObligationCause;
 use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
 use rustc_middle::ty::TraitPredicate;
 use rustc_middle::ty::{self, Ty, TyCtxt};
@@ -45,24 +44,15 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {
             let impl_substs = ecx.infcx.fresh_substs_for_item(DUMMY_SP, impl_def_id);
             let impl_trait_ref = impl_trait_ref.subst(tcx, impl_substs);
 
-            let Ok(InferOk { obligations, .. }) = ecx.infcx
-                .at(&ObligationCause::dummy(), goal.param_env)
-                .define_opaque_types(false)
-                .eq(goal.predicate.trait_ref, impl_trait_ref)
-                .map_err(|e| debug!("failed to equate trait refs: {e:?}"))
-            else {
-                return Err(NoSolution);
-            };
+            let mut nested_goals =
+                ecx.infcx.eq(goal.param_env, goal.predicate.trait_ref, impl_trait_ref)?;
             let where_clause_bounds = tcx
                 .predicates_of(impl_def_id)
                 .instantiate(tcx, impl_substs)
                 .predicates
                 .into_iter()
                 .map(|pred| goal.with(tcx, pred));
-
-            let nested_goals =
-                obligations.into_iter().map(|o| o.into()).chain(where_clause_bounds).collect();
-
+            nested_goals.extend(where_clause_bounds);
             ecx.evaluate_all(nested_goals)
         })
     }