about summary refs log tree commit diff
path: root/compiler/rustc_infer/src/traits/util.rs
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2023-04-06 23:11:19 +0000
committerMichael Goulet <michael@errs.io>2023-04-06 23:30:22 +0000
commit758bedc10454b1d101613f4b363fe54c7a405b00 (patch)
tree8a1f8b0b5ebd9ffcc11cadb68d56d0c1b45adb2e /compiler/rustc_infer/src/traits/util.rs
parentde74dab880447f5227030b261dbd0f2bc4f32dba (diff)
downloadrust-758bedc10454b1d101613f4b363fe54c7a405b00.tar.gz
rust-758bedc10454b1d101613f4b363fe54c7a405b00.zip
Make elaborator generic
Diffstat (limited to 'compiler/rustc_infer/src/traits/util.rs')
-rw-r--r--compiler/rustc_infer/src/traits/util.rs209
1 files changed, 117 insertions, 92 deletions
diff --git a/compiler/rustc_infer/src/traits/util.rs b/compiler/rustc_infer/src/traits/util.rs
index f3797499866..c5095a76f42 100644
--- a/compiler/rustc_infer/src/traits/util.rs
+++ b/compiler/rustc_infer/src/traits/util.rs
@@ -1,7 +1,7 @@
 use smallvec::smallvec;
 
 use crate::infer::outlives::components::{push_outlives_components, Component};
-use crate::traits::{self, Obligation, ObligationCause, PredicateObligation};
+use crate::traits::{self, Obligation, PredicateObligation};
 use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
 use rustc_middle::ty::{self, ToPredicate, TyCtxt};
 use rustc_span::symbol::Ident;
@@ -66,99 +66,143 @@ impl<'tcx> Extend<ty::Predicate<'tcx>> for PredicateSet<'tcx> {
 /// if we know that `T: Ord`, the elaborator would deduce that `T: PartialOrd`
 /// holds as well. Similarly, if we have `trait Foo: 'static`, and we know that
 /// `T: Foo`, then we know that `T: 'static`.
-pub struct Elaborator<'tcx> {
-    stack: Vec<PredicateObligation<'tcx>>,
+pub struct Elaborator<'tcx, O> {
+    stack: Vec<O>,
     visited: PredicateSet<'tcx>,
 }
 
-pub fn elaborate_trait_ref<'tcx>(
-    tcx: TyCtxt<'tcx>,
-    trait_ref: ty::PolyTraitRef<'tcx>,
-) -> impl Iterator<Item = ty::Predicate<'tcx>> {
-    elaborate_predicates(tcx, std::iter::once(trait_ref.without_const().to_predicate(tcx)))
+/// Describes how to elaborate an obligation into a sub-obligation.
+///
+/// For [`Obligation`], a sub-obligation is combined with the current obligation's
+/// param-env and cause code. For [`ty::Predicate`], none of this is needed, since
+/// there is no param-env or cause code to copy over.
+pub trait Elaboratable<'tcx> {
+    fn predicate(&self) -> ty::Predicate<'tcx>;
+
+    // Makes a new `Self` but with a different predicate.
+    fn child(&self, predicate: ty::Predicate<'tcx>) -> Self;
+
+    // Makes a new `Self` but with a different predicate and a different cause
+    // code (if `Self` has one).
+    fn child_with_derived_cause(
+        &self,
+        predicate: ty::Predicate<'tcx>,
+        span: Span,
+        parent_trait_pred: ty::PolyTraitPredicate<'tcx>,
+        index: usize,
+    ) -> Self;
 }
 
