about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMichael Goulet <michael@errs.io>2024-03-04 19:59:09 +0000
committerMichael Goulet <michael@errs.io>2024-09-25 13:13:04 -0400
commit2dacf7ac61bdea8321ff623a28dd2b56cf54701c (patch)
tree38ce8d25d95ff2a5cbe1d61f2c5f36f2c11e172a
parentb5117538e934f81e39eb9c326fdcc6574d144cb7 (diff)
downloadrust-2dacf7ac61bdea8321ff623a28dd2b56cf54701c.tar.gz
rust-2dacf7ac61bdea8321ff623a28dd2b56cf54701c.zip
Collect relevant item bounds from trait clauses for nested rigid projections, GATs
-rw-r--r--compiler/rustc_hir_analysis/src/collect/item_bounds.rs226
-rw-r--r--tests/ui/associated-types/imply-relevant-nested-item-bounds-2.rs28
-rw-r--r--tests/ui/associated-types/imply-relevant-nested-item-bounds-for-gat.rs19
-rw-r--r--tests/ui/associated-types/imply-relevant-nested-item-bounds.rs23
4 files changed, 286 insertions, 10 deletions
diff --git a/compiler/rustc_hir_analysis/src/collect/item_bounds.rs b/compiler/rustc_hir_analysis/src/collect/item_bounds.rs
index c64741625a4..3e2adaad370 100644
--- a/compiler/rustc_hir_analysis/src/collect/item_bounds.rs
+++ b/compiler/rustc_hir_analysis/src/collect/item_bounds.rs
@@ -1,8 +1,9 @@
-use rustc_data_structures::fx::FxIndexSet;
+use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
 use rustc_hir as hir;
 use rustc_infer::traits::util;
+use rustc_middle::ty::fold::shift_vars;
 use rustc_middle::ty::{
-    self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable,
+    self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
 };
 use rustc_middle::{bug, span_bug};
 use rustc_span::Span;
@@ -42,14 +43,110 @@ fn associated_type_bounds<'tcx>(
     let trait_def_id = tcx.local_parent(assoc_item_def_id);
     let trait_predicates = tcx.trait_explicit_predicates_and_bounds(trait_def_id);
 
-    let bounds_from_parent = trait_predicates.predicates.iter().copied().filter(|(pred, _)| {
-        match pred.kind().skip_binder() {
-            ty::ClauseKind::Trait(tr) => tr.self_ty() == item_ty,
-            ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty() == item_ty,
-            ty::ClauseKind::TypeOutlives(outlives) => outlives.0 == item_ty,
-            _ => false,
-        }
-    });
+    let item_trait_ref = ty::TraitRef::identity(tcx, tcx.parent(assoc_item_def_id.to_def_id()));
+    let bounds_from_parent =
+        trait_predicates.predicates.iter().copied().filter_map(|(pred, span)| {
+            let mut clause_ty = match pred.kind().skip_binder() {
+                ty::ClauseKind::Trait(tr) => tr.self_ty(),
+                ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty(),
+                ty::ClauseKind::TypeOutlives(outlives) => outlives.0,
+                _ => return None,
+            };
+
+            // The code below is quite involved, so let me explain.
+            //
+            // We loop here, because we also want to collect vars for nested associated items as
+            // well. For example, given a clause like `Self::A::B`, we want to add that to the
+            // item bounds for `A`, so that we may use that bound in the case that `Self::A::B` is
+            // rigid.
+            //
+            // Secondly, regarding bound vars, when we see a where clause that mentions a GAT
+            // like `for<'a, ...> Self::Assoc<'a, ...>: Bound<'b, ...>`, we want to turn that into
+            // an item bound on the GAT, where all of the GAT args are substituted with the GAT's
+            // param regions, and then keep all of the other late-bound vars in the bound around.
+            // We need to "compress" the binder so that it doesn't mention any of those vars that
+            // were mapped to params.
+            let gat_vars = loop {
+                if let ty::Alias(ty::Projection, alias_ty) = *clause_ty.kind() {
+                    if alias_ty.trait_ref(tcx) == item_trait_ref
+                        && alias_ty.def_id == assoc_item_def_id.to_def_id()
+                    {
+                        break &alias_ty.args[item_trait_ref.args.len()..];
+                    } else {
+                        // Only collect *self* type bounds if the filter is for self.
+                        match filter {
+                            PredicateFilter::SelfOnly | PredicateFilter::SelfThatDefines(_) => {
+                                return None;
+                            }
+                            PredicateFilter::All | PredicateFilter::SelfAndAssociatedTypeBounds => {
+                            }
+                        }
+
+                        clause_ty = alias_ty.self_ty();
+                        continue;
+                    }
+                }
+
+                return None;
+            };
+            // Special-case: No GAT vars, no mapping needed.
+            if gat_vars.is_empty() {
+                return Some((pred, span));
+            }
+
+            // First, check that all of the GAT args are substituted with a unique late-bound arg.
+            // If we find a duplicate, then it can't be mapped to the definition's params.
+            let mut mapping = FxIndexMap::default();
+            let generics = tcx.generics_of(assoc_item_def_id);
+            for (param, var) in std::iter::zip(&generics.own_params, gat_vars) {
+                let existing = match var.unpack() {
+                    ty::GenericArgKind::Lifetime(re) => {
+                        if let ty::RegionKind::ReBound(ty::INNERMOST, bv) = re.kind() {
+                            mapping.insert(bv.var, tcx.mk_param_from_def(param))
+                        } else {
+                            return None;
+                        }
+                    }
+                    ty::GenericArgKind::Type(ty) => {
+                        if let ty::Bound(ty::INNERMOST, bv) = *ty.kind() {
+                            mapping.insert(bv.var, tcx.mk_param_from_def(param))
+                        } else {
+                            return None;
+                        }
+                    }
+                    ty::GenericArgKind::Const(ct) => {
+                        if let ty::ConstKind::Bound(ty::INNERMOST, bv) = ct.kind() {
+                            mapping.insert(bv, tcx.mk_param_from_def(param))
+                        } else {
+                            return None;
+                        }
+                    }
+                };
+
+                if existing.is_some() {
+                    return None;
+                }
+            }
+
+            // Finally, map all of the args in the GAT to the params we expect, and compress
+            // the remaining late-bound vars so that they count up from var 0.
+            let mut folder = MapAndCompressBoundVars {
+                tcx,
+                binder: ty::INNERMOST,
+                still_bound_vars: vec![],
+                mapping,
+            };
+            let pred = pred.kind().skip_binder().fold_with(&mut folder);
+
+            Some((
+                ty::Binder::bind_with_vars(
+                    pred,
+                    tcx.mk_bound_variable_kinds(&folder.still_bound_vars),
+                )
+                .upcast(tcx),
+                span,
+            ))
+        });
 
     let all_bounds = tcx.arena.alloc_from_iter(bounds.clauses(tcx).chain(bounds_from_parent));
     debug!(
@@ -63,6 +160,115 @@ fn associated_type_bounds<'tcx>(
     all_bounds
 }
 
