about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_infer/src/infer/combine.rs114
1 files changed, 26 insertions, 88 deletions
diff --git a/compiler/rustc_infer/src/infer/combine.rs b/compiler/rustc_infer/src/infer/combine.rs
index 72676b718fa..a567b6acdbe 100644
--- a/compiler/rustc_infer/src/infer/combine.rs
+++ b/compiler/rustc_infer/src/infer/combine.rs
@@ -37,7 +37,10 @@ use rustc_middle::traits::ObligationCause;
 use rustc_middle::ty::error::{ExpectedFound, TypeError};
 use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
 use rustc_middle::ty::subst::SubstsRef;
-use rustc_middle::ty::{self, InferConst, Ty, TyCtxt, TypeVisitable};
+use rustc_middle::ty::{
+    self, FallibleTypeFolder, InferConst, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable,
+    TypeVisitable,
+};
 use rustc_middle::ty::{IntType, UintType};
 use rustc_span::{Span, DUMMY_SP};
 
@@ -140,8 +143,6 @@ impl<'tcx> InferCtxt<'tcx> {
         let a = self.shallow_resolve(a);
         let b = self.shallow_resolve(b);
 
-        let a_is_expected = relation.a_is_expected();
-
         match (a.kind(), b.kind()) {
             (
                 ty::ConstKind::Infer(InferConst::Var(a_vid)),
@@ -158,11 +159,11 @@ impl<'tcx> InferCtxt<'tcx> {
             }
 
             (ty::ConstKind::Infer(InferConst::Var(vid)), _) => {
-                return self.unify_const_variable(relation.param_env(), vid, b, a_is_expected);
+                return self.unify_const_variable(vid, b);
             }
 
             (_, ty::ConstKind::Infer(InferConst::Var(vid))) => {
-                return self.unify_const_variable(relation.param_env(), vid, a, !a_is_expected);
+                return self.unify_const_variable(vid, a);
             }
             (ty::ConstKind::Unevaluated(..), _) if self.tcx.lazy_normalization() => {
                 // FIXME(#59490): Need to remove the leak check to accommodate
@@ -223,10 +224,8 @@ impl<'tcx> InferCtxt<'tcx> {
     #[instrument(level = "debug", skip(self))]
     fn unify_const_variable(
         &self,
-        param_env: ty::ParamEnv<'tcx>,
         target_vid: ty::ConstVid<'tcx>,
         ct: ty::Const<'tcx>,
-        vid_is_expected: bool,
     ) -> RelateResult<'tcx, ty::Const<'tcx>> {
         let (for_universe, span) = {
             let mut inner = self.inner.borrow_mut();
@@ -239,8 +238,12 @@ impl<'tcx> InferCtxt<'tcx> {
                 ConstVariableValue::Unknown { universe } => (universe, var_value.origin.span),
             }
         };
-        let value = ConstInferUnifier { infcx: self, span, param_env, for_universe, target_vid }
-            .relate(ct, ct)?;
+        let value = ct.try_fold_with(&mut ConstInferUnifier {
+            infcx: self,
+            span,
+            for_universe,
+            target_vid,
+        })?;
 
         self.inner.borrow_mut().const_unification_table().union_value(
             target_vid,
@@ -800,8 +803,6 @@ struct ConstInferUnifier<'cx, 'tcx> {
 
     span: Span,
 
-    param_env: ty::ParamEnv<'tcx>,
-
     for_universe: ty::UniverseIndex,
 
     /// The vid of the const variable that is in the process of being
@@ -810,61 +811,15 @@ struct ConstInferUnifier<'cx, 'tcx> {
     target_vid: ty::ConstVid<'tcx>,
 }
 
-// We use `TypeRelation` here to propagate `RelateResult` upwards.
-//
-// Both inputs are expected to be the same.
-impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
-    fn tcx(&self) -> TyCtxt<'tcx> {
-        self.infcx.tcx
-    }
-
-    fn intercrate(&self) -> bool {
-        assert!(!self.infcx.intercrate);
-        false
-    }
-
-    fn param_env(&self) -> ty::ParamEnv<'tcx> {
-        self.param_env
-    }
-
-    fn tag(&self) -> &'static str {
-        "ConstInferUnifier"
-    }
-
-    fn a_is_expected(&self) -> bool {
-        true
-    }
-
-    fn mark_ambiguous(&mut self) {
-        bug!()
-    }
-
-    fn relate_with_variance<T: Relate<'tcx>>(
-        &mut self,
-        _variance: ty::Variance,
-        _info: ty::VarianceDiagInfo<'tcx>,
-        a: T,
-        b: T,
-    ) -> RelateResult<'tcx, T> {
-        // We don't care about variance here.
-        self.relate(a, b)
-    }
+impl<'tcx> FallibleTypeFolder<'tcx> for ConstInferUnifier<'_, 'tcx> {
+    type Error = TypeError<'tcx>;
 
-    fn binders<T>(
-        &mut self,
-        a: ty::Binder<'tcx, T>,
-        b: ty::Binder<'tcx, T>,
-    ) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
-    where
-        T: Relate<'tcx>,
-    {
-        Ok(a.rebind(self.relate(a.skip_binder(), b.skip_binder())?))
+    fn tcx<'a>(&'a self) -> TyCtxt<'tcx> {
+        self.infcx.tcx
     }
 
     #[instrument(level = "debug", skip(self), ret)]
-    fn tys(&mut self, t: Ty<'tcx>, _t: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
-        debug_assert_eq!(t, _t);
-
+    fn try_fold_ty(&mut self, t: Ty<'tcx>) -> Result<Ty<'tcx>, TypeError<'tcx>> {
         match t.kind() {
             &ty::Infer(ty::TyVar(vid)) => {
                 let vid = self.infcx.inner.borrow_mut().type_variables().root_var(vid);
@@ -872,7 +827,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
                 match probe {
                     TypeVariableValue::Known { value: u } => {
                         debug!("ConstOccursChecker: known value {:?}", u);
-                        self.tys(u, u)
+                        u.try_fold_with(self)
                     }
                     TypeVariableValue::Unknown { universe } => {
                         if self.for_universe.can_name(universe) {
@@ -892,16 +847,15 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
                 }
             }
             ty::Infer(ty::IntVar(_) | ty::FloatVar(_)) => Ok(t),
-            _ => relate::super_relate_tys(self, t, t),
+            _ => t.try_super_fold_with(self),
         }
     }
 
-    fn regions(
+    #[instrument(level = "debug", skip(self), ret)]
+    fn try_fold_region(
         &mut self,
         r: ty::Region<'tcx>,
-        _r: ty::Region<'tcx>,
-    ) -> RelateResult<'tcx, ty::Region<'tcx>> {
-        debug_assert_eq!(r, _r);
+    ) -> Result<ty::Region<'tcx>, TypeError<'tcx>> {
         debug!("ConstInferUnifier: r={:?}", r);
 
         match *r {
@@ -930,14 +884,8 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
         }
     }
 
-    #[instrument(level = "debug", skip(self))]
-    fn consts(
-        &mut self,
-        c: ty::Const<'tcx>,
-        _c: ty::Const<'tcx>,
-    ) -> RelateResult<'tcx, ty::Const<'tcx>> {
-        debug_assert_eq!(c, _c);
-
+    #[instrument(level = "debug", skip(self), ret)]
+    fn try_fold_const(&mut self, c: ty::Const<'tcx>) -> Result<ty::Const<'tcx>, TypeError<'tcx>> {
         match c.kind() {
             ty::ConstKind::Infer(InferConst::Var(vid)) => {
                 // Check if the current unification would end up
@@ -958,7 +906,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
                 let var_value =
                     self.infcx.inner.borrow_mut().const_unification_table().probe_value(vid);
                 match var_value.val {
-                    ConstVariableValue::Known { value: u } => self.consts(u, u),
+                    ConstVariableValue::Known { value: u } => u.try_fold_with(self),
                     ConstVariableValue::Unknown { universe } => {
                         if self.for_universe.can_name(universe) {
                             Ok(c)
@@ -977,17 +925,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
                     }
                 }
             }
-            ty::ConstKind::Unevaluated(ty::UnevaluatedConst { def, substs }) => {
-                let substs = self.relate_with_variance(
-                    ty::Variance::Invariant,
-                    ty::VarianceDiagInfo::default(),
-                    substs,
-                    substs,
-                )?;
-
-                Ok(self.tcx().mk_const(ty::UnevaluatedConst { def, substs }, c.ty()))
-            }
-            _ => relate::super_relate_consts(self, c, c),
+            _ => c.try_super_fold_with(self),
         }
     }
 }