-pub fn elaborate_trait_refs<'tcx>(
-    tcx: TyCtxt<'tcx>,
-    trait_refs: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
-) -> impl Iterator<Item = ty::Predicate<'tcx>> {
-    let predicates = trait_refs.map(move |trait_ref| trait_ref.without_const().to_predicate(tcx));
-    elaborate_predicates(tcx, predicates)
+impl<'tcx> Elaboratable<'tcx> for PredicateObligation<'tcx> {
+    fn predicate(&self) -> ty::Predicate<'tcx> {
+        self.predicate
+    }
+
+    fn child(&self, predicate: ty::Predicate<'tcx>) -> Self {
+        Obligation {
+            cause: self.cause.clone(),
+            param_env: self.param_env,
+            recursion_depth: 0,
+            predicate,
+        }
+    }
+
+    fn child_with_derived_cause(
+        &self,
+        predicate: ty::Predicate<'tcx>,
+        span: Span,
+        parent_trait_pred: ty::PolyTraitPredicate<'tcx>,
+        index: usize,
+    ) -> Self {
+        let cause = self.cause.clone().derived_cause(parent_trait_pred, |derived| {
+            traits::ImplDerivedObligation(Box::new(traits::ImplDerivedObligationCause {
+                derived,
+                impl_or_alias_def_id: parent_trait_pred.def_id(),
+                impl_def_predicate_index: Some(index),
+                span,
+            }))
+        });
+        Obligation { cause, param_env: self.param_env, recursion_depth: 0, predicate }
+    }
 }
 
-pub fn elaborate_predicates<'tcx>(
+impl<'tcx> Elaboratable<'tcx> for ty::Predicate<'tcx> {
+    fn predicate(&self) -> ty::Predicate<'tcx> {
+        *self
+    }
+
+    fn child(&self, predicate: ty::Predicate<'tcx>) -> Self {
+        predicate
+    }
+
+    fn child_with_derived_cause(
+        &self,
+        predicate: ty::Predicate<'tcx>,
+        _span: Span,
+        _parent_trait_pred: ty::PolyTraitPredicate<'tcx>,
+        _index: usize,
+    ) -> Self {
+        predicate
+    }
+}
+
+impl<'tcx> Elaboratable<'tcx> for (ty::Predicate<'tcx>, Span) {
+    fn predicate(&self) -> ty::Predicate<'tcx> {
+        self.0
+    }
+
+    fn child(&self, predicate: ty::Predicate<'tcx>) -> Self {
+        (predicate, self.1)
+    }
+
+    fn child_with_derived_cause(
+        &self,
+        predicate: ty::Predicate<'tcx>,
+        _span: Span,
+        _parent_trait_pred: ty::PolyTraitPredicate<'tcx>,
+        _index: usize,
+    ) -> Self {
+        (predicate, self.1)
+    }
+}
+
+pub fn elaborate_trait_ref<'tcx>(
     tcx: TyCtxt<'tcx>,
-    predicates: impl Iterator<Item = ty::Predicate<'tcx>>,
-) -> impl Iterator<Item = ty::Predicate<'tcx>> {
-    elaborate_obligations(
-        tcx,
-        predicates
-            .map(|predicate| {
-                Obligation::new(
-                    tcx,
-                    // We'll dump the cause/param-env later
-                    ObligationCause::dummy(),
-                    ty::ParamEnv::empty(),
-                    predicate,
-                )
-            })
-            .collect(),
-    )
-    .map(|obl| obl.predicate)
+    trait_ref: ty::PolyTraitRef<'tcx>,
+) -> Elaborator<'tcx, ty::Predicate<'tcx>> {
+    elaborate(tcx, std::iter::once(trait_ref.without_const().to_predicate(tcx)))
 }
 
-pub fn elaborate_predicates_with_span<'tcx>(
+pub fn elaborate_trait_refs<'tcx>(
     tcx: TyCtxt<'tcx>,
-    predicates: impl Iterator<Item = (ty::Predicate<'tcx>, Span)>,
-) -> impl Iterator<Item = (ty::Predicate<'tcx>, Span)> {
-    elaborate_obligations(
-        tcx,
-        predicates
-            .map(|(predicate, span)| {
-                Obligation::new(
-                    tcx,
-                    // We'll dump the cause/param-env later
-                    ObligationCause::dummy_with_span(span),
-                    ty::ParamEnv::empty(),
-                    predicate,
-                )
-            })
-            .collect(),
-    )
-    .map(|obl| (obl.predicate, obl.cause.span))
+    trait_refs: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
+) -> Elaborator<'tcx, ty::Predicate<'tcx>> {
+    elaborate(tcx, trait_refs.map(|trait_ref| trait_ref.to_predicate(tcx)))
 }
 
