about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMatthias Krüger <matthias.krueger@famsik.de>2024-07-31 15:36:29 +0200
committerGitHub <noreply@github.com>2024-07-31 15:36:29 +0200
commit563f938ab30ce2d3dbebadbde9c8ad32d2452f78 (patch)
tree00009780d3c03b4cf2cc01e3e0a1a0cac071f12f
parent579eb684b9830c26852e5ba2acba2fa5eea3a6c5 (diff)
parente7f89a7eea766af95604405d666e089a42cd4b48 (diff)
downloadrust-563f938ab30ce2d3dbebadbde9c8ad32d2452f78.tar.gz
rust-563f938ab30ce2d3dbebadbde9c8ad32d2452f78.zip
Rollup merge of #127681 - dingxiangfei2009:smart-ptr-bounds, r=compiler-errors
derive(SmartPointer): rewrite bounds in where and generic bounds

Fix #127647

Due to the `Unsize` bounds, we need to commute the bounds on the pointee type to the new self type.

cc ```@Darksonn```
-rw-r--r--compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs244
-rw-r--r--tests/ui/deriving/deriving-smart-pointer-expanded.rs22
-rw-r--r--tests/ui/deriving/deriving-smart-pointer-expanded.stdout44
-rw-r--r--tests/ui/deriving/smart-pointer-bounds-issue-127647.rs78
4 files changed, 371 insertions, 17 deletions
diff --git a/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs b/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs
index bbc7cd39627..02555bd799c 100644
--- a/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs
+++ b/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs
@@ -1,14 +1,18 @@
 use std::mem::swap;
 
+use ast::ptr::P;
 use ast::HasAttrs;
+use rustc_ast::mut_visit::MutVisitor;
+use rustc_ast::visit::BoundKind;
 use rustc_ast::{
     self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem,
-    TraitBoundModifiers, VariantData,
+    TraitBoundModifiers, VariantData, WherePredicate,
 };
 use rustc_attr as attr;
+use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
 use rustc_expand::base::{Annotatable, ExtCtxt};
 use rustc_span::symbol::{sym, Ident};
-use rustc_span::Span;
+use rustc_span::{Span, Symbol};
 use smallvec::{smallvec, SmallVec};
 use thin_vec::{thin_vec, ThinVec};
 
