about summary refs log tree commit diff
diff options
context:
space:
mode:
authorSantiago Pastorino <spastorino@gmail.com>2020-11-03 17:07:18 -0300
committerSantiago Pastorino <spastorino@gmail.com>2020-11-27 11:23:47 -0300
commit24dcf6f7a29d7577a3c3448046d2d48b2fee59de (patch)
treede5838a155bbb04e5e9a23f3ddeb1a19a1a6404b
parent361543d776d832b42f022f5b3aa1ab77263bc4a9 (diff)
downloadrust-24dcf6f7a29d7577a3c3448046d2d48b2fee59de.tar.gz
rust-24dcf6f7a29d7577a3c3448046d2d48b2fee59de.zip
Allow to use super trait bounds in where clauses
-rw-r--r--compiler/rustc_middle/src/query/mod.rs2
-rw-r--r--compiler/rustc_middle/src/ty/query/keys.rs13
-rw-r--r--compiler/rustc_typeck/src/astconv/mod.rs19
-rw-r--r--compiler/rustc_typeck/src/check/fn_ctxt/mod.rs28
-rw-r--r--compiler/rustc_typeck/src/collect.rs113
-rw-r--r--src/test/ui/associated-type-bounds/super-trait-referencing.rs15
6 files changed, 167 insertions, 23 deletions
diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs
index 7822ecc2c1f..ed032220b54 100644
--- a/compiler/rustc_middle/src/query/mod.rs
+++ b/compiler/rustc_middle/src/query/mod.rs
@@ -438,7 +438,7 @@ rustc_queries! {
 
         /// To avoid cycles within the predicates of a single item we compute
         /// per-type-parameter predicates for resolving `T::AssocTy`.
-        query type_param_predicates(key: (DefId, LocalDefId)) -> ty::GenericPredicates<'tcx> {
+        query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
             desc { |tcx| "computing the bounds for type parameter `{}`", {
                 let id = tcx.hir().local_def_id_to_hir_id(key.1);
                 tcx.hir().ty_param_name(id)
diff --git a/compiler/rustc_middle/src/ty/query/keys.rs b/compiler/rustc_middle/src/ty/query/keys.rs
index a005990264c..339a068205c 100644
--- a/compiler/rustc_middle/src/ty/query/keys.rs
+++ b/compiler/rustc_middle/src/ty/query/keys.rs
@@ -7,7 +7,7 @@ use crate::ty::subst::{GenericArg, SubstsRef};
 use crate::ty::{self, Ty, TyCtxt};
 use rustc_hir::def_id::{CrateNum, DefId, LocalDefId, LOCAL_CRATE};
 use rustc_query_system::query::DefaultCacheSelector;
-use rustc_span::symbol::Symbol;
+use rustc_span::symbol::{Ident, Symbol};
 use rustc_span::{Span, DUMMY_SP};
 
 /// The `Key` trait controls what types can legally be used as the key
@@ -149,6 +149,17 @@ impl Key for (LocalDefId, DefId) {
     }
 }
 
+impl Key for (DefId, LocalDefId, Ident) {
+    type CacheSelector = DefaultCacheSelector;
+
+    fn query_crate(&self) -> CrateNum {
+        self.0.krate
+    }
+    fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
+        self.1.default_span(tcx)
+    }
+}
+
 impl Key for (CrateNum, DefId) {
     type CacheSelector = DefaultCacheSelector;
 
diff --git a/compiler/rustc_typeck/src/astconv/mod.rs b/compiler/rustc_typeck/src/astconv/mod.rs
index 2f64597a510..e891ea3403f 100644
--- a/compiler/rustc_typeck/src/astconv/mod.rs
+++ b/compiler/rustc_typeck/src/astconv/mod.rs
@@ -49,9 +49,10 @@ pub trait AstConv<'tcx> {
 
     fn default_constness_for_trait_bounds(&self) -> Constness;
 
-    /// Returns predicates in scope of the form `X: Foo`, where `X` is
-    /// a type parameter `X` with the given id `def_id`. This is a
-    /// subset of the full set of predicates.
+    /// Returns predicates in scope of the form `X: Foo<T>`, where `X`
+    /// is a type parameter `X` with the given id `def_id` and T
+    /// matches assoc_name. This is a subset of the full set of
+    /// predicates.
     ///
     /// This is used for one specific purpose: resolving "short-hand"
     /// associated type references like `T::Item`. In principle, we
@@ -60,7 +61,12 @@ pub trait AstConv<'tcx> {
     /// but this can lead to cycle errors. The problem is that we have
     /// to do this resolution *in order to create the predicates in
     /// the first place*. Hence, we have this "special pass".
-    fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx>;
+    fn get_type_parameter_bounds(
+        &self,
+        span: Span,
+        def_id: DefId,
+        assoc_name: Ident,
+    ) -> ty::GenericPredicates<'tcx>;
 
     /// Returns the lifetime to use when a lifetime is omitted (and not elided).
     fn re_infer(&self, param: Option<&ty::GenericParamDef>, span: Span)
@@ -1361,8 +1367,9 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
             ty_param_def_id, assoc_name, span,
         );
 
-        let predicates =
-            &self.get_type_parameter_bounds(span, ty_param_def_id.to_def_id()).predicates;
+        let predicates = &self
+            .get_type_parameter_bounds(span, ty_param_def_id.to_def_id(), assoc_name)
+            .predicates;
 
         debug!("find_bound_for_assoc_item: predicates={:#?}", predicates);
 
diff --git a/compiler/rustc_typeck/src/check/fn_ctxt/mod.rs b/compiler/rustc_typeck/src/check/fn_ctxt/mod.rs
index f635e0b6f93..96ebb781f0b 100644
--- a/compiler/rustc_typeck/src/check/fn_ctxt/mod.rs
+++ b/compiler/rustc_typeck/src/check/fn_ctxt/mod.rs
@@ -20,6 +20,7 @@ use rustc_middle::ty::fold::TypeFoldable;
 use rustc_middle::ty::subst::GenericArgKind;
 use rustc_middle::ty::{self, Const, Ty, TyCtxt};
 use rustc_session::Session;
+use rustc_span::symbol::Ident;
 use rustc_span::{self, Span};
 use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};
 
