about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-06-07 10:05:47 -0400
committerMichael Goulet <michael@errs.io>2024-06-11 13:52:51 -0400
commit44a6f72a725d1e274d734473c95f95caa5c6fbb6 (patch)
tree2a1dc953eafb3ba3834704d15365032f2400b6fc
parent4038010436d18725dd5c6c8cd91f9f1fca617373 (diff)
downloadrust-44a6f72a725d1e274d734473c95f95caa5c6fbb6.tar.gz
rust-44a6f72a725d1e274d734473c95f95caa5c6fbb6.zip
Make ObligationEmittingRelation deal with Goals only
-rw-r--r--compiler/rustc_borrowck/src/type_check/relate_tys.rs38
-rw-r--r--compiler/rustc_infer/src/infer/at.rs52
-rw-r--r--compiler/rustc_infer/src/infer/relate/combine.rs35
-rw-r--r--compiler/rustc_infer/src/infer/relate/glb.rs10
-rw-r--r--compiler/rustc_infer/src/infer/relate/lattice.rs9
-rw-r--r--compiler/rustc_infer/src/infer/relate/lub.rs10
-rw-r--r--compiler/rustc_infer/src/infer/relate/type_relating.rs25
7 files changed, 119 insertions, 60 deletions
diff --git a/compiler/rustc_borrowck/src/type_check/relate_tys.rs b/compiler/rustc_borrowck/src/type_check/relate_tys.rs
index cd51d73ba55..a87b9f7a23d 100644
--- a/compiler/rustc_borrowck/src/type_check/relate_tys.rs
+++ b/compiler/rustc_borrowck/src/type_check/relate_tys.rs
@@ -3,7 +3,8 @@ use rustc_errors::ErrorGuaranteed;
 use rustc_infer::infer::relate::{ObligationEmittingRelation, StructurallyRelateAliases};
 use rustc_infer::infer::relate::{Relate, RelateResult, TypeRelation};
 use rustc_infer::infer::NllRegionVariableOrigin;
-use rustc_infer::traits::{Obligation, PredicateObligation};
+use rustc_infer::traits::solve::Goal;
+use rustc_infer::traits::Obligation;
 use rustc_middle::mir::ConstraintCategory;
 use rustc_middle::span_bug;
 use rustc_middle::traits::query::NoSolution;
@@ -154,8 +155,13 @@ impl<'me, 'bccx, 'tcx> NllTypeRelating<'me, 'bccx, 'tcx> {
             ),
         };
         let cause = ObligationCause::dummy_with_span(self.span());
-        let obligations = infcx.handle_opaque_type(a, b, &cause, self.param_env())?.obligations;
-        self.register_obligations(obligations);
+        self.register_obligations(
+            infcx
+                .handle_opaque_type(a, b, &cause, self.param_env())?
+                .obligations
+                .into_iter()
+                .map(Goal::from),
+        );
         Ok(())
     }
 
@@ -550,22 +556,32 @@ impl<'bccx, 'tcx> ObligationEmittingRelation<'tcx> for NllTypeRelating<'_, 'bccx
         &mut self,
         obligations: impl IntoIterator<Item: ty::Upcast<TyCtxt<'tcx>, ty::Predicate<'tcx>>>,
     ) {
+        let tcx = self.tcx();
+        let param_env = self.param_env();
         self.register_obligations(
-            obligations
-                .into_iter()
-                .map(|to_pred| {
-                    Obligation::new(self.tcx(), ObligationCause::dummy(), self.param_env(), to_pred)
-                })
-                .collect(),
+            obligations.into_iter().map(|to_pred| Goal::new(tcx, param_env, to_pred)),
         );
     }
 