@@ -141,33 +145,239 @@ pub fn expand_deriving_smart_ptr(
     alt_self_params[pointee_param_idx] = GenericArg::Type(s_ty.clone());
     let alt_self_type = cx.ty_path(cx.path_all(span, false, vec![name_ident], alt_self_params));
 
+    // # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location
+    //
     // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
     let mut impl_generics = generics.clone();
+    let pointee_ty_ident = generics.params[pointee_param_idx].ident;
+    let mut self_bounds;
     {
-        let p = &mut impl_generics.params[pointee_param_idx];
+        let pointee = &mut impl_generics.params[pointee_param_idx];
+        self_bounds = pointee.bounds.clone();
         let arg = GenericArg::Type(s_ty.clone());
         let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
-        p.bounds.push(cx.trait_bound(unsize, false));
+        pointee.bounds.push(cx.trait_bound(unsize, false));
         let mut attrs = thin_vec![];
-        swap(&mut p.attrs, &mut attrs);
-        p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect();
+        swap(&mut pointee.attrs, &mut attrs);
+        // Drop `#[pointee]` attribute since it should not be recognized outside `derive(SmartPointer)`
+        pointee.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect();
     }
 
-    // Add the `__S: ?Sized` extra parameter to the impl block.
+    // # Rewrite generic parameter bounds
+    // For each bound `U: ..` in `struct<U: ..>`, make a new bound with `__S` in place of `#[pointee]`
+    // Example:
+    // ```
+    // struct<
+    //     U: Trait<T>,
+    //     #[pointee] T: Trait<T>,
+    //     V: Trait<T>> ...
+    // ```
+    // ... generates this `impl` generic parameters
+    // ```
+    // impl<
+    //     U: Trait<T> + Trait<__S>,
+    //     T: Trait<T> + Unsize<__S>, // (**)
+    //     __S: Trait<__S> + ?Sized, // (*)
+    //     V: Trait<T> + Trait<__S>> ...
+    // ```
+    // The new bound marked with (*) has to be done separately.
+    // See next section
+    for (idx, (params, orig_params)) in
+        impl_generics.params.iter_mut().zip(&generics.params).enumerate()
+    {
+        // Default type parameters are rejected for `impl` block.
+        // We should drop them now.
+        match &mut params.kind {
+            ast::GenericParamKind::Const { default, .. } => *default = None,
+            ast::GenericParamKind::Type { default } => *default = None,
+            ast::GenericParamKind::Lifetime => {}
+        }
+        // We CANNOT rewrite `#[pointee]` type parameter bounds.
+        // This has been set in stone. (**)
+        // So we skip over it.
+        // Otherwise, we push extra bounds involving `__S`.
+        if idx != pointee_param_idx {
+            for bound in &orig_params.bounds {
+                let mut bound = bound.clone();
+                let mut substitution = TypeSubstitution {
+                    from_name: pointee_ty_ident.name,
+                    to_ty: &s_ty,
+                    rewritten: false,
+                };
+                substitution.visit_param_bound(&mut bound, BoundKind::Bound);
+                if substitution.rewritten {
+                    // We found use of `#[pointee]` somewhere,
+                    // so we make a new bound using `__S` in place of `#[pointee]`
+                    params.bounds.push(bound);
+                }
+            }
+        }
+    }
+
+    // # Insert `__S` type parameter
+    //
+    // We now insert `__S` with the missing bounds marked with (*) above.
+    // We should also write the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`.
     let sized = cx.path_global(span, path!(span, core::marker::Sized));
-    let bound = GenericBound::Trait(
-        cx.poly_trait_ref(span, sized),
-        TraitBoundModifiers {
-            polarity: ast::BoundPolarity::Maybe(span),
-            constness: ast::BoundConstness::Never,
-            asyncness: ast::BoundAsyncness::Normal,
-        },
-    );
-    let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None);
-    impl_generics.params.push(extra_param);
+    // For some reason, we are not allowed to write `?Sized` bound twice like `__S: ?Sized + ?Sized`.
+    if !contains_maybe_sized_bound(&self_bounds)
+        && !contains_maybe_sized_bound_on_pointee(
+            &generics.where_clause.predicates,
+            pointee_ty_ident.name,
+        )
+    {
+        self_bounds.push(GenericBound::Trait(
+            cx.poly_trait_ref(span, sized),
+            TraitBoundModifiers {
+                polarity: ast::BoundPolarity::Maybe(span),
+                constness: ast::BoundConstness::Never,
+                asyncness: ast::BoundAsyncness::Normal,
+            },
+        ));
+    }
+    {
+        let mut substitution =
+            TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false };
+        for bound in &mut self_bounds {
+            substitution.visit_param_bound(bound, BoundKind::Bound);
+        }
+    }
+
+    // # Rewrite `where` clauses
+    //
+    // Move on to `where` clauses.
+    // Example:
+    // ```
+    // struct MyPointer<#[pointee] T, ..>
+    // where
+    //   U: Trait<V> + Trait<T>,
+    //   Companion<T>: Trait<T>,
+    //   T: Trait<T>,
+    // { .. }
+    // ```
+    // ... will have a impl prelude like so
+    // ```
+    // impl<..> ..
+    // where
+    //   U: Trait<V> + Trait<T>,
+    //   U: Trait<__S>,
+    //   Companion<T>: Trait<T>,
+    //   Companion<__S>: Trait<__S>,
+    //   T: Trait<T>,
+    //   __S: Trait<__S>,
+    // ```
+    //
+    // We should also write a few new `where` bounds from `#[pointee] T` to `__S`
+    // as well as any bound that indirectly involves the `#[pointee] T` type.
+    for bound in &generics.where_clause.predicates {
+        if let ast::WherePredicate::BoundPredicate(bound) = bound {
+            let mut substitution = TypeSubstitution {
+                from_name: pointee_ty_ident.name,
+                to_ty: &s_ty,
+                rewritten: false,
+            };
+            let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate {
+                span: bound.span,
+                bound_generic_params: bound.bound_generic_params.clone(),
+                bounded_ty: bound.bounded_ty.clone(),
+                bounds: bound.bounds.clone(),
+            });
+            substitution.visit_where_predicate(&mut predicate);
+            if substitution.rewritten {
+                impl_generics.where_clause.predicates.push(predicate);
+            }
+        }
+    }
+
+    let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None);
+    impl_generics.params.insert(pointee_param_idx + 1, extra_param);
 
     // Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`.
     let gen_args = vec![GenericArg::Type(alt_self_type.clone())];
     add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
     add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone());
 }