+struct MapAndCompressBoundVars<'tcx> {
+    tcx: TyCtxt<'tcx>,
+    /// How deep are we? Makes sure we don't touch the vars of nested binders.
+    binder: ty::DebruijnIndex,
+    /// List of bound vars that remain unsubstituted because they were not
+    /// mentioned in the GAT's args.
+    still_bound_vars: Vec<ty::BoundVariableKind>,
+    /// Subtle invariant: If the `GenericArg` is bound, then it should be
+    /// stored with the debruijn index of `INNERMOST` so it can be shifted
+    /// correctly during substitution.
+    mapping: FxIndexMap<ty::BoundVar, ty::GenericArg<'tcx>>,
+}
+
+impl<'tcx> TypeFolder<TyCtxt<'tcx>> for MapAndCompressBoundVars<'tcx> {
+    fn cx(&self) -> TyCtxt<'tcx> {
+        self.tcx
+    }
+
+    fn fold_binder<T>(&mut self, t: ty::Binder<'tcx, T>) -> ty::Binder<'tcx, T>
+    where
+        ty::Binder<'tcx, T>: TypeSuperFoldable<TyCtxt<'tcx>>,
+    {
+        self.binder.shift_in(1);
+        let out = t.super_fold_with(self);
+        self.binder.shift_out(1);
+        out
+    }
+
+    fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
+        if !ty.has_bound_vars() {
+            return ty;
+        }
+
+        if let ty::Bound(binder, old_bound) = *ty.kind()
+            && self.binder == binder
+        {
+            let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
+                mapped.expect_ty()
+            } else {
+                // If we didn't find a mapped generic, then make a new one.
+                // Allocate a new var idx, and insert a new bound ty.
+                let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
+                self.still_bound_vars.push(ty::BoundVariableKind::Ty(old_bound.kind));
+                let mapped = Ty::new_bound(self.tcx, ty::INNERMOST, ty::BoundTy {
+                    var,
+                    kind: old_bound.kind,
+                });
+                self.mapping.insert(old_bound.var, mapped.into());
+                mapped
+            };
+
+            shift_vars(self.tcx, mapped, self.binder.as_u32())
+        } else {
+            ty.super_fold_with(self)
+        }
+    }
+
+    fn fold_region(&mut self, re: ty::Region<'tcx>) -> ty::Region<'tcx> {
+        if let ty::ReBound(binder, old_bound) = re.kind()
+            && self.binder == binder
+        {
+            let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
+                mapped.expect_region()
+            } else {
+                let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
+                self.still_bound_vars.push(ty::BoundVariableKind::Region(old_bound.kind));
+                let mapped = ty::Region::new_bound(self.tcx, ty::INNERMOST, ty::BoundRegion {
+                    var,
+                    kind: old_bound.kind,
+                });
+                self.mapping.insert(old_bound.var, mapped.into());
+                mapped
+            };
+
+            shift_vars(self.tcx, mapped, self.binder.as_u32())
+        } else {
+            re
+        }
+    }
+
+    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
+        if !ct.has_bound_vars() {
+            return ct;
+        }
+
+        if let ty::ConstKind::Bound(binder, old_var) = ct.kind()
+            && self.binder == binder
+        {
+            let mapped = if let Some(mapped) = self.mapping.get(&old_var) {
+                mapped.expect_const()
+            } else {
+                let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
+                self.still_bound_vars.push(ty::BoundVariableKind::Const);
+                let mapped = ty::Const::new_bound(self.tcx, ty::INNERMOST, var);
+                self.mapping.insert(old_var, mapped.into());
+                mapped
+            };
+
+            shift_vars(self.tcx, mapped, self.binder.as_u32())
+        } else {
+            ct.super_fold_with(self)
+        }
+    }
+
+    fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
+        if !p.has_bound_vars() { p } else { p.super_fold_with(self) }
+    }
+}
+
 /// Opaque types don't inherit bounds from their parent: for return position
 /// impl trait it isn't possible to write a suitable predicate on the
 /// containing function and for type-alias impl trait we don't have a backwards