-    fn register_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
+    fn register_obligations(
+        &mut self,
+        obligations: impl IntoIterator<Item = Goal<'tcx, ty::Predicate<'tcx>>>,
+    ) {
         let _: Result<_, ErrorGuaranteed> = self.type_checker.fully_perform_op(
             self.locations,
             self.category,
             InstantiateOpaqueType {
-                obligations,
+                obligations: obligations
+                    .into_iter()
+                    .map(|goal| {
+                        Obligation::new(
+                            self.tcx(),
+                            ObligationCause::dummy_with_span(self.span()),
+                            goal.param_env,
+                            goal.predicate,
+                        )
+                    })
+                    .collect(),
                 // These fields are filled in during execution of the operation
                 base_universe: None,
                 region_constraints: None,
diff --git a/compiler/rustc_infer/src/infer/at.rs b/compiler/rustc_infer/src/infer/at.rs
index 046d908d148..8994739f5c7 100644
--- a/compiler/rustc_infer/src/infer/at.rs
+++ b/compiler/rustc_infer/src/infer/at.rs
@@ -31,6 +31,8 @@ use crate::infer::relate::{Relate, StructurallyRelateAliases, TypeRelation};
 use rustc_middle::bug;
 use rustc_middle::ty::{Const, ImplSubject};
 
+use crate::traits::Obligation;
+
 /// Whether we should define opaque types or just treat them opaquely.
 ///
 /// Currently only used to prevent predicate matching from matching anything
@@ -119,10 +121,8 @@ impl<'a, 'tcx> At<'a, 'tcx> {
             self.param_env,
             define_opaque_types,
         );
-        fields
-            .sup()
-            .relate(expected, actual)
-            .map(|_| InferOk { value: (), obligations: fields.obligations })
+        fields.sup().relate(expected, actual)?;
+        Ok(InferOk { value: (), obligations: fields.into_obligations() })
     }
 
     /// Makes `expected <: actual`.
@@ -141,10 +141,8 @@ impl<'a, 'tcx> At<'a, 'tcx> {
             self.param_env,
             define_opaque_types,
         );
-        fields
-            .sub()
-            .relate(expected, actual)
-            .map(|_| InferOk { value: (), obligations: fields.obligations })
+        fields.sub().relate(expected, actual)?;
+        Ok(InferOk { value: (), obligations: fields.into_obligations() })
     }
 
     /// Makes `expected == actual`.
@@ -163,10 +161,22 @@ impl<'a, 'tcx> At<'a, 'tcx> {
             self.param_env,
             define_opaque_types,
         );
-        fields
-            .equate(StructurallyRelateAliases::No)
-            .relate(expected, actual)
-            .map(|_| InferOk { value: (), obligations: fields.obligations })
+        fields.equate(StructurallyRelateAliases::No).relate(expected, actual)?;
+        Ok(InferOk {
+            value: (),
+            obligations: fields
+                .obligations
+                .into_iter()
+                .map(|goal| {
+                    Obligation::new(
+                        self.infcx.tcx,
+                        fields.trace.cause.clone(),
+                        goal.param_env,
+                        goal.predicate,
+                    )
+                })
+                .collect(),
+        })
     }
 
     /// Equates `expected` and `found` while structurally relating aliases.
@@ -187,10 +197,8 @@ impl<'a, 'tcx> At<'a, 'tcx> {
             self.param_env,
             DefineOpaqueTypes::Yes,
         );
-        fields
-            .equate(StructurallyRelateAliases::Yes)
-            .relate(expected, actual)
-            .map(|_| InferOk { value: (), obligations: fields.obligations })
+        fields.equate(StructurallyRelateAliases::Yes).relate(expected, actual)?;
+        Ok(InferOk { value: (), obligations: fields.into_obligations() })
     }
 
     pub fn relate<T>(
@@ -237,10 +245,8 @@ impl<'a, 'tcx> At<'a, 'tcx> {
             self.param_env,
             define_opaque_types,
         );
-        fields
-            .lub()
-            .relate(expected, actual)
-            .map(|value| InferOk { value, obligations: fields.obligations })
+        let value = fields.lub().relate(expected, actual)?;
+        Ok(InferOk { value, obligations: fields.into_obligations() })
     }
 
     /// Computes the greatest-lower-bound, or mutual subtype, of two
@@ -261,10 +267,8 @@ impl<'a, 'tcx> At<'a, 'tcx> {
             self.param_env,
             define_opaque_types,
         );
-        fields
-            .glb()
-            .relate(expected, actual)
-            .map(|value| InferOk { value, obligations: fields.obligations })
+        let value = fields.glb().relate(expected, actual)?;
+        Ok(InferOk { value, obligations: fields.into_obligations() })
     }
 }
 
diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs
index e62ef5d4ea4..1a0a0d10c6d 100644
--- a/compiler/rustc_infer/src/infer/relate/combine.rs
+++ b/compiler/rustc_infer/src/infer/relate/combine.rs
@@ -28,6 +28,7 @@ use crate::infer::{DefineOpaqueTypes, InferCtxt, TypeTrace};
 use crate::traits::{Obligation, PredicateObligation};
 use rustc_middle::bug;
 use rustc_middle::infer::unify_key::EffectVarValue;
+use rustc_middle::traits::solve::Goal;
 use rustc_middle::ty::error::{ExpectedFound, TypeError};
 use rustc_middle::ty::{self, InferConst, Ty, TyCtxt, TypeVisitableExt, Upcast};
 use rustc_middle::ty::{IntType, UintType};
@@ -38,7 +39,7 @@ pub struct CombineFields<'infcx, 'tcx> {
     pub infcx: &'infcx InferCtxt<'tcx>,
     pub trace: TypeTrace<'tcx>,
     pub param_env: ty::ParamEnv<'tcx>,
-    pub obligations: Vec<PredicateObligation<'tcx>>,
+    pub obligations: Vec<Goal<'tcx, ty::Predicate<'tcx>>>,
     pub define_opaque_types: DefineOpaqueTypes,
 }
 