+
+fn contains_maybe_sized_bound_on_pointee(predicates: &[WherePredicate], pointee: Symbol) -> bool {
+    for bound in predicates {
+        if let ast::WherePredicate::BoundPredicate(bound) = bound
+            && bound.bounded_ty.kind.is_simple_path().is_some_and(|name| name == pointee)
+        {
+            for bound in &bound.bounds {
+                if is_maybe_sized_bound(bound) {
+                    return true;
+                }
+            }
+        }
+    }
+    false
+}
+
+fn is_maybe_sized_bound(bound: &GenericBound) -> bool {
+    if let GenericBound::Trait(
+        trait_ref,
+        TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
+    ) = bound
+    {
+        is_sized_marker(&trait_ref.trait_ref.path)
+    } else {
+        false
+    }
+}
+
+fn contains_maybe_sized_bound(bounds: &[GenericBound]) -> bool {
+    bounds.iter().any(is_maybe_sized_bound)
+}
+
+fn path_segment_is_exact_match(path_segments: &[ast::PathSegment], syms: &[Symbol]) -> bool {
+    path_segments.iter().zip(syms).all(|(segment, &symbol)| segment.ident.name == symbol)
+}
+
+fn is_sized_marker(path: &ast::Path) -> bool {
+    const CORE_UNSIZE: [Symbol; 3] = [sym::core, sym::marker, sym::Sized];
+    const STD_UNSIZE: [Symbol; 3] = [sym::std, sym::marker, sym::Sized];
+    if path.segments.len() == 4 && path.is_global() {
+        path_segment_is_exact_match(&path.segments[1..], &CORE_UNSIZE)
+            || path_segment_is_exact_match(&path.segments[1..], &STD_UNSIZE)
+    } else if path.segments.len() == 3 {
+        path_segment_is_exact_match(&path.segments, &CORE_UNSIZE)
+            || path_segment_is_exact_match(&path.segments, &STD_UNSIZE)
+    } else {
+        *path == sym::Sized
+    }
+}
+
+struct TypeSubstitution<'a> {
+    from_name: Symbol,
+    to_ty: &'a ast::Ty,
+    rewritten: bool,
+}
+
+impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> {
+    fn visit_ty(&mut self, ty: &mut P<ast::Ty>) {
+        if let Some(name) = ty.kind.is_simple_path()
+            && name == self.from_name
+        {
+            **ty = self.to_ty.clone();
+            self.rewritten = true;
+        } else {
+            ast::mut_visit::walk_ty(self, ty);
+        }
+    }
+
+    fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) {
+        match where_predicate {
+            rustc_ast::WherePredicate::BoundPredicate(bound) => {
+                bound
+                    .bound_generic_params
+                    .flat_map_in_place(|param| self.flat_map_generic_param(param));
+                self.visit_ty(&mut bound.bounded_ty);
+                for bound in &mut bound.bounds {
+                    self.visit_param_bound(bound, BoundKind::Bound)
+                }
+            }
+            rustc_ast::WherePredicate::RegionPredicate(_)
+            | rustc_ast::WherePredicate::EqPredicate(_) => {}
+        }
+    }
+}
diff --git a/tests/ui/deriving/deriving-smart-pointer-expanded.rs b/tests/ui/deriving/deriving-smart-pointer-expanded.rs
new file mode 100644
index 00000000000..b78258c2529
--- /dev/null
+++ b/tests/ui/deriving/deriving-smart-pointer-expanded.rs
@@ -0,0 +1,22 @@
+//@ check-pass
+//@ compile-flags: -Zunpretty=expanded
+#![feature(derive_smart_pointer)]
+use std::marker::SmartPointer;
+
+pub trait MyTrait<T: ?Sized> {}
+
+#[derive(SmartPointer)]
+#[repr(transparent)]
+struct MyPointer<'a, #[pointee] T: ?Sized> {
+    ptr: &'a T,
+}
+
+#[derive(core::marker::SmartPointer)]
+#[repr(transparent)]
+pub struct MyPointer2<'a, Y, Z: MyTrait<T>, #[pointee] T: ?Sized + MyTrait<T>, X: MyTrait<T> = ()>
+where
+    Y: MyTrait<T>,
+{
+    data: &'a mut T,
+    x: core::marker::PhantomData<X>,
+}
diff --git a/tests/ui/deriving/deriving-smart-pointer-expanded.stdout b/tests/ui/deriving/deriving-smart-pointer-expanded.stdout
new file mode 100644
index 00000000000..3c7e7198180
--- /dev/null
+++ b/tests/ui/deriving/deriving-smart-pointer-expanded.stdout
@@ -0,0 +1,44 @@
+#![feature(prelude_import)]
+#![no_std]
+//@ check-pass
+//@ compile-flags: -Zunpretty=expanded
+#![feature(derive_smart_pointer)]
+#[prelude_import]
+use ::std::prelude::rust_2015::*;
+#[macro_use]
+extern crate std;
+use std::marker::SmartPointer;
+
+pub trait MyTrait<T: ?Sized> {}
+
+#[repr(transparent)]
+struct MyPointer<'a, #[pointee] T: ?Sized> {
+    ptr: &'a T,
+}
+#[automatically_derived]
+impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized>
+    ::core::ops::DispatchFromDyn<MyPointer<'a, __S>> for MyPointer<'a, T> {
+}
+#[automatically_derived]
+impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized>
+    ::core::ops::CoerceUnsized<MyPointer<'a, __S>> for MyPointer<'a, T> {
+}
+
+#[repr(transparent)]
+pub struct MyPointer2<'a, Y, Z: MyTrait<T>, #[pointee] T: ?Sized + MyTrait<T>,
+    X: MyTrait<T> = ()> where Y: MyTrait<T> {
+    data: &'a mut T,
+    x: core::marker::PhantomData<X>,
+}
+#[automatically_derived]
+impl<'a, Y, Z: MyTrait<T> + MyTrait<__S>, T: ?Sized + MyTrait<T> +
+    ::core::marker::Unsize<__S>, __S: ?Sized + MyTrait<__S>, X: MyTrait<T> +
+    MyTrait<__S>> ::core::ops::DispatchFromDyn<MyPointer2<'a, Y, Z, __S, X>>
+    for MyPointer2<'a, Y, Z, T, X> where Y: MyTrait<T>, Y: MyTrait<__S> {
+}
+#[automatically_derived]
+impl<'a, Y, Z: MyTrait<T> + MyTrait<__S>, T: ?Sized + MyTrait<T> +
+    ::core::marker::Unsize<__S>, __S: ?Sized + MyTrait<__S>, X: MyTrait<T> +
+    MyTrait<__S>> ::core::ops::CoerceUnsized<MyPointer2<'a, Y, Z, __S, X>> for
+    MyPointer2<'a, Y, Z, T, X> where Y: MyTrait<T>, Y: MyTrait<__S> {
+}
diff --git a/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs b/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs
new file mode 100644
index 00000000000..4cae1b32896
--- /dev/null
+++ b/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs
@@ -0,0 +1,78 @@
+//@ check-pass
+
+#![feature(derive_smart_pointer)]
+
+#[derive(core::marker::SmartPointer)]
+#[repr(transparent)]
+pub struct Ptr<'a, #[pointee] T: OnDrop + ?Sized, X> {
+    data: &'a mut T,
+    x: core::marker::PhantomData<X>,
+}
+
+pub trait OnDrop {
+    fn on_drop(&mut self);
+}
+
+#[derive(core::marker::SmartPointer)]
+#[repr(transparent)]
+pub struct Ptr2<'a, #[pointee] T: ?Sized, X>
+where
+    T: OnDrop,
+{
+    data: &'a mut T,
+    x: core::marker::PhantomData<X>,
+}
+
+pub trait MyTrait<T: ?Sized> {}
+
+#[derive(core::marker::SmartPointer)]
+#[repr(transparent)]
+pub struct Ptr3<'a, #[pointee] T: ?Sized, X>
+where
+    T: MyTrait<T>,
+{
+    data: &'a mut T,
+    x: core::marker::PhantomData<X>,
+}
+
+#[derive(core::marker::SmartPointer)]
+#[repr(transparent)]
+pub struct Ptr4<'a, #[pointee] T: MyTrait<T> + ?Sized, X> {
+    data: &'a mut T,
+    x: core::marker::PhantomData<X>,
+}
+
+#[derive(core::marker::SmartPointer)]
+#[repr(transparent)]
+pub struct Ptr5<'a, #[pointee] T: ?Sized, X>
+where
+    Ptr5Companion<T>: MyTrait<T>,
+    Ptr5Companion2: MyTrait<T>,
+{
+    data: &'a mut T,
+    x: core::marker::PhantomData<X>,
+}
+
+pub struct Ptr5Companion<T: ?Sized>(core::marker::PhantomData<T>);
+pub struct Ptr5Companion2;
+
+#[derive(core::marker::SmartPointer)]
+#[repr(transparent)]
+pub struct Ptr6<'a, #[pointee] T: ?Sized, X: MyTrait<T> = (), const PARAM: usize = 0> {
+    data: &'a mut T,
+    x: core::marker::PhantomData<X>,
+}
+
+// a reduced example from https://lore.kernel.org/all/20240402-linked-list-v1-1-b1c59ba7ae3b@google.com/
+#[repr(transparent)]
+#[derive(core::marker::SmartPointer)]
+pub struct ListArc<#[pointee] T, const ID: u64 = 0>
+where
+    T: ListArcSafe<ID> + ?Sized,
+{
+    arc: *const T,
+}
+
+pub trait ListArcSafe<const ID: u64> {}
+
+fn main() {}