about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_middle/src/ty/predicate.rs9
-rw-r--r--compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs77
2 files changed, 78 insertions, 8 deletions
diff --git a/compiler/rustc_middle/src/ty/predicate.rs b/compiler/rustc_middle/src/ty/predicate.rs
index efb6cf25546..c730f5117c5 100644
--- a/compiler/rustc_middle/src/ty/predicate.rs
+++ b/compiler/rustc_middle/src/ty/predicate.rs
@@ -121,17 +121,14 @@ impl<'tcx> Predicate<'tcx> {
     #[inline]
     pub fn allow_normalization(self) -> bool {
         match self.kind().skip_binder() {
-            PredicateKind::Clause(ClauseKind::WellFormed(_)) => false,
-            // `NormalizesTo` is only used in the new solver, so this shouldn't
-            // matter. Normalizing `term` would be 'wrong' however, as it changes whether
-            // `normalizes-to(<T as Trait>::Assoc, <T as Trait>::Assoc)` holds.
-            PredicateKind::NormalizesTo(..) => false,
+            PredicateKind::Clause(ClauseKind::WellFormed(_))
+            | PredicateKind::AliasRelate(..)
+            | PredicateKind::NormalizesTo(..) => false,
             PredicateKind::Clause(ClauseKind::Trait(_))
             | PredicateKind::Clause(ClauseKind::RegionOutlives(_))
             | PredicateKind::Clause(ClauseKind::TypeOutlives(_))
             | PredicateKind::Clause(ClauseKind::Projection(_))
             | PredicateKind::Clause(ClauseKind::ConstArgHasType(..))
-            | PredicateKind::AliasRelate(..)
             | PredicateKind::ObjectSafe(_)
             | PredicateKind::Subtype(_)
             | PredicateKind::Coerce(_)
diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
index ce408ddea37..4cf0af94811 100644
--- a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
@@ -13,11 +13,14 @@ use rustc_middle::traits::solve::{
     inspect, CanonicalInput, CanonicalResponse, Certainty, PredefinedOpaquesData, QueryResult,
 };
 use rustc_middle::traits::specialization_graph;
+use rustc_middle::ty::AliasRelationDirection;
+use rustc_middle::ty::TypeFolder;
 use rustc_middle::ty::{
     self, InferCtxtLike, OpaqueTypeKey, Ty, TyCtxt, TypeFoldable, TypeSuperVisitable,
     TypeVisitable, TypeVisitableExt, TypeVisitor,
 };
 use rustc_span::DUMMY_SP;
+use rustc_type_ir::fold::TypeSuperFoldable;
 use rustc_type_ir::{self as ir, CanonicalVarValues, Interner};
 use rustc_type_ir_macros::{Lift_Generic, TypeFoldable_Generic, TypeVisitable_Generic};
 use std::ops::ControlFlow;
@@ -455,13 +458,23 @@ impl<'a, 'tcx> EvalCtxt<'a, InferCtxt<'tcx>> {
     }
 
     #[instrument(level = "trace", skip(self))]
-    pub(super) fn add_normalizes_to_goal(&mut self, goal: Goal<'tcx, ty::NormalizesTo<'tcx>>) {
+    pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<'tcx, ty::NormalizesTo<'tcx>>) {
+        goal.predicate = goal
+            .predicate
+            .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
         self.inspect.add_normalizes_to_goal(self.infcx, self.max_input_universe, goal);
         self.nested_goals.normalizes_to_goals.push(goal);
     }
 
     #[instrument(level = "debug", skip(self))]
-    pub(super) fn add_goal(&mut self, source: GoalSource, goal: Goal<'tcx, ty::Predicate<'tcx>>) {
+    pub(super) fn add_goal(
+        &mut self,
+        source: GoalSource,
+        mut goal: Goal<'tcx, ty::Predicate<'tcx>>,
+    ) {
+        goal.predicate = goal
+            .predicate
+            .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
         self.inspect.add_goal(self.infcx, self.max_input_universe, source, goal);
         self.nested_goals.goals.push((source, goal));
     }
@@ -1084,3 +1097,63 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
         });
     }
 }
+
+/// Eagerly replace aliases with inference variables, emitting `AliasRelate`
+/// goals, used when adding goals to the `EvalCtxt`. We compute the
+/// `AliasRelate` goals before evaluating the actual goal to get all the
+/// constraints we can.
+///
+/// This is a performance optimization to more eagerly detect cycles during trait
+/// solving. See tests/ui/traits/next-solver/cycles/cycle-modulo-ambig-aliases.rs.
+struct ReplaceAliasWithInfer<'me, 'a, 'tcx> {
+    ecx: &'me mut EvalCtxt<'a, InferCtxt<'tcx>>,
+    param_env: ty::ParamEnv<'tcx>,
+}
+
+impl<'tcx> TypeFolder<TyCtxt<'tcx>> for ReplaceAliasWithInfer<'_, '_, 'tcx> {
+    fn interner(&self) -> TyCtxt<'tcx> {
+        self.ecx.tcx()
+    }
+
+    fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
+        match *ty.kind() {
+            ty::Alias(..) if !ty.has_escaping_bound_vars() => {
+                let infer_ty = self.ecx.next_ty_infer();
+                let normalizes_to = ty::PredicateKind::AliasRelate(
+                    ty.into(),
+                    infer_ty.into(),
+                    AliasRelationDirection::Equate,
+                );
+                self.ecx.add_goal(
+                    GoalSource::Misc,
+                    Goal::new(self.interner(), self.param_env, normalizes_to),
+                );
+                infer_ty
+            }
+            _ => ty.super_fold_with(self),
+        }
+    }
+
+    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
+        match ct.kind() {
+            ty::ConstKind::Unevaluated(..) if !ct.has_escaping_bound_vars() => {
+                let infer_ct = self.ecx.next_const_infer(ct.ty());
+                let normalizes_to = ty::PredicateKind::AliasRelate(
+                    ct.into(),
+                    infer_ct.into(),
+                    AliasRelationDirection::Equate,
+                );
+                self.ecx.add_goal(
+                    GoalSource::Misc,
+                    Goal::new(self.interner(), self.param_env, normalizes_to),
+                );
+                infer_ct
+            }
+            _ => ct.super_fold_with(self),
+        }
+    }
+
+    fn fold_predicate(&mut self, predicate: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
+        if predicate.allow_normalization() { predicate.super_fold_with(self) } else { predicate }
+    }
+}