about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection/src/solve
diff options
context:
space:
mode:
authorlcnr <rust@lcnr.de>2024-05-21 19:56:53 +0000
committerlcnr <rust@lcnr.de>2024-05-28 04:54:05 +0000
commit98bfd54b0abd938d6b424d45b57766e6f2f41ea2 (patch)
treedfc09fbe055e7c12ec5df2e713aacbc8ea20be77 /compiler/rustc_trait_selection/src/solve
parent13ce22904265dbb7ab9d5bd8d508d8ea0ca4c4de (diff)
downloadrust-98bfd54b0abd938d6b424d45b57766e6f2f41ea2.tar.gz
rust-98bfd54b0abd938d6b424d45b57766e6f2f41ea2.zip
eagerly normalize when adding goals
Diffstat (limited to 'compiler/rustc_trait_selection/src/solve')
-rw-r--r--compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs77
1 files changed, 75 insertions, 2 deletions
diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
index ce408ddea37..4cf0af94811 100644
--- a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
@@ -13,11 +13,14 @@ use rustc_middle::traits::solve::{
     inspect, CanonicalInput, CanonicalResponse, Certainty, PredefinedOpaquesData, QueryResult,
 };
 use rustc_middle::traits::specialization_graph;
+use rustc_middle::ty::AliasRelationDirection;
+use rustc_middle::ty::TypeFolder;
 use rustc_middle::ty::{
     self, InferCtxtLike, OpaqueTypeKey, Ty, TyCtxt, TypeFoldable, TypeSuperVisitable,
     TypeVisitable, TypeVisitableExt, TypeVisitor,
 };
 use rustc_span::DUMMY_SP;
+use rustc_type_ir::fold::TypeSuperFoldable;
 use rustc_type_ir::{self as ir, CanonicalVarValues, Interner};
 use rustc_type_ir_macros::{Lift_Generic, TypeFoldable_Generic, TypeVisitable_Generic};
 use std::ops::ControlFlow;
@@ -455,13 +458,23 @@ impl<'a, 'tcx> EvalCtxt<'a, InferCtxt<'tcx>> {
     }
 
     #[instrument(level = "trace", skip(self))]
-    pub(super) fn add_normalizes_to_goal(&mut self, goal: Goal<'tcx, ty::NormalizesTo<'tcx>>) {
+    pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<'tcx, ty::NormalizesTo<'tcx>>) {
+        goal.predicate = goal
+            .predicate
+            .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
         self.inspect.add_normalizes_to_goal(self.infcx, self.max_input_universe, goal);
         self.nested_goals.normalizes_to_goals.push(goal);
     }
 
     #[instrument(level = "debug", skip(self))]
-    pub(super) fn add_goal(&mut self, source: GoalSource, goal: Goal<'tcx, ty::Predicate<'tcx>>) {
+    pub(super) fn add_goal(
+        &mut self,
+        source: GoalSource,
+        mut goal: Goal<'tcx, ty::Predicate<'tcx>>,
+    ) {
+        goal.predicate = goal
+            .predicate
+            .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
         self.inspect.add_goal(self.infcx, self.max_input_universe, source, goal);
         self.nested_goals.goals.push((source, goal));
     }
@@ -1084,3 +1097,63 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
         });
     }
 }
+
+/// Eagerly replace aliases with inference variables, emitting `AliasRelate`
+/// goals, used when adding goals to the `EvalCtxt`. We compute the
+/// `AliasRelate` goals before evaluating the actual goal to get all the
+/// constraints we can.
+///
+/// This is a performance optimization to more eagerly detect cycles during trait
+/// solving. See tests/ui/traits/next-solver/cycles/cycle-modulo-ambig-aliases.rs.
+struct ReplaceAliasWithInfer<'me, 'a, 'tcx> {
+    ecx: &'me mut EvalCtxt<'a, InferCtxt<'tcx>>,
+    param_env: ty::ParamEnv<'tcx>,
+}
+
+impl<'tcx> TypeFolder<TyCtxt<'tcx>> for ReplaceAliasWithInfer<'_, '_, 'tcx> {
+    fn interner(&self) -> TyCtxt<'tcx> {
+        self.ecx.tcx()
+    }
+
+    fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
+        match *ty.kind() {
+            ty::Alias(..) if !ty.has_escaping_bound_vars() => {
+                let infer_ty = self.ecx.next_ty_infer();
+                let normalizes_to = ty::PredicateKind::AliasRelate(
+                    ty.into(),
+                    infer_ty.into(),
+                    AliasRelationDirection::Equate,
+                );
+                self.ecx.add_goal(
+                    GoalSource::Misc,
+                    Goal::new(self.interner(), self.param_env, normalizes_to),
+                );
+                infer_ty
+            }
+            _ => ty.super_fold_with(self),
+        }
+    }
+
+    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
+        match ct.kind() {
+            ty::ConstKind::Unevaluated(..) if !ct.has_escaping_bound_vars() => {
+                let infer_ct = self.ecx.next_const_infer(ct.ty());
+                let normalizes_to = ty::PredicateKind::AliasRelate(
+                    ct.into(),
+                    infer_ct.into(),
+                    AliasRelationDirection::Equate,
+                );
+                self.ecx.add_goal(
+                    GoalSource::Misc,
+                    Goal::new(self.interner(), self.param_env, normalizes_to),
+                );
+                infer_ct
+            }
+            _ => ct.super_fold_with(self),
+        }
+    }
+
+    fn fold_predicate(&mut self, predicate: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
+        if predicate.allow_normalization() { predicate.super_fold_with(self) } else { predicate }
+    }
+}