about summary refs log tree commit diff
path: root/compiler/rustc_trait_selection/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_trait_selection/src')
-rw-r--r--compiler/rustc_trait_selection/src/solve/alias_relate.rs1
-rw-r--r--compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs77
-rw-r--r--compiler/rustc_trait_selection/src/solve/inspect/analyse.rs86
3 files changed, 117 insertions, 47 deletions
diff --git a/compiler/rustc_trait_selection/src/solve/alias_relate.rs b/compiler/rustc_trait_selection/src/solve/alias_relate.rs
index 43e61de955a..33b30bef683 100644
--- a/compiler/rustc_trait_selection/src/solve/alias_relate.rs
+++ b/compiler/rustc_trait_selection/src/solve/alias_relate.rs
@@ -28,6 +28,7 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
     ) -> QueryResult<'tcx> {
         let tcx = self.tcx();
         let Goal { param_env, predicate: (lhs, rhs, direction) } = goal;
+        debug_assert!(lhs.to_alias_term().is_some() || rhs.to_alias_term().is_some());
 
         // Structurally normalize the lhs.
         let lhs = if let Some(alias) = lhs.to_alias_term() {
diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
index ce408ddea37..4cf0af94811 100644
--- a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
@@ -13,11 +13,14 @@ use rustc_middle::traits::solve::{
     inspect, CanonicalInput, CanonicalResponse, Certainty, PredefinedOpaquesData, QueryResult,
 };
 use rustc_middle::traits::specialization_graph;
+use rustc_middle::ty::AliasRelationDirection;
+use rustc_middle::ty::TypeFolder;
 use rustc_middle::ty::{
     self, InferCtxtLike, OpaqueTypeKey, Ty, TyCtxt, TypeFoldable, TypeSuperVisitable,
     TypeVisitable, TypeVisitableExt, TypeVisitor,
 };
 use rustc_span::DUMMY_SP;
+use rustc_type_ir::fold::TypeSuperFoldable;
 use rustc_type_ir::{self as ir, CanonicalVarValues, Interner};
 use rustc_type_ir_macros::{Lift_Generic, TypeFoldable_Generic, TypeVisitable_Generic};
 use std::ops::ControlFlow;
@@ -455,13 +458,23 @@ impl<'a, 'tcx> EvalCtxt<'a, InferCtxt<'tcx>> {
     }
 
     #[instrument(level = "trace", skip(self))]
-    pub(super) fn add_normalizes_to_goal(&mut self, goal: Goal<'tcx, ty::NormalizesTo<'tcx>>) {
+    pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<'tcx, ty::NormalizesTo<'tcx>>) {
+        goal.predicate = goal
+            .predicate
+            .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
         self.inspect.add_normalizes_to_goal(self.infcx, self.max_input_universe, goal);
         self.nested_goals.normalizes_to_goals.push(goal);
     }
 
     #[instrument(level = "debug", skip(self))]
-    pub(super) fn add_goal(&mut self, source: GoalSource, goal: Goal<'tcx, ty::Predicate<'tcx>>) {
+    pub(super) fn add_goal(
+        &mut self,
+        source: GoalSource,
+        mut goal: Goal<'tcx, ty::Predicate<'tcx>>,
+    ) {
+        goal.predicate = goal
+            .predicate
+            .fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
         self.inspect.add_goal(self.infcx, self.max_input_universe, source, goal);
         self.nested_goals.goals.push((source, goal));
     }
@@ -1084,3 +1097,63 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
         });
     }
 }
