about summary refs log tree commit diff
diff options
context:
space:
mode:
authorSantiago Pastorino <spastorino@gmail.com>2020-11-13 14:01:16 -0300
committerSantiago Pastorino <spastorino@gmail.com>2020-11-27 11:23:47 -0300
commit2ca4964db5d263a8f9222846bd70a7f26cf414cf (patch)
tree29aa746f704a7e8b8a45a63ed1832b17c11b40e2
parent24dcf6f7a29d7577a3c3448046d2d48b2fee59de (diff)
downloadrust-2ca4964db5d263a8f9222846bd70a7f26cf414cf.tar.gz
rust-2ca4964db5d263a8f9222846bd70a7f26cf414cf.zip
Allow to self reference associated types in where clauses
-rw-r--r--compiler/rustc_infer/src/traits/util.rs35
-rw-r--r--compiler/rustc_middle/src/query/mod.rs10
-rw-r--r--compiler/rustc_middle/src/ty/query/keys.rs11
-rw-r--r--compiler/rustc_trait_selection/src/traits/mod.rs3
-rw-r--r--compiler/rustc_typeck/src/astconv/mod.rs49
-rw-r--r--compiler/rustc_typeck/src/collect.rs141
-rw-r--r--compiler/rustc_typeck/src/collect/item_bounds.rs6
-rw-r--r--src/test/ui/associated-type-bounds/super-trait-referencing-self.rs12
8 files changed, 204 insertions, 63 deletions
diff --git a/compiler/rustc_infer/src/traits/util.rs b/compiler/rustc_infer/src/traits/util.rs
index 8273c2d291d..26f6fcd5dab 100644
--- a/compiler/rustc_infer/src/traits/util.rs
+++ b/compiler/rustc_infer/src/traits/util.rs
@@ -4,6 +4,7 @@ use crate::traits::{Obligation, ObligationCause, PredicateObligation};
 use rustc_data_structures::fx::FxHashSet;
 use rustc_middle::ty::outlives::Component;
 use rustc_middle::ty::{self, ToPredicate, TyCtxt, WithConstness};
+use rustc_span::symbol::Ident;
 
 pub fn anonymize_predicate<'tcx>(
     tcx: TyCtxt<'tcx>,
@@ -89,6 +90,32 @@ pub fn elaborate_trait_refs<'tcx>(
     elaborate_predicates(tcx, predicates)
 }
 
+pub fn elaborate_trait_refs_that_define_assoc_type<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    trait_refs: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
+    assoc_name: Ident,
+) -> FxHashSet<ty::PolyTraitRef<'tcx>> {
+    let mut stack: Vec<_> = trait_refs.collect();
+    let mut trait_refs = FxHashSet::default();
+
+    while let Some(trait_ref) = stack.pop() {
+        if trait_refs.insert(trait_ref) {
+            let super_predicates =
+                tcx.super_predicates_that_define_assoc_type((trait_ref.def_id(), Some(assoc_name)));
+            for (super_predicate, _) in super_predicates.predicates {
+                let bound_predicate = super_predicate.bound_atom();
+                let subst_predicate = super_predicate
+                    .subst_supertrait(tcx, &bound_predicate.rebind(trait_ref.skip_binder()));
+                if let Some(binder) = subst_predicate.to_opt_poly_trait_ref() {
+                    stack.push(binder.value);
+                }
+            }
+        }
+    }
+
+    trait_refs
+}
+
 pub fn elaborate_predicates<'tcx>(
     tcx: TyCtxt<'tcx>,
     predicates: impl Iterator<Item = ty::Predicate<'tcx>>,
@@ -287,6 +314,14 @@ pub fn transitive_bounds<'tcx>(
     elaborate_trait_refs(tcx, bounds).filter_to_traits()
 }
 
+pub fn transitive_bounds_that_define_assoc_type<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
+    assoc_name: Ident,
+) -> FxHashSet<ty::PolyTraitRef<'tcx>> {
+    elaborate_trait_refs_that_define_assoc_type(tcx, bounds, assoc_name)
+}
+
 ///////////////////////////////////////////////////////////////////////////
 // Other
 ///////////////////////////////////////////////////////////////////////////
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index ed032220b54..c19943e1f8c 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -436,6 +436,16 @@ rustc_queries! {
             desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key) }
         }
 