@@ -51,6 +52,20 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
     ) -> Self {
         Self { infcx, trace, param_env, define_opaque_types, obligations: vec![] }
     }
+
+    pub(crate) fn into_obligations(self) -> Vec<PredicateObligation<'tcx>> {
+        self.obligations
+            .into_iter()
+            .map(|goal| {
+                Obligation::new(
+                    self.infcx.tcx,
+                    self.trace.cause.clone(),
+                    goal.param_env,
+                    goal.predicate,
+                )
+            })
+            .collect()
+    }
 }
 
 impl<'tcx> InferCtxt<'tcx> {
@@ -290,7 +305,10 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
         Glb::new(self)
     }
 
-    pub fn register_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
+    pub fn register_obligations(
+        &mut self,
+        obligations: impl IntoIterator<Item = Goal<'tcx, ty::Predicate<'tcx>>>,
+    ) {
         self.obligations.extend(obligations);
     }
 
@@ -298,9 +316,11 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
         &mut self,
         obligations: impl IntoIterator<Item: Upcast<TyCtxt<'tcx>, ty::Predicate<'tcx>>>,
     ) {
-        self.obligations.extend(obligations.into_iter().map(|to_pred| {
-            Obligation::new(self.infcx.tcx, self.trace.cause.clone(), self.param_env, to_pred)
-        }))
+        self.obligations.extend(
+            obligations
+                .into_iter()
+                .map(|to_pred| Goal::new(self.infcx.tcx, self.param_env, to_pred)),
+        )
     }
 }
 
@@ -315,7 +335,10 @@ pub trait ObligationEmittingRelation<'tcx>: TypeRelation<TyCtxt<'tcx>> {
     fn structurally_relate_aliases(&self) -> StructurallyRelateAliases;
 
     /// Register obligations that must hold in order for this relation to hold
-    fn register_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>);
+    fn register_obligations(
+        &mut self,
+        obligations: impl IntoIterator<Item = Goal<'tcx, ty::Predicate<'tcx>>>,
+    );
 
     /// Register predicates that must hold in order for this relation to hold. Uses
     /// a default obligation cause, [`ObligationEmittingRelation::register_obligations`] should
diff --git a/compiler/rustc_infer/src/infer/relate/glb.rs b/compiler/rustc_infer/src/infer/relate/glb.rs
index ca772b349d2..6f37995ac1e 100644
--- a/compiler/rustc_infer/src/infer/relate/glb.rs
+++ b/compiler/rustc_infer/src/infer/relate/glb.rs
@@ -1,6 +1,7 @@
 //! Greatest lower bound. See [`lattice`].
 
-use super::{Relate, RelateResult, TypeRelation};
+use rustc_middle::traits::solve::Goal;
+use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation};
 use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
 use rustc_span::Span;
 
@@ -8,7 +9,7 @@ use super::combine::{CombineFields, ObligationEmittingRelation};
 use super::lattice::{self, LatticeDir};
 use super::StructurallyRelateAliases;
 use crate::infer::{DefineOpaqueTypes, InferCtxt, SubregionOrigin};
-use crate::traits::{ObligationCause, PredicateObligation};
+use crate::traits::ObligationCause;
 
 /// "Greatest lower bound" (common subtype)
 pub struct Glb<'combine, 'infcx, 'tcx> {
@@ -147,7 +148,10 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Glb<'_, '_, 'tcx> {
         self.fields.register_predicates(obligations);
     }
 
-    fn register_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
+    fn register_obligations(
+        &mut self,
+        obligations: impl IntoIterator<Item = Goal<'tcx, ty::Predicate<'tcx>>>,
+    ) {
         self.fields.register_obligations(obligations);
     }
 
diff --git a/compiler/rustc_infer/src/infer/relate/lattice.rs b/compiler/rustc_infer/src/infer/relate/lattice.rs
index f05b984142a..8c6f1690ade 100644
--- a/compiler/rustc_infer/src/infer/relate/lattice.rs
+++ b/compiler/rustc_infer/src/infer/relate/lattice.rs
@@ -21,7 +21,8 @@ use super::combine::ObligationEmittingRelation;
 use crate::infer::{DefineOpaqueTypes, InferCtxt};
 use crate::traits::ObligationCause;
 
-use super::RelateResult;
+use rustc_middle::traits::solve::Goal;
+use rustc_middle::ty::relate::RelateResult;
 use rustc_middle::ty::TyVar;
 use rustc_middle::ty::{self, Ty};
 