-pub fn elaborate_obligations<'tcx>(
+pub fn elaborate<'tcx, O: Elaboratable<'tcx>>(
     tcx: TyCtxt<'tcx>,
-    obligations: Vec<PredicateObligation<'tcx>>,
-) -> Elaborator<'tcx> {
+    obligations: impl IntoIterator<Item = O>,
+) -> Elaborator<'tcx, O> {
     let mut elaborator = Elaborator { stack: Vec::new(), visited: PredicateSet::new(tcx) };
     elaborator.extend_deduped(obligations);
     elaborator
 }
 
-fn predicate_obligation<'tcx>(
-    predicate: ty::Predicate<'tcx>,
-    param_env: ty::ParamEnv<'tcx>,
-    cause: ObligationCause<'tcx>,
-) -> PredicateObligation<'tcx> {
-    Obligation { cause, param_env, recursion_depth: 0, predicate }
-}
-
-impl<'tcx> Elaborator<'tcx> {
-    fn extend_deduped(&mut self, obligations: impl IntoIterator<Item = PredicateObligation<'tcx>>) {
+impl<'tcx, O: Elaboratable<'tcx>> Elaborator<'tcx, O> {
+    fn extend_deduped(&mut self, obligations: impl IntoIterator<Item = O>) {
         // Only keep those bounds that we haven't already seen.
         // This is necessary to prevent infinite recursion in some
         // cases. One common case is when people define
         // `trait Sized: Sized { }` rather than `trait Sized { }`.
         // let visited = &mut self.visited;
-        self.stack.extend(obligations.into_iter().filter(|o| self.visited.insert(o.predicate)));
+        self.stack.extend(obligations.into_iter().filter(|o| self.visited.insert(o.predicate())));
     }
 
-    fn elaborate(&mut self, obligation: &PredicateObligation<'tcx>) {
+    fn elaborate(&mut self, elaboratable: &O) {
         let tcx = self.visited.tcx;
 
-        let bound_predicate = obligation.predicate.kind();
+        let bound_predicate = elaboratable.predicate().kind();
         match bound_predicate.skip_binder() {
             ty::PredicateKind::Clause(ty::Clause::Trait(data)) => {
                 // Get predicates declared on the trait.
@@ -170,24 +214,11 @@ impl<'tcx> Elaborator<'tcx> {
                         if data.constness == ty::BoundConstness::NotConst {
                             pred = pred.without_const(tcx);
                         }
-
-                        let cause = obligation.cause.clone().derived_cause(
-                            bound_predicate.rebind(data),
-                            |derived| {
-                                traits::ImplDerivedObligation(Box::new(
-                                    traits::ImplDerivedObligationCause {
-                                        derived,
-                                        impl_or_alias_def_id: data.def_id(),
-                                        impl_def_predicate_index: Some(index),
-                                        span,
-                                    },
-                                ))
-                            },
-                        );
-                        predicate_obligation(
+                        elaboratable.child_with_derived_cause(
                             pred.subst_supertrait(tcx, &bound_predicate.rebind(data.trait_ref)),
-                            obligation.param_env,
-                            cause,
+                            span,
+                            bound_predicate.rebind(data),
+                            index,
                         )
                     });
                 debug!(?data, ?obligations, "super_predicates");
@@ -290,13 +321,7 @@ impl<'tcx> Elaborator<'tcx> {
                         .map(|predicate_kind| {
                             bound_predicate.rebind(predicate_kind).to_predicate(tcx)
                         })
-                        .map(|predicate| {
-                            predicate_obligation(
-                                predicate,
-                                obligation.param_env,
-                                obligation.cause.clone(),
-                            )
-                        }),
+                        .map(|predicate| elaboratable.child(predicate)),
                 );
             }
             ty::PredicateKind::TypeWellFormedFromEnv(..) => {
@@ -313,8 +338,8 @@ impl<'tcx> Elaborator<'tcx> {
     }
 }
 
-impl<'tcx> Iterator for Elaborator<'tcx> {
-    type Item = PredicateObligation<'tcx>;
+impl<'tcx, O: Elaboratable<'tcx>> Iterator for Elaborator<'tcx, O> {
+    type Item = O;
 
     fn size_hint(&self) -> (usize, Option<usize>) {
         (self.stack.len(), None)