@@ -183,7 +184,12 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
         }
     }
 
-    fn get_type_parameter_bounds(&self, _: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
+    fn get_type_parameter_bounds(
+        &self,
+        _: Span,
+        def_id: DefId,
+        assoc_name: Ident,
+    ) -> ty::GenericPredicates<'tcx> {
         let tcx = self.tcx;
         let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
         let item_id = tcx.hir().ty_param_owner(hir_id);
@@ -196,9 +202,23 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
                 self.param_env.caller_bounds().iter().filter_map(|predicate| {
                     match predicate.skip_binders() {
                         ty::PredicateAtom::Trait(data, _) if data.self_ty().is_param(index) => {
-                            // HACK(eddyb) should get the original `Span`.
-                            let span = tcx.def_span(def_id);
-                            Some((predicate, span))
+                            let trait_did = data.def_id();
+                            if tcx
+                                .associated_items(trait_did)
+                                .find_by_name_and_kind(
+                                    tcx,
+                                    assoc_name,
+                                    ty::AssocKind::Type,
+                                    trait_did,
+                                )
+                                .is_some()
+                            {
+                                // HACK(eddyb) should get the original `Span`.
+                                let span = tcx.def_span(def_id);
+                                Some((predicate, span))
+                            } else {
+                                None
+                            }
                         }
                         _ => None,
                     }
diff --git a/compiler/rustc_typeck/src/collect.rs b/compiler/rustc_typeck/src/collect.rs
index dee0e6c2ebb..756147ca54c 100644
--- a/compiler/rustc_typeck/src/collect.rs
+++ b/compiler/rustc_typeck/src/collect.rs
@@ -310,8 +310,17 @@ impl AstConv<'tcx> for ItemCtxt<'tcx> {
         }
     }
 
-    fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
-        self.tcx.at(span).type_param_predicates((self.item_def_id, def_id.expect_local()))
+    fn get_type_parameter_bounds(
+        &self,
+        span: Span,
+        def_id: DefId,
+        assoc_name: Ident,
+    ) -> ty::GenericPredicates<'tcx> {
+        self.tcx.at(span).type_param_predicates((
+            self.item_def_id,
+            def_id.expect_local(),
+            assoc_name,
+        ))
     }
 
     fn re_infer(&self, _: Option<&ty::GenericParamDef>, _: Span) -> Option<ty::Region<'tcx>> {
@@ -492,7 +501,7 @@ fn get_new_lifetime_name<'tcx>(
 /// `X: Foo` where `X` is the type parameter `def_id`.
 fn type_param_predicates(
     tcx: TyCtxt<'_>,
-    (item_def_id, def_id): (DefId, LocalDefId),
+    (item_def_id, def_id, assoc_name): (DefId, LocalDefId, Ident),
 ) -> ty::GenericPredicates<'_> {
     use rustc_hir::*;
 
@@ -517,7 +526,7 @@ fn type_param_predicates(
     let mut result = parent
         .map(|parent| {
             let icx = ItemCtxt::new(tcx, parent);
-            icx.get_type_parameter_bounds(DUMMY_SP, def_id.to_def_id())
+            icx.get_type_parameter_bounds(DUMMY_SP, def_id.to_def_id(), assoc_name)
         })
         .unwrap_or_default();
     let mut extend = None;
@@ -560,12 +569,18 @@ fn type_param_predicates(
 
     let icx = ItemCtxt::new(tcx, item_def_id);
     let extra_predicates = extend.into_iter().chain(
-        icx.type_parameter_bounds_in_generics(ast_generics, param_id, ty, OnlySelfBounds(true))
-            .into_iter()
-            .filter(|(predicate, _)| match predicate.skip_binders() {
-                ty::PredicateAtom::Trait(data, _) => data.self_ty().is_param(index),
-                _ => false,
-            }),
+        icx.type_parameter_bounds_in_generics(
+            ast_generics,
+            param_id,
+            ty,
+            OnlySelfBounds(true),
+            Some(assoc_name),
+        )
+        .into_iter()
+        .filter(|(predicate, _)| match predicate.skip_binders() {
+            ty::PredicateAtom::Trait(data, _) => data.self_ty().is_param(index),
+            _ => false,
+        }),
     );
     result.predicates =
         tcx.arena.alloc_from_iter(result.predicates.iter().copied().chain(extra_predicates));
@@ -583,6 +598,7 @@ impl ItemCtxt<'tcx> {
         param_id: hir::HirId,
         ty: Ty<'tcx>,
         only_self_bounds: OnlySelfBounds,
+        assoc_name: Option<Ident>,
     ) -> Vec<(ty::Predicate<'tcx>, Span)> {
         let constness = self.default_constness_for_trait_bounds();
         let from_ty_params = ast_generics
@@ -593,6 +609,10 @@ impl ItemCtxt<'tcx> {
                 _ => None,
             })
             .flat_map(|bounds| bounds.iter())
+            .filter(|b| match assoc_name {
+                Some(assoc_name) => self.bound_defines_assoc_item(b, assoc_name),
+                None => true,
+            })
             .flat_map(|b| predicates_from_bound(self, ty, b, constness));
 
         let from_where_clauses = ast_generics
@@ -611,12 +631,43 @@ impl ItemCtxt<'tcx> {
                 } else {
                     None
                 };