@@ -109,7 +110,11 @@ where
                 && !this.infcx().next_trait_solver() =>
         {
             this.register_obligations(
-                infcx.handle_opaque_type(a, b, this.cause(), this.param_env())?.obligations,
+                infcx
+                    .handle_opaque_type(a, b, this.cause(), this.param_env())?
+                    .obligations
+                    .into_iter()
+                    .map(Goal::from),
             );
             Ok(a)
         }
diff --git a/compiler/rustc_infer/src/infer/relate/lub.rs b/compiler/rustc_infer/src/infer/relate/lub.rs
index 0b9de8de001..625cc02115a 100644
--- a/compiler/rustc_infer/src/infer/relate/lub.rs
+++ b/compiler/rustc_infer/src/infer/relate/lub.rs
@@ -4,9 +4,10 @@ use super::combine::{CombineFields, ObligationEmittingRelation};
 use super::lattice::{self, LatticeDir};
 use super::StructurallyRelateAliases;
 use crate::infer::{DefineOpaqueTypes, InferCtxt, SubregionOrigin};
-use crate::traits::{ObligationCause, PredicateObligation};
+use crate::traits::ObligationCause;
 
-use super::{Relate, RelateResult, TypeRelation};
+use rustc_middle::traits::solve::Goal;
+use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation};
 use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
 use rustc_span::Span;
 
@@ -147,7 +148,10 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Lub<'_, '_, 'tcx> {
         self.fields.register_predicates(obligations);
     }
 
-    fn register_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
+    fn register_obligations(
+        &mut self,
+        obligations: impl IntoIterator<Item = Goal<'tcx, ty::Predicate<'tcx>>>,
+    ) {
         self.fields.register_obligations(obligations)
     }
 
diff --git a/compiler/rustc_infer/src/infer/relate/type_relating.rs b/compiler/rustc_infer/src/infer/relate/type_relating.rs
index 447e4d6bfd8..328e4d8902f 100644
--- a/compiler/rustc_infer/src/infer/relate/type_relating.rs
+++ b/compiler/rustc_infer/src/infer/relate/type_relating.rs
@@ -1,11 +1,10 @@
 use super::combine::CombineFields;
+use crate::infer::relate::{ObligationEmittingRelation, StructurallyRelateAliases};
 use crate::infer::BoundRegionConversionTime::HigherRankedType;
 use crate::infer::{DefineOpaqueTypes, SubregionOrigin};
-use crate::traits::{Obligation, PredicateObligation};
-
-use super::{
-    relate_args_invariantly, relate_args_with_variances, ObligationEmittingRelation, Relate,
-    RelateResult, StructurallyRelateAliases, TypeRelation,
+use rustc_middle::traits::solve::Goal;
+use rustc_middle::ty::relate::{
+    relate_args_invariantly, relate_args_with_variances, Relate, RelateResult, TypeRelation,
 };
 use rustc_middle::ty::TyVar;
 use rustc_middle::ty::{self, Ty, TyCtxt};
@@ -88,9 +87,8 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
                     ty::Covariant => {
                         // can't make progress on `A <: B` if both A and B are
                         // type variables, so record an obligation.
-                        self.fields.obligations.push(Obligation::new(
+                        self.fields.obligations.push(Goal::new(
                             self.tcx(),
-                            self.fields.trace.cause.clone(),
                             self.fields.param_env,
                             ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate {
                                 a_is_expected: true,
@@ -102,9 +100,8 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
                     ty::Contravariant => {
                         // can't make progress on `B <: A` if both A and B are
                         // type variables, so record an obligation.
-                        self.fields.obligations.push(Obligation::new(
+                        self.fields.obligations.push(Goal::new(
                             self.tcx(),
-                            self.fields.trace.cause.clone(),
                             self.fields.param_env,
                             ty::Binder::dummy(ty::PredicateKind::Subtype(ty::SubtypePredicate {
                                 a_is_expected: false,
@@ -153,10 +150,13 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, '_, 'tcx> {
                     && def_id.is_local()
                     && !infcx.next_trait_solver() =>
             {
+                // FIXME: Don't shuttle between Goal and Obligation
                 self.fields.obligations.extend(
                     infcx
                         .handle_opaque_type(a, b, &self.fields.trace.cause, self.param_env())?
-                        .obligations,
+                        .obligations
+                        .into_iter()
+                        .map(Goal::from),
                 );
             }
 
@@ -318,7 +318,10 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for TypeRelating<'_, '_, 'tcx> {
         self.fields.register_predicates(obligations);
     }
 
-    fn register_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
+    fn register_obligations(
+        &mut self,
+        obligations: impl IntoIterator<Item = Goal<'tcx, ty::Predicate<'tcx>>>,
+    ) {
         self.fields.register_obligations(obligations);
     }