+
+/// Eagerly replace aliases with inference variables, emitting `AliasRelate`
+/// goals, used when adding goals to the `EvalCtxt`. We compute the
+/// `AliasRelate` goals before evaluating the actual goal to get all the
+/// constraints we can.
+///
+/// This is a performance optimization to more eagerly detect cycles during trait
+/// solving. See tests/ui/traits/next-solver/cycles/cycle-modulo-ambig-aliases.rs.
+struct ReplaceAliasWithInfer<'me, 'a, 'tcx> {
+    ecx: &'me mut EvalCtxt<'a, InferCtxt<'tcx>>,
+    param_env: ty::ParamEnv<'tcx>,
+}
+
+impl<'tcx> TypeFolder<TyCtxt<'tcx>> for ReplaceAliasWithInfer<'_, '_, 'tcx> {
+    fn interner(&self) -> TyCtxt<'tcx> {
+        self.ecx.tcx()
+    }
+
+    fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
+        match *ty.kind() {
+            ty::Alias(..) if !ty.has_escaping_bound_vars() => {
+                let infer_ty = self.ecx.next_ty_infer();
+                let normalizes_to = ty::PredicateKind::AliasRelate(
+                    ty.into(),
+                    infer_ty.into(),
+                    AliasRelationDirection::Equate,
+                );
+                self.ecx.add_goal(
+                    GoalSource::Misc,
+                    Goal::new(self.interner(), self.param_env, normalizes_to),
+                );
+                infer_ty
+            }
+            _ => ty.super_fold_with(self),
+        }
+    }
+
+    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
+        match ct.kind() {
+            ty::ConstKind::Unevaluated(..) if !ct.has_escaping_bound_vars() => {
+                let infer_ct = self.ecx.next_const_infer(ct.ty());
+                let normalizes_to = ty::PredicateKind::AliasRelate(
+                    ct.into(),
+                    infer_ct.into(),
+                    AliasRelationDirection::Equate,
+                );
+                self.ecx.add_goal(
+                    GoalSource::Misc,
+                    Goal::new(self.interner(), self.param_env, normalizes_to),
+                );
+                infer_ct
+            }
+            _ => ct.super_fold_with(self),
+        }
+    }
+
+    fn fold_predicate(&mut self, predicate: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
+        if predicate.allow_normalization() { predicate.super_fold_with(self) } else { predicate }
+    }
+}
diff --git a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
index 447357f8b3f..1f27978e5a6 100644
--- a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
+++ b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
@@ -89,10 +89,8 @@ impl<'tcx> NormalizesToTermHack<'tcx> {
 pub struct InspectCandidate<'a, 'tcx> {
     goal: &'a InspectGoal<'a, 'tcx>,
     kind: inspect::ProbeKind<TyCtxt<'tcx>>,
-    nested_goals:
-        Vec<(GoalSource, inspect::CanonicalState<TyCtxt<'tcx>, Goal<'tcx, ty::Predicate<'tcx>>>)>,
+    steps: Vec<&'a inspect::ProbeStep<TyCtxt<'tcx>>>,
     final_state: inspect::CanonicalState<TyCtxt<'tcx>, ()>,
-    impl_args: Option<inspect::CanonicalState<TyCtxt<'tcx>, ty::GenericArgsRef<'tcx>>>,
     result: QueryResult<'tcx>,
     shallow_certainty: Certainty,
 }
@@ -148,7 +146,7 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
     #[instrument(
         level = "debug",
         skip_all,
-        fields(goal = ?self.goal.goal, nested_goals = ?self.nested_goals)
+        fields(goal = ?self.goal.goal, steps = ?self.steps)
     )]
     pub fn instantiate_nested_goals_and_opt_impl_args(
         &self,
@@ -157,22 +155,34 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
         let infcx = self.goal.infcx;
         let param_env = self.goal.goal.param_env;
         let mut orig_values = self.goal.orig_values.to_vec();
-        let instantiated_goals: Vec<_> = self
-            .nested_goals
-            .iter()
-            .map(|(source, goal)| {
-                (
-                    *source,
+
+        let mut instantiated_goals = vec![];
+        let mut opt_impl_args = None;
+        for step in &self.steps {
+            match **step {
+                inspect::ProbeStep::AddGoal(source, goal) => instantiated_goals.push((
+                    source,
                     canonical::instantiate_canonical_state(
                         infcx,
                         span,
                         param_env,
                         &mut orig_values,
-                        *goal,
+                        goal,
                     ),
-                )
-            })
-            .collect();
+                )),
+                inspect::ProbeStep::RecordImplArgs { impl_args } => {
+                    opt_impl_args = Some(canonical::instantiate_canonical_state(
+                        infcx,
+                        span,
+                        param_env,
+                        &mut orig_values,
+                        impl_args,
+                    ));
+                }
+                inspect::ProbeStep::MakeCanonicalResponse { .. }
+                | inspect::ProbeStep::NestedProbe(_) => unreachable!(),
+            }
+        }
 
         let () = canonical::instantiate_canonical_state(
             infcx,
@@ -182,17 +192,6 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
             self.final_state,
         );
 
-        let impl_args = self.impl_args.map(|impl_args| {
-            canonical::instantiate_canonical_state(
-                infcx,
-                span,
-                param_env,
-                &mut orig_values,
-                impl_args,
-            )
-            .fold_with(&mut EagerResolver::new(infcx))
-        });
-
         if let Some(term_hack) = self.goal.normalizes_to_term_hack {
             // FIXME: We ignore the expected term of `NormalizesTo` goals
             // when computing the result of its candidates. This is
@@ -200,6 +199,9 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
             let _ = term_hack.constrain(infcx, span, param_env);
         }
 
+        let opt_impl_args =
+            opt_impl_args.map(|impl_args| impl_args.fold_with(&mut EagerResolver::new(infcx)));
+
         let goals = instantiated_goals
             .into_iter()
             .map(|(source, goal)| match goal.predicate.kind().no_bound_vars() {
@@ -249,7 +251,7 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
             })
             .collect();
 
-        (goals, impl_args)
+        (goals, opt_impl_args)
     }
 
     /// Visit all nested goals of this candidate, rolling back
@@ -279,17 +281,18 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
     fn candidates_recur(
         &'a self,
         candidates: &mut Vec<InspectCandidate<'a, 'tcx>>,
-        nested_goals: &mut Vec<(
-            GoalSource,
-            inspect::CanonicalState<TyCtxt<'tcx>, Goal<'tcx, ty::Predicate<'tcx>>>,
-        )>,
-        probe: &inspect::Probe<TyCtxt<'tcx>>,
+        steps: &mut Vec<&'a inspect::ProbeStep<TyCtxt<'tcx>>>,
+        probe: &'a inspect::Probe<TyCtxt<'tcx>>,
     ) {
         let mut shallow_certainty = None;
-        let mut impl_args = None;
         for step in &probe.steps {
             match *step {
-                inspect::ProbeStep::AddGoal(source, goal) => nested_goals.push((source, goal)),
+                inspect::ProbeStep::AddGoal(..) | inspect::ProbeStep::RecordImplArgs { .. } => {
+                    steps.push(step)
+                }
+                inspect::ProbeStep::MakeCanonicalResponse { shallow_certainty: c } => {
+                    assert_eq!(shallow_certainty.replace(c), None);
+                }
                 inspect::ProbeStep::NestedProbe(ref probe) => {
                     match probe.kind {
                         // These never assemble candidates for the goal we're trying to solve.
@@ -305,18 +308,12 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
                             // Nested probes have to prove goals added in their parent
                             // but do not leak them, so we truncate the added goals
                             // afterwards.
-                            let num_goals = nested_goals.len();
-                            self.candidates_recur(candidates, nested_goals, probe);
-                            nested_goals.truncate(num_goals);
+                            let num_steps = steps.len();
+                            self.candidates_recur(candidates, steps, probe);
+                            steps.truncate(num_steps);
                         }
                     }
                 }
-                inspect::ProbeStep::MakeCanonicalResponse { shallow_certainty: c } => {
-                    assert_eq!(shallow_certainty.replace(c), None);
-                }
-                inspect::ProbeStep::RecordImplArgs { impl_args: i } => {
-                    assert_eq!(impl_args.replace(i), None);
-                }
             }
         }
 
@@ -338,11 +335,10 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
                     candidates.push(InspectCandidate {
                         goal: self,
                         kind: probe.kind,
-                        nested_goals: nested_goals.clone(),
+                        steps: steps.clone(),
                         final_state: probe.final_state,
-                        result,
                         shallow_certainty,
-                        impl_args,
+                        result,
                     });
                 }
             }