-                bp.bounds.iter().filter_map(move |b| bt.map(|bt| (bt, b)))
+                bp.bounds
+                    .iter()
+                    .filter(|b| match assoc_name {
+                        Some(assoc_name) => self.bound_defines_assoc_item(b, assoc_name),
+                        None => true,
+                    })
+                    .filter_map(move |b| bt.map(|bt| (bt, b)))
             })
             .flat_map(|(bt, b)| predicates_from_bound(self, bt, b, constness));
 
         from_ty_params.chain(from_where_clauses).collect()
     }
+
+    fn bound_defines_assoc_item(&self, b: &hir::GenericBound<'_>, assoc_name: Ident) -> bool {
+        debug!("bound_defines_assoc_item(b={:?}, assoc_name={:?})", b, assoc_name);
+
+        match b {
+            hir::GenericBound::Trait(poly_trait_ref, _) => {
+                let trait_ref = &poly_trait_ref.trait_ref;
+                let trait_did = trait_ref.trait_def_id().unwrap();
+                let traits_did = super_traits_of(self.tcx, trait_did);
+
+                traits_did.iter().any(|trait_did| {
+                    self.tcx
+                        .associated_items(*trait_did)
+                        .find_by_name_and_kind(
+                            self.tcx,
+                            assoc_name,
+                            ty::AssocKind::Type,
+                            *trait_did,
+                        )
+                        .is_some()
+                })
+            }
+            _ => false,
+        }
+    }
 }
 
 /// Tests whether this is the AST for a reference to the type