diff --git a/tests/ui/associated-types/imply-relevant-nested-item-bounds-2.rs b/tests/ui/associated-types/imply-relevant-nested-item-bounds-2.rs
new file mode 100644
index 00000000000..864c3189350
--- /dev/null
+++ b/tests/ui/associated-types/imply-relevant-nested-item-bounds-2.rs
@@ -0,0 +1,28 @@
+//@ check-pass
+//@ revisions: current next
+//@[next] compile-flags: -Znext-solver
+
+trait Trait
+where
+    Self::Assoc: Clone,
+{
+    type Assoc;
+}
+
+fn foo<T: Trait>(x: &T::Assoc) -> T::Assoc {
+    x.clone()
+}
+
+trait Trait2
+where
+    Self::Assoc: Iterator,
+    <Self::Assoc as Iterator>::Item: Clone,
+{
+    type Assoc;
+}
+
+fn foo2<T: Trait2>(x: &<T::Assoc as Iterator>::Item) -> <T::Assoc as Iterator>::Item {
+    x.clone()
+}
+
+fn main() {}
diff --git a/tests/ui/associated-types/imply-relevant-nested-item-bounds-for-gat.rs b/tests/ui/associated-types/imply-relevant-nested-item-bounds-for-gat.rs
new file mode 100644
index 00000000000..4e3b0b3b148
--- /dev/null
+++ b/tests/ui/associated-types/imply-relevant-nested-item-bounds-for-gat.rs
@@ -0,0 +1,19 @@
+//@ check-pass
+
+// Test that `for<'a> Self::Gat<'a>: Debug` is implied in the definition of `Foo`,
+// just as it would be if it weren't a GAT but just a regular associated type.
+
+use std::fmt::Debug;
+
+trait Foo
+where
+    for<'a> Self::Gat<'a>: Debug,
+{
+    type Gat<'a>;
+}
+
+fn test<T: Foo>(x: T::Gat<'static>) {
+    println!("{:?}", x);
+}
+
+fn main() {}
diff --git a/tests/ui/associated-types/imply-relevant-nested-item-bounds.rs b/tests/ui/associated-types/imply-relevant-nested-item-bounds.rs
new file mode 100644
index 00000000000..5a477a5b349
--- /dev/null
+++ b/tests/ui/associated-types/imply-relevant-nested-item-bounds.rs
@@ -0,0 +1,23 @@
+//@ check-pass
+//@ revisions: current next
+//@[next] compile-flags: -Znext-solver
+
+trait Foo
+where
+    Self::Iterator: Iterator,
+    <Self::Iterator as Iterator>::Item: Bar,
+{
+    type Iterator;
+
+    fn iter() -> Self::Iterator;
+}
+
+trait Bar {
+    fn bar(&self);
+}
+
+fn x<T: Foo>() {
+    T::iter().next().unwrap().bar();
+}
+
+fn main() {}