+        /// Maps from the `DefId` of a trait to the list of
+        /// super-predicates. This is a subset of the full list of
+        /// predicates. We store these in a separate map because we must
+        /// evaluate them even during type conversion, often before the
+        /// full predicates are available (note that supertraits have
+        /// additional acyclicity requirements).
+        query super_predicates_that_define_assoc_type(key: (DefId, Option<rustc_span::symbol::Ident>)) -> ty::GenericPredicates<'tcx> {
+            desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key.0) }
+        }
+
         /// To avoid cycles within the predicates of a single item we compute
         /// per-type-parameter predicates for resolving `T::AssocTy`.
         query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
diff --git a/compiler/rustc_middle/src/ty/query/keys.rs b/compiler/rustc_middle/src/ty/query/keys.rs
index 339a068205c..3949c303f72 100644
--- a/compiler/rustc_middle/src/ty/query/keys.rs
+++ b/compiler/rustc_middle/src/ty/query/keys.rs
@@ -149,6 +149,17 @@ impl Key for (LocalDefId, DefId) {
     }
 }
 
+impl Key for (DefId, Option<Ident>) {
+    type CacheSelector = DefaultCacheSelector;
+
+    fn query_crate(&self) -> CrateNum {
+        self.0.krate
+    }
+    fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
+        tcx.def_span(self.0)
+    }
+}
+
 impl Key for (DefId, LocalDefId, Ident) {
     type CacheSelector = DefaultCacheSelector;
 
diff --git a/compiler/rustc_trait_selection/src/traits/mod.rs b/compiler/rustc_trait_selection/src/traits/mod.rs
index 2d7df2ddd11..d0276288284 100644
--- a/compiler/rustc_trait_selection/src/traits/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/mod.rs
@@ -65,7 +65,8 @@ pub use self::util::{
     get_vtable_index_of_object_method, impl_item_is_final, predicate_for_trait_def, upcast_choices,
 };
 pub use self::util::{
-    supertrait_def_ids, supertraits, transitive_bounds, SupertraitDefIds, Supertraits,
+    supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_type,
+    SupertraitDefIds, Supertraits,
 };
 
 pub use self::chalk_fulfill::FulfillmentContext as ChalkFulfillmentContext;
diff --git a/compiler/rustc_typeck/src/astconv/mod.rs b/compiler/rustc_typeck/src/astconv/mod.rs
index e891ea3403f..d8fbcb633cd 100644
--- a/compiler/rustc_typeck/src/astconv/mod.rs
+++ b/compiler/rustc_typeck/src/astconv/mod.rs
@@ -6,6 +6,7 @@ mod errors;
 mod generics;
 
 use crate::bounds::Bounds;
+use crate::collect::super_traits_of;
 use crate::collect::PlaceholderHirTyCollector;
 use crate::errors::{
     AmbiguousLifetimeBound, MultipleRelaxedDefaultBounds, TraitObjectDeclaredWithNoTraits,
@@ -768,7 +769,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
     }
 
     // Returns `true` if a bounds list includes `?Sized`.
-    pub fn is_unsized(&self, ast_bounds: &[hir::GenericBound<'_>], span: Span) -> bool {
+    pub fn is_unsized(&self, ast_bounds: &[&hir::GenericBound<'_>], span: Span) -> bool {
         let tcx = self.tcx();
 
         // Try to find an unbound in bounds.
@@ -826,7 +827,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
     fn add_bounds(
         &self,
         param_ty: Ty<'tcx>,
-        ast_bounds: &[hir::GenericBound<'_>],
+        ast_bounds: &[&hir::GenericBound<'_>],
         bounds: &mut Bounds<'tcx>,
     ) {
         let mut trait_bounds = Vec::new();
@@ -844,7 +845,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
                 hir::GenericBound::Trait(_, hir::TraitBoundModifier::Maybe) => {}
                 hir::GenericBound::LangItemTrait(lang_item, span, hir_id, args) => self
                     .instantiate_lang_item_trait_ref(
-                        lang_item, span, hir_id, args, param_ty, bounds,
+                        *lang_item, *span, *hir_id, args, param_ty, bounds,
                     ),
                 hir::GenericBound::Outlives(ref l) => region_bounds.push(l),
             }
@@ -878,7 +879,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
     pub fn compute_bounds(
         &self,
         param_ty: Ty<'tcx>,
-        ast_bounds: &[hir::GenericBound<'_>],
+        ast_bounds: &[&hir::GenericBound<'_>],
         sized_by_default: SizedByDefault,
         span: Span,
     ) -> Bounds<'tcx> {
@@ -896,6 +897,39 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
         bounds
     }
 
+    pub fn compute_bounds_that_match_assoc_type(
+        &self,
+        param_ty: Ty<'tcx>,
+        ast_bounds: &[hir::GenericBound<'_>],
+        sized_by_default: SizedByDefault,
+        span: Span,
+        assoc_name: Ident,
+    ) -> Bounds<'tcx> {
+        let mut result = Vec::new();
+
+        for ast_bound in ast_bounds {
+            if let Some(trait_ref) = ast_bound.trait_ref() {
+                if let Some(trait_did) = trait_ref.trait_def_id() {
+                    if super_traits_of(self.tcx(), trait_did).any(|trait_did| {
+                        self.tcx()
+                            .associated_items(trait_did)
+                            .find_by_name_and_kind(
+                                self.tcx(),
+                                assoc_name,
+                                ty::AssocKind::Type,
+                                trait_did,
+                            )
+                            .is_some()
+                    }) {
+                        result.push(ast_bound);
+                    }
+                }
+            }
+        }
+
+        self.compute_bounds(param_ty, &result, sized_by_default, span)
+    }
+
     /// Given an HIR binding like `Item = Foo` or `Item: Foo`, pushes the corresponding predicates
     /// onto `bounds`.
     ///
@@ -1050,7 +1084,8 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
                 // Calling `skip_binder` is okay, because `add_bounds` expects the `param_ty`
                 // parameter to have a skipped binder.
                 let param_ty = tcx.mk_projection(assoc_ty.def_id, candidate.skip_binder().substs);
-                self.add_bounds(param_ty, ast_bounds, bounds);
+                let ast_bounds: Vec<_> = ast_bounds.iter().collect();
+                self.add_bounds(param_ty, &ast_bounds, bounds);
             }
         }
         Ok(())
@@ -1377,12 +1412,14 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
         let param_name = tcx.hir().ty_param_name(param_hir_id);
         self.one_bound_for_assoc_type(
             || {
-                traits::transitive_bounds(
+                traits::transitive_bounds_that_define_assoc_type(
                     tcx,
                     predicates.iter().filter_map(|(p, _)| {
                         p.to_opt_poly_trait_ref().map(|trait_ref| trait_ref.value)
                     }),
+                    assoc_name,
                 )
+                .into_iter()
             },
             || param_name.to_string(),
             assoc_name,
diff --git a/compiler/rustc_typeck/src/collect.rs b/compiler/rustc_typeck/src/collect.rs
index 756147ca54c..8b457c7ceec 100644
--- a/compiler/rustc_typeck/src/collect.rs
+++ b/compiler/rustc_typeck/src/collect.rs
@@ -1,3 +1,4 @@
+// ignore-tidy-filelength
 //! "Collection" is the process of determining the type and other external
 //! details of each item in Rust. Collection is specifically concerned
 //! with *inter-procedural* things -- for example, for a function
@@ -79,6 +80,7 @@ pub fn provide(providers: &mut Providers) {
         projection_ty_from_predicates,
         explicit_predicates_of,
         super_predicates_of,
+        super_predicates_that_define_assoc_type,
         trait_explicit_predicates_and_bounds,
         type_param_predicates,
         trait_def,
@@ -651,17 +653,10 @@ impl ItemCtxt<'tcx> {
             hir::GenericBound::Trait(poly_trait_ref, _) => {
                 let trait_ref = &poly_trait_ref.trait_ref;
                 let trait_did = trait_ref.trait_def_id().unwrap();
-                let traits_did = super_traits_of(self.tcx, trait_did);
-
-                traits_did.iter().any(|trait_did| {
+                super_traits_of(self.tcx, trait_did).any(|trait_did| {
                     self.tcx
-                        .associated_items(*trait_did)
-                        .find_by_name_and_kind(
-                            self.tcx,
-                            assoc_name,
-                            ty::AssocKind::Type,
-                            *trait_did,
-                        )
+                        .associated_items(trait_did)
+                        .find_by_name_and_kind(self.tcx, assoc_name, ty::AssocKind::Type, trait_did)
                         .is_some()
                 })
             }
@@ -1035,55 +1030,91 @@ fn adt_def(tcx: TyCtxt<'_>, def_id: DefId) -> &ty::AdtDef {
 /// the transitive super-predicates are converted.
 fn super_predicates_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> ty::GenericPredicates<'_> {
     debug!("super_predicates(trait_def_id={:?})", trait_def_id);
-    let trait_hir_id = tcx.hir().local_def_id_to_hir_id(trait_def_id.expect_local());
+    tcx.super_predicates_that_define_assoc_type((trait_def_id, None))
+}
 
-    let item = match tcx.hir().get(trait_hir_id) {
-        Node::Item(item) => item,
-        _ => bug!("trait_node_id {} is not an item", trait_hir_id),
-    };
+/// Ensures that the super-predicates of the trait with a `DefId`
+/// of `trait_def_id` are converted and stored. This also ensures that
+/// the transitive super-predicates are converted.
+fn super_predicates_that_define_assoc_type(
+    tcx: TyCtxt<'_>,
+    (trait_def_id, assoc_name): (DefId, Option<Ident>),
+) -> ty::GenericPredicates<'_> {
+    debug!(
+        "super_predicates_that_define_assoc_type(trait_def_id={:?}, assoc_name={:?})",
+        trait_def_id, assoc_name
+    );
+    if trait_def_id.is_local() {
+        debug!("super_predicates_that_define_assoc_type: local trait_def_id={:?}", trait_def_id);
+        let trait_hir_id = tcx.hir().local_def_id_to_hir_id(trait_def_id.expect_local());
 
-    let (generics, bounds) = match item.kind {
-        hir::ItemKind::Trait(.., ref generics, ref supertraits, _) => (generics, supertraits),
-        hir::ItemKind::TraitAlias(ref generics, ref supertraits) => (generics, supertraits),
-        _ => span_bug!(item.span, "super_predicates invoked on non-trait"),
-    };
+        let item = match tcx.hir().get(trait_hir_id) {
+            Node::Item(item) => item,
+            _ => bug!("trait_node_id {} is not an item", trait_hir_id),
+        };
 
-    let icx = ItemCtxt::new(tcx, trait_def_id);
-
-    // Convert the bounds that follow the colon, e.g., `Bar + Zed` in `trait Foo: Bar + Zed`.
-    let self_param_ty = tcx.types.self_param;
-    let superbounds1 =
-        AstConv::compute_bounds(&icx, self_param_ty, bounds, SizedByDefault::No, item.span);
-
-    let superbounds1 = superbounds1.predicates(tcx, self_param_ty);
-
-    // Convert any explicit superbounds in the where-clause,
-    // e.g., `trait Foo where Self: Bar`.
-    // In the case of trait aliases, however, we include all bounds in the where-clause,
-    // so e.g., `trait Foo = where u32: PartialEq<Self>` would include `u32: PartialEq<Self>`
-    // as one of its "superpredicates".
-    let is_trait_alias = tcx.is_trait_alias(trait_def_id);
-    let superbounds2 = icx.type_parameter_bounds_in_generics(
-        generics,
-        item.hir_id,
-        self_param_ty,
-        OnlySelfBounds(!is_trait_alias),
-        None,
-    );
+        let (generics, bounds) = match item.kind {
+            hir::ItemKind::Trait(.., ref generics, ref supertraits, _) => (generics, supertraits),
+            hir::ItemKind::TraitAlias(ref generics, ref supertraits) => (generics, supertraits),
+            _ => span_bug!(item.span, "super_predicates invoked on non-trait"),
+        };
 
-    // Combine the two lists to form the complete set of superbounds:
-    let superbounds = &*tcx.arena.alloc_from_iter(superbounds1.into_iter().chain(superbounds2));
+        let icx = ItemCtxt::new(tcx, trait_def_id);
 
-    // Now require that immediate supertraits are converted,
-    // which will, in turn, reach indirect supertraits.
-    for &(pred, span) in superbounds {
-        debug!("superbound: {:?}", pred);
-        if let ty::PredicateAtom::Trait(bound, _) = pred.skip_binders() {
-            tcx.at(span).super_predicates_of(bound.def_id());
+        // Convert the bounds that follow the colon, e.g., `Bar + Zed` in `trait Foo: Bar + Zed`.
+        let self_param_ty = tcx.types.self_param;
+        let superbounds1 = if let Some(assoc_name) = assoc_name {
+            AstConv::compute_bounds_that_match_assoc_type(
+                &icx,
+                self_param_ty,
+                &bounds,
+                SizedByDefault::No,
+                item.span,
+                assoc_name,
+            )
+        } else {
+            let bounds: Vec<_> = bounds.iter().collect();
+            AstConv::compute_bounds(&icx, self_param_ty, &bounds, SizedByDefault::No, item.span)
+        };
+
+        let superbounds1 = superbounds1.predicates(tcx, self_param_ty);
+
+        // Convert any explicit superbounds in the where-clause,
+        // e.g., `trait Foo where Self: Bar`.
+        // In the case of trait aliases, however, we include all bounds in the where-clause,
+        // so e.g., `trait Foo = where u32: PartialEq<Self>` would include `u32: PartialEq<Self>`
+        // as one of its "superpredicates".
+        let is_trait_alias = tcx.is_trait_alias(trait_def_id);
+        let superbounds2 = icx.type_parameter_bounds_in_generics(
+            generics,
+            item.hir_id,
+            self_param_ty,
+            OnlySelfBounds(!is_trait_alias),
+            assoc_name,
+        );
+
+        // Combine the two lists to form the complete set of superbounds:
+        let superbounds = &*tcx.arena.alloc_from_iter(superbounds1.into_iter().chain(superbounds2));
+
+        // Now require that immediate supertraits are converted,
+        // which will, in turn, reach indirect supertraits.
+        if assoc_name.is_none() {
+            // FIXME: move this into the `super_predicates_of` query
+            for &(pred, span) in superbounds {
+                debug!("superbound: {:?}", pred);
+                if let ty::PredicateAtom::Trait(bound, _) = pred.skip_binders() {
+                    tcx.at(span).super_predicates_of(bound.def_id());
+                }
+            }
         }
-    }
 
-    ty::GenericPredicates { parent: None, predicates: superbounds }
+        ty::GenericPredicates { parent: None, predicates: superbounds }
+    } else {
+        // if `assoc_name` is None, then the query should've been redirected to an
+        // external provider
+        assert!(assoc_name.is_some());
+        tcx.super_predicates_of(trait_def_id)
+    }
 }
 
 pub fn super_traits_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> impl Iterator<Item = DefId> {
@@ -1123,6 +1154,8 @@ pub fn super_traits_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> impl Iterator<It
             }
         }
     }
+
+    set.into_iter()
 }
 
 fn trait_def(tcx: TyCtxt<'_>, def_id: DefId) -> ty::TraitDef {
@@ -1976,8 +2009,8 @@ fn gather_explicit_predicates_of(tcx: TyCtxt<'_>, def_id: DefId) -> ty::GenericP
                 index += 1;
 
                 let sized = SizedByDefault::Yes;
-                let bounds =
-                    AstConv::compute_bounds(&icx, param_ty, &param.bounds, sized, param.span);
+                let bounds: Vec<_> = param.bounds.iter().collect();
+                let bounds = AstConv::compute_bounds(&icx, param_ty, &bounds, sized, param.span);
                 predicates.extend(bounds.predicates(tcx, param_ty));
             }
             GenericParamKind::Const { .. } => {
diff --git a/compiler/rustc_typeck/src/collect/item_bounds.rs b/compiler/rustc_typeck/src/collect/item_bounds.rs
index e596dd1a396..62586d793b4 100644
--- a/compiler/rustc_typeck/src/collect/item_bounds.rs
+++ b/compiler/rustc_typeck/src/collect/item_bounds.rs
@@ -25,10 +25,11 @@ fn associated_type_bounds<'tcx>(
         InternalSubsts::identity_for_item(tcx, assoc_item_def_id),
     );
 
+    let bounds: Vec<_> = bounds.iter().collect();
     let bounds = AstConv::compute_bounds(
         &ItemCtxt::new(tcx, assoc_item_def_id),
         item_ty,
-        bounds,
+        &bounds,
         SizedByDefault::Yes,
         span,
     );
@@ -65,10 +66,11 @@ fn opaque_type_bounds<'tcx>(
         let item_ty =
             tcx.mk_opaque(opaque_def_id, InternalSubsts::identity_for_item(tcx, opaque_def_id));
 
+        let bounds: Vec<_> = bounds.iter().collect();
         let bounds = AstConv::compute_bounds(
             &ItemCtxt::new(tcx, opaque_def_id),
             item_ty,
-            bounds,
+            &bounds,
             SizedByDefault::Yes,
             span,
         )
diff --git a/src/test/ui/associated-type-bounds/super-trait-referencing-self.rs b/src/test/ui/associated-type-bounds/super-trait-referencing-self.rs
new file mode 100644
index 00000000000..c82ec01f4d6
--- /dev/null
+++ b/src/test/ui/associated-type-bounds/super-trait-referencing-self.rs
@@ -0,0 +1,12 @@
+// check-pass
+trait Foo {
+    type Bar;
+}
+trait Qux: Foo + AsRef<Self::Bar> {}
+trait Foo2 {}
+
+trait Qux2: Foo2 + AsRef<Self::Bar> {
+    type Bar;
+}
+
+fn main() {}