@@ -1017,6 +1068,7 @@ fn super_predicates_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> ty::GenericPredi
         item.hir_id,
         self_param_ty,
         OnlySelfBounds(!is_trait_alias),
+        None,
     );
 
     // Combine the two lists to form the complete set of superbounds:
@@ -1034,6 +1086,45 @@ fn super_predicates_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> ty::GenericPredi
     ty::GenericPredicates { parent: None, predicates: superbounds }
 }
 
+pub fn super_traits_of(tcx: TyCtxt<'_>, trait_def_id: DefId) -> impl Iterator<Item = DefId> {
+    let mut set = FxHashSet::default();
+    let mut stack = vec![trait_def_id];
+    while let Some(trait_did) = stack.pop() {
+        if !set.insert(trait_did) {
+            continue;
+        }
+
+        if trait_did.is_local() {
+            let trait_hir_id = tcx.hir().local_def_id_to_hir_id(trait_did.expect_local());
+
+            let item = match tcx.hir().get(trait_hir_id) {
+                Node::Item(item) => item,
+                _ => bug!("super_trait_of {} is not an item", trait_hir_id),
+            };
+
+            let supertraits = match item.kind {
+                hir::ItemKind::Trait(.., ref supertraits, _) => supertraits,
+                hir::ItemKind::TraitAlias(_, ref supertraits) => supertraits,
+                _ => span_bug!(item.span, "super_trait_of invoked on non-trait"),
+            };
+
+            for supertrait in supertraits.iter() {
+                let trait_ref = supertrait.trait_ref();
+                if let Some(trait_did) = trait_ref.and_then(|trait_ref| trait_ref.trait_def_id()) {
+                    stack.push(trait_did);
+                }
+            }
+        } else {
+            let generic_predicates = tcx.super_predicates_of(trait_did);
+            for (predicate, _) in generic_predicates.predicates {
+                if let ty::PredicateAtom::Trait(data, _) = predicate.skip_binders() {
+                    stack.push(data.def_id());
+                }
+            }
+        }
+    }
+}
+
 fn trait_def(tcx: TyCtxt<'_>, def_id: DefId) -> ty::TraitDef {
     let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
     let item = tcx.hir().expect_item(hir_id);
diff --git a/src/test/ui/associated-type-bounds/super-trait-referencing.rs b/src/test/ui/associated-type-bounds/super-trait-referencing.rs
new file mode 100644
index 00000000000..fde6b91e6c4
--- /dev/null
+++ b/src/test/ui/associated-type-bounds/super-trait-referencing.rs
@@ -0,0 +1,15 @@
+// check-pass
+trait Foo {
+    type Item;
+}
+
+trait Bar<T> {}
+
+fn baz<T>()
+where
+    T: Foo,
+    T: Bar<T::Item>,
+{
+}
+
+fn main() {}