about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_trait_selection/src/traits/select/mod.rs75
1 files changed, 51 insertions, 24 deletions
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index 6bf46bb9a1c..daa3f2775e0 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -662,32 +662,59 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
                         tcx.features().generic_const_exprs,
                         "`ConstEquate` without a feature gate: {c1:?} {c2:?}",
                     );
-                    debug!(?c1, ?c2, "evaluate_predicate_recursively: equating consts");
 
-                    // FIXME: we probably should only try to unify abstract constants
-                    // if the constants depend on generic parameters.
-                    //
-                    // Let's just see where this breaks :shrug:
-                    if let (ty::ConstKind::Unevaluated(_), ty::ConstKind::Unevaluated(_)) =
-                        (c1.kind(), c2.kind())
                     {
-                        if let Ok(Some(a)) = tcx.expand_abstract_consts(c1)
-                            && let Ok(Some(b)) = tcx.expand_abstract_consts(c2)
-                            && a.ty() == b.ty() 
-                            && let Ok(new_obligations) = self
-                                .infcx
-                                .at(&obligation.cause, obligation.param_env)
-                                .eq(a, b)
-                        {
-                            let mut obligations = new_obligations.obligations;
-                            self.add_depth(
-                                obligations.iter_mut(),
-                                obligation.recursion_depth,
-                            );
-                            return self.evaluate_predicates_recursively(
-                                previous_stack,
-                                obligations.into_iter(),
-                            );
+                        let c1 =
+                            if let Ok(Some(a)) = tcx.expand_abstract_consts(c1) { a } else { c1 };
+                        let c2 =
+                            if let Ok(Some(b)) = tcx.expand_abstract_consts(c2) { b } else { c2 };
+                        debug!(
+                            "evalaute_predicate_recursively: equating consts:\nc1= {:?}\nc2= {:?}",
+                            c1, c2
+                        );
+
+                        use rustc_hir::def::DefKind;
+                        use ty::ConstKind::Unevaluated;
+                        match (c1.kind(), c2.kind()) {
+                            (Unevaluated(a), Unevaluated(b))
+                                if a.def.did == b.def.did
+                                    && tcx.def_kind(a.def.did) == DefKind::AssocConst =>
+                            {
+                                if let Ok(new_obligations) = self
+                                    .infcx
+                                    .at(&obligation.cause, obligation.param_env)
+                                    .trace(c1, c2)
+                                    .eq(a.substs, b.substs)
+                                {
+                                    let mut obligations = new_obligations.obligations;
+                                    self.add_depth(
+                                        obligations.iter_mut(),
+                                        obligation.recursion_depth,
+                                    );
+                                    return self.evaluate_predicates_recursively(
+                                        previous_stack,
+                                        obligations.into_iter(),
+                                    );
+                                }
+                            }
+                            (_, Unevaluated(_)) | (Unevaluated(_), _) => (),
+                            (_, _) => {
+                                if let Ok(new_obligations) = self
+                                    .infcx
+                                    .at(&obligation.cause, obligation.param_env)
+                                    .eq(c1, c2)
+                                {
+                                    let mut obligations = new_obligations.obligations;
+                                    self.add_depth(
+                                        obligations.iter_mut(),
+                                        obligation.recursion_depth,
+                                    );
+                                    return self.evaluate_predicates_recursively(
+                                        previous_stack,
+                                        obligations.into_iter(),
+                                    );
+                                }
+                            }
                         }
                     }