about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_infer/src/infer/combine.rs177
-rw-r--r--compiler/rustc_infer/src/infer/generalize.rs66
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs18
3 files changed, 81 insertions, 180 deletions
diff --git a/compiler/rustc_infer/src/infer/combine.rs b/compiler/rustc_infer/src/infer/combine.rs
index b13c9627bf7..4ed4164bd27 100644
--- a/compiler/rustc_infer/src/infer/combine.rs
+++ b/compiler/rustc_infer/src/infer/combine.rs
@@ -26,21 +26,17 @@ use super::equate::Equate;
 use super::glb::Glb;
 use super::lub::Lub;
 use super::sub::Sub;
-use super::type_variable::TypeVariableValue;
-use super::{DefineOpaqueTypes, InferCtxt, MiscVariable, TypeTrace};
-use crate::infer::generalize::{generalize, CombineDelegate, Generalization};
+use super::{DefineOpaqueTypes, InferCtxt, TypeTrace};
+use crate::infer::generalize::{self, CombineDelegate, Generalization};
 use crate::traits::{Obligation, PredicateObligations};
 use rustc_middle::infer::canonical::OriginalQueryValues;
 use rustc_middle::infer::unify_key::{ConstVarValue, ConstVariableValue};
 use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind};
 use rustc_middle::ty::error::{ExpectedFound, TypeError};
 use rustc_middle::ty::relate::{RelateResult, TypeRelation};
-use rustc_middle::ty::{
-    self, AliasKind, FallibleTypeFolder, InferConst, ToPredicate, Ty, TyCtxt, TypeFoldable,
-    TypeSuperFoldable, TypeVisitableExt,
-};
+use rustc_middle::ty::{self, AliasKind, InferConst, ToPredicate, Ty, TyCtxt, TypeVisitableExt};
 use rustc_middle::ty::{IntType, UintType};
-use rustc_span::{Span, DUMMY_SP};
+use rustc_span::DUMMY_SP;
 
 #[derive(Clone)]
 pub struct CombineFields<'infcx, 'tcx> {
@@ -208,11 +204,11 @@ impl<'tcx> InferCtxt<'tcx> {
             // matching in the solver.
             let a_error = self.tcx.const_error(a.ty(), guar);
             if let ty::ConstKind::Infer(InferConst::Var(vid)) = a.kind() {
-                return self.unify_const_variable(vid, a_error);
+                return self.unify_const_variable(vid, a_error, relation.param_env());
             }
             let b_error = self.tcx.const_error(b.ty(), guar);
             if let ty::ConstKind::Infer(InferConst::Var(vid)) = b.kind() {
-                return self.unify_const_variable(vid, b_error);
+                return self.unify_const_variable(vid, b_error, relation.param_env());
             }
 
             return Ok(if relation.a_is_expected() { a_error } else { b_error });
@@ -234,11 +230,11 @@ impl<'tcx> InferCtxt<'tcx> {
             }
 
             (ty::ConstKind::Infer(InferConst::Var(vid)), _) => {
-                return self.unify_const_variable(vid, b);
+                return self.unify_const_variable(vid, b, relation.param_env());
             }
 
             (_, ty::ConstKind::Infer(InferConst::Var(vid))) => {
-                return self.unify_const_variable(vid, a);
+                return self.unify_const_variable(vid, a, relation.param_env());
             }
             (ty::ConstKind::Unevaluated(..), _) | (_, ty::ConstKind::Unevaluated(..))
                 if self.tcx.lazy_normalization() =>
@@ -291,24 +287,17 @@ impl<'tcx> InferCtxt<'tcx> {
         &self,
         target_vid: ty::ConstVid<'tcx>,
         ct: ty::Const<'tcx>,
+        param_env: ty::ParamEnv<'tcx>,
     ) -> RelateResult<'tcx, ty::Const<'tcx>> {
-        let (for_universe, span) = {
-            let mut inner = self.inner.borrow_mut();
-            let variable_table = &mut inner.const_unification_table();
-            let var_value = variable_table.probe_value(target_vid);
-            match var_value.val {
-                ConstVariableValue::Known { value } => {
-                    bug!("instantiating {:?} which has a known value {:?}", target_vid, value)
-                }
-                ConstVariableValue::Unknown { universe } => (universe, var_value.origin.span),
-            }
-        };
-        let value = ct.try_fold_with(&mut ConstInferUnifier {
-            infcx: self,
-            span,
-            for_universe,
+        let span =
+            self.inner.borrow_mut().const_unification_table().probe_value(target_vid).origin.span;
+        let Generalization { value, needs_wf: _ } = generalize::generalize(
+            self,
+            &mut CombineDelegate { infcx: self, span, param_env },
+            ct,
             target_vid,
-        })?;
+            ty::Variance::Invariant,
+        )?;
 
         self.inner.borrow_mut().const_unification_table().union_value(
             target_vid,
@@ -547,135 +536,3 @@ fn float_unification_error<'tcx>(
     let (ty::FloatVarValue(a), ty::FloatVarValue(b)) = v;
     TypeError::FloatMismatch(ExpectedFound::new(a_is_expected, a, b))
 }
-
-struct ConstInferUnifier<'cx, 'tcx> {
-    infcx: &'cx InferCtxt<'tcx>,
-
-    span: Span,
-
-    for_universe: ty::UniverseIndex,
-
-    /// The vid of the const variable that is in the process of being
-    /// instantiated; if we find this within the const we are folding,
-    /// that means we would have created a cyclic const.
-    target_vid: ty::ConstVid<'tcx>,
-}
-
-impl<'tcx> FallibleTypeFolder<TyCtxt<'tcx>> for ConstInferUnifier<'_, 'tcx> {
-    type Error = TypeError<'tcx>;
-
-    fn interner(&self) -> TyCtxt<'tcx> {
-        self.infcx.tcx
-    }
-
-    #[instrument(level = "debug", skip(self), ret)]
-    fn try_fold_ty(&mut self, t: Ty<'tcx>) -> Result<Ty<'tcx>, TypeError<'tcx>> {
-        match t.kind() {
-            &ty::Infer(ty::TyVar(vid)) => {
-                let vid = self.infcx.inner.borrow_mut().type_variables().root_var(vid);
-                let probe = self.infcx.inner.borrow_mut().type_variables().probe(vid);
-                match probe {
-                    TypeVariableValue::Known { value: u } => {
-                        debug!("ConstOccursChecker: known value {:?}", u);
-                        u.try_fold_with(self)
-                    }
-                    TypeVariableValue::Unknown { universe } => {
-                        if self.for_universe.can_name(universe) {
-                            return Ok(t);
-                        }
-
-                        let origin =
-                            *self.infcx.inner.borrow_mut().type_variables().var_origin(vid);
-                        let new_var_id = self
-                            .infcx
-                            .inner
-                            .borrow_mut()
-                            .type_variables()
-                            .new_var(self.for_universe, origin);
-                        Ok(self.interner().mk_ty_var(new_var_id))
-                    }
-                }
-            }
-            ty::Infer(ty::IntVar(_) | ty::FloatVar(_)) => Ok(t),
-            _ => t.try_super_fold_with(self),
-        }
-    }
-
-    #[instrument(level = "debug", skip(self), ret)]
-    fn try_fold_region(
-        &mut self,
-        r: ty::Region<'tcx>,
-    ) -> Result<ty::Region<'tcx>, TypeError<'tcx>> {
-        debug!("ConstInferUnifier: r={:?}", r);
-
-        match *r {
-            // Never make variables for regions bound within the type itself,
-            // nor for erased regions.
-            ty::ReLateBound(..) | ty::ReErased | ty::ReError(_) => {
-                return Ok(r);
-            }
-
-            ty::RePlaceholder(..)
-            | ty::ReVar(..)
-            | ty::ReStatic
-            | ty::ReEarlyBound(..)
-            | ty::ReFree(..) => {
-                // see common code below
-            }
-        }
-
-        let r_universe = self.infcx.universe_of_region(r);
-        if self.for_universe.can_name(r_universe) {
-            return Ok(r);
-        } else {
-            // FIXME: This is non-ideal because we don't give a
-            // very descriptive origin for this region variable.
-            Ok(self.infcx.next_region_var_in_universe(MiscVariable(self.span), self.for_universe))
-        }
-    }
-
-    #[instrument(level = "debug", skip(self), ret)]
-    fn try_fold_const(&mut self, c: ty::Const<'tcx>) -> Result<ty::Const<'tcx>, TypeError<'tcx>> {
-        match c.kind() {
-            ty::ConstKind::Infer(InferConst::Var(vid)) => {
-                // Check if the current unification would end up
-                // unifying `target_vid` with a const which contains
-                // an inference variable which is unioned with `target_vid`.
-                //
-                // Not doing so can easily result in stack overflows.
-                if self
-                    .infcx
-                    .inner
-                    .borrow_mut()
-                    .const_unification_table()
-                    .unioned(self.target_vid, vid)
-                {
-                    return Err(TypeError::CyclicConst(c));
-                }
-
-                let var_value =
-                    self.infcx.inner.borrow_mut().const_unification_table().probe_value(vid);
-                match var_value.val {
-                    ConstVariableValue::Known { value: u } => u.try_fold_with(self),
-                    ConstVariableValue::Unknown { universe } => {
-                        if self.for_universe.can_name(universe) {
-                            Ok(c)
-                        } else {
-                            let new_var_id =
-                                self.infcx.inner.borrow_mut().const_unification_table().new_key(
-                                    ConstVarValue {
-                                        origin: var_value.origin,
-                                        val: ConstVariableValue::Unknown {
-                                            universe: self.for_universe,
-                                        },
-                                    },
-                                );
-                            Ok(self.interner().mk_const(new_var_id, c.ty()))
-                        }
-                    }
-                }
-            }
-            _ => c.try_super_fold_with(self),
-        }
-    }
-}
diff --git a/compiler/rustc_infer/src/infer/generalize.rs b/compiler/rustc_infer/src/infer/generalize.rs
index c8562c84a17..8acfe638aa3 100644
--- a/compiler/rustc_infer/src/infer/generalize.rs
+++ b/compiler/rustc_infer/src/infer/generalize.rs
@@ -3,36 +3,44 @@ use rustc_hir::def_id::DefId;
 use rustc_middle::infer::unify_key::{ConstVarValue, ConstVariableValue};
 use rustc_middle::ty::error::TypeError;
 use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
-use rustc_middle::ty::{self, InferConst, Ty, TyCtxt, TypeVisitableExt};
+use rustc_middle::ty::{self, InferConst, Term, Ty, TyCtxt, TypeVisitableExt};
 use rustc_span::Span;
 
 use crate::infer::nll_relate::TypeRelatingDelegate;
 use crate::infer::type_variable::TypeVariableValue;
 use crate::infer::{InferCtxt, RegionVariableOrigin};
 
-pub(super) fn generalize<'tcx, D: GeneralizerDelegate<'tcx>>(
+pub(super) fn generalize<'tcx, D: GeneralizerDelegate<'tcx>, T: Into<Term<'tcx>> + Relate<'tcx>>(
     infcx: &InferCtxt<'tcx>,
     delegate: &mut D,
-    ty: Ty<'tcx>,
-    for_vid: ty::TyVid,
+    term: T,
+    for_vid: impl Into<ty::TermVid<'tcx>>,
     ambient_variance: ty::Variance,
-) -> RelateResult<'tcx, Generalization<Ty<'tcx>>> {
-    let for_universe = infcx.probe_ty_var(for_vid).unwrap_err();
-    let for_vid_sub_root = infcx.inner.borrow_mut().type_variables().sub_root_var(for_vid);
+) -> RelateResult<'tcx, Generalization<T>> {
+    let (for_universe, root_vid) = match for_vid.into() {
+        ty::TermVid::Ty(ty_vid) => (
+            infcx.probe_ty_var(ty_vid).unwrap_err(),
+            ty::TermVid::Ty(infcx.inner.borrow_mut().type_variables().sub_root_var(ty_vid)),
+        ),
+        ty::TermVid::Const(ct_vid) => (
+            infcx.probe_const_var(ct_vid).unwrap_err(),
+            ty::TermVid::Const(infcx.inner.borrow_mut().const_unification_table().find(ct_vid)),
+        ),
+    };
 
     let mut generalizer = Generalizer {
         infcx,
         delegate,
         ambient_variance,
-        for_vid_sub_root,
+        root_vid,
         for_universe,
-        root_ty: ty,
+        root_term: term.into(),
         needs_wf: false,
         cache: Default::default(),
     };
 
-    assert!(!ty.has_escaping_bound_vars());
-    let value = generalizer.relate(ty, ty)?;
+    assert!(!term.has_escaping_bound_vars());
+    let value = generalizer.relate(term, term)?;
     let needs_wf = generalizer.needs_wf;
     Ok(Generalization { value, needs_wf })
 }
@@ -99,11 +107,8 @@ where
 /// establishes `'0: 'x` as a constraint.
 ///
 /// [blog post]: https://is.gd/0hKvIr
-struct Generalizer<'me, 'tcx, D>
-where
-    D: GeneralizerDelegate<'tcx>,
-{
-    pub infcx: &'me InferCtxt<'tcx>,
+struct Generalizer<'me, 'tcx, D> {
+    infcx: &'me InferCtxt<'tcx>,
 
     // An delegate used to abstract the behaviors of the three previous
     // generalizer-like implementations.
@@ -116,14 +121,15 @@ where
     /// The vid of the type variable that is in the process of being
     /// instantiated. If we find this within the value we are folding,
     /// that means we would have created a cyclic value.
-    pub for_vid_sub_root: ty::TyVid,
+    root_vid: ty::TermVid<'tcx>,
 
     /// The universe of the type variable that is in the process of being
     /// instantiated. If we find anything that this universe cannot name,
     /// we reject the relation.
     for_universe: ty::UniverseIndex,
 
-    pub root_ty: Ty<'tcx>,
+    /// The root term (const or type) we're generalizing. Used for cycle errors.
+    root_term: Term<'tcx>,
 
     cache: SsoHashMap<Ty<'tcx>, Ty<'tcx>>,
 
@@ -131,6 +137,15 @@ where
     needs_wf: bool,
 }
 
+impl<'tcx, D> Generalizer<'_, 'tcx, D> {
+    fn cyclic_term_error(&self) -> TypeError<'tcx> {
+        match self.root_term.unpack() {
+            ty::TermKind::Ty(ty) => TypeError::CyclicTy(ty),
+            ty::TermKind::Const(ct) => TypeError::CyclicConst(ct),
+        }
+    }
+}
+
 impl<'tcx, D> TypeRelation<'tcx> for Generalizer<'_, 'tcx, D>
 where
     D: GeneralizerDelegate<'tcx>,
@@ -226,10 +241,10 @@ where
                 let mut inner = self.infcx.inner.borrow_mut();
                 let vid = inner.type_variables().root_var(vid);
                 let sub_vid = inner.type_variables().sub_root_var(vid);
-                if sub_vid == self.for_vid_sub_root {
+                if TermVid::Ty(sub_vid) == self.root_vid {
                     // If sub-roots are equal, then `for_vid` and
                     // `vid` are related via subtyping.
-                    Err(TypeError::CyclicTy(self.root_ty))
+                    Err(self.cyclic_term_error())
                 } else {
                     let probe = inner.type_variables().probe(vid);
                     match probe {
@@ -363,6 +378,17 @@ where
                 bug!("unexpected inference variable encountered in NLL generalization: {:?}", c);
             }
             ty::ConstKind::Infer(InferConst::Var(vid)) => {
+                // Check if the current unification would end up
+                // unifying `target_vid` with a const which contains
+                // an inference variable which is unioned with `target_vid`.
+                //
+                // Not doing so can easily result in stack overflows.
+                if TermVid::Const(self.infcx.inner.borrow_mut().const_unification_table().find(vid))
+                    == self.root_vid
+                {
+                    return Err(self.cyclic_term_error());
+                }
+
                 let mut inner = self.infcx.inner.borrow_mut();
                 let variable_table = &mut inner.const_unification_table();
                 let var_value = variable_table.probe_value(vid);
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index b414e1200cd..2fe0b2938ef 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -1070,6 +1070,24 @@ impl ParamTerm {
     }
 }
 
+#[derive(Copy, Clone, Eq, PartialEq, Debug)]
+pub enum TermVid<'tcx> {
+    Ty(ty::TyVid),
+    Const(ty::ConstVid<'tcx>),
+}
+
+impl From<ty::TyVid> for TermVid<'_> {
+    fn from(value: ty::TyVid) -> Self {
+        TermVid::Ty(value)
+    }
+}
+
+impl<'tcx> From<ty::ConstVid<'tcx>> for TermVid<'tcx> {
+    fn from(value: ty::ConstVid<'tcx>) -> Self {
+        TermVid::Const(value)
+    }
+}
+
 /// This kind of predicate has no *direct* correspondent in the
 /// syntax, but it roughly corresponds to the syntactic forms:
 ///