about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs12
-rw-r--r--crates/hir-expand/src/builtin_derive_macro.rs131
-rw-r--r--crates/hir-ty/src/tests/traits.rs60
-rw-r--r--crates/ide/src/expand_macro.rs8
-rw-r--r--crates/test-utils/src/minicore.rs6
5 files changed, 187 insertions, 30 deletions
diff --git a/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs b/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs
index fafcde25ae7..37cf348c92d 100644
--- a/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs
+++ b/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs
@@ -16,7 +16,7 @@ struct Foo;
 #[derive(Copy)]
 struct Foo;
 
-impl < > core::marker::Copy for Foo< > {}"#]],
+impl < > core::marker::Copy for Foo< > where {}"#]],
     );
 }
 
@@ -41,7 +41,7 @@ macro Copy {}
 #[derive(Copy)]
 struct Foo;
 
-impl < > crate ::marker::Copy for Foo< > {}"#]],
+impl < > crate ::marker::Copy for Foo< > where {}"#]],
     );
 }
 
@@ -57,7 +57,7 @@ struct Foo<A, B>;
 #[derive(Copy)]
 struct Foo<A, B>;
 
-impl <T0: core::marker::Copy, T1: core::marker::Copy, > core::marker::Copy for Foo<T0, T1, > {}"#]],
+impl <A: core::marker::Copy, B: core::marker::Copy, > core::marker::Copy for Foo<A, B, > where {}"#]],
     );
 }
 
@@ -74,7 +74,7 @@ struct Foo<A, B, 'a, 'b>;
 #[derive(Copy)]
 struct Foo<A, B, 'a, 'b>;
 
-impl <T0: core::marker::Copy, T1: core::marker::Copy, > core::marker::Copy for Foo<T0, T1, > {}"#]],
+impl <A: core::marker::Copy, B: core::marker::Copy, > core::marker::Copy for Foo<A, B, > where {}"#]],
     );
 }
 
@@ -90,7 +90,7 @@ struct Foo<A, B>;
 #[derive(Clone)]
 struct Foo<A, B>;
 
-impl <T0: core::clone::Clone, T1: core::clone::Clone, > core::clone::Clone for Foo<T0, T1, > {}"#]],
+impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Foo<A, B, > where {}"#]],
     );
 }
 
@@ -106,6 +106,6 @@ struct Foo<const X: usize, T>(u32);
 #[derive(Clone)]
 struct Foo<const X: usize, T>(u32);
 
-impl <const T0: usize, T1: core::clone::Clone, > core::clone::Clone for Foo<T0, T1, > {}"#]],
+impl <const X: usize, T: core::clone::Clone, > core::clone::Clone for Foo<X, T, > where {}"#]],
     );
 }
diff --git a/crates/hir-expand/src/builtin_derive_macro.rs b/crates/hir-expand/src/builtin_derive_macro.rs
index 5c1a75132ee..7e753663c01 100644
--- a/crates/hir-expand/src/builtin_derive_macro.rs
+++ b/crates/hir-expand/src/builtin_derive_macro.rs
@@ -1,11 +1,12 @@
 //! Builtin derives.
 
 use base_db::{CrateOrigin, LangCrateOrigin};
+use std::collections::HashSet;
 use tracing::debug;
 
 use crate::tt::{self, TokenId};
 use syntax::{
-    ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName},
+    ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName, HasTypeBounds, PathType},
     match_ast,
 };
 
@@ -60,8 +61,11 @@ pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander>
 
 struct BasicAdtInfo {
     name: tt::Ident,
-    /// `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
-    param_types: Vec<Option<tt::Subtree>>,
+    /// first field is the name, and
+    /// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
+    /// third fields is where bounds, if any
+    param_types: Vec<(tt::Subtree, Option<tt::Subtree>, Option<tt::Subtree>)>,
+    associated_types: Vec<tt::Subtree>,
 }
 
 fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
@@ -86,18 +90,28 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
             },
         }
     };
-    let name = name.ok_or_else(|| {
-        debug!("parsed item has no name");
-        ExpandError::Other("missing name".into())
-    })?;
-    let name_token_id =
-        token_map.token_by_range(name.syntax().text_range()).unwrap_or_else(TokenId::unspecified);
-    let name_token = tt::Ident { span: name_token_id, text: name.text().into() };
+    let mut param_type_set: HashSet<String> = HashSet::new();
     let param_types = params
         .into_iter()
         .flat_map(|param_list| param_list.type_or_const_params())
         .map(|param| {
-            if let ast::TypeOrConstParam::Const(param) = param {
+            let name = {
+                let this = param.name();
+                match this {
+                    Some(x) => {
+                        param_type_set.insert(x.to_string());
+                        mbe::syntax_node_to_token_tree(x.syntax()).0
+                    }
+                    None => tt::Subtree::empty(),
+                }
+            };
+            let bounds = match &param {
+                ast::TypeOrConstParam::Type(x) => {
+                    x.type_bound_list().map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
+                }
+                ast::TypeOrConstParam::Const(_) => None,
+            };
+            let ty = if let ast::TypeOrConstParam::Const(param) = param {
                 let ty = param
                     .ty()
                     .map(|ty| mbe::syntax_node_to_token_tree(ty.syntax()).0)
@@ -105,27 +119,97 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
                 Some(ty)
             } else {
                 None
-            }
+            };
+            (name, ty, bounds)
         })
         .collect();
-    Ok(BasicAdtInfo { name: name_token, param_types })
+    let is_associated_type = |p: &PathType| {
+        if let Some(p) = p.path() {
+            if let Some(parent) = p.qualifier() {
+                if let Some(x) = parent.segment() {
+                    if let Some(x) = x.path_type() {
+                        if let Some(x) = x.path() {
+                            if let Some(pname) = x.as_single_name_ref() {
+                                if param_type_set.contains(&pname.to_string()) {
+                                    // <T as Trait>::Assoc
+                                    return true;
+                                }
+                            }
+                        }
+                    }
+                }
+                if let Some(pname) = parent.as_single_name_ref() {
+                    if param_type_set.contains(&pname.to_string()) {
+                        // T::Assoc
+                        return true;
+                    }
+                }
+            }
+        }
+        false
+    };
+    let associated_types = node
+        .descendants()
+        .filter_map(PathType::cast)
+        .filter(is_associated_type)
+        .map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
+        .collect::<Vec<_>>();
+    let name = name.ok_or_else(|| {
+        debug!("parsed item has no name");
+        ExpandError::Other("missing name".into())
+    })?;
+    let name_token_id =
+        token_map.token_by_range(name.syntax().text_range()).unwrap_or_else(TokenId::unspecified);
+    let name_token = tt::Ident { span: name_token_id, text: name.text().into() };
+    Ok(BasicAdtInfo { name: name_token, param_types, associated_types })
 }
 
+/// Given that we are deriving a trait `DerivedTrait` for a type like:
+///
+/// ```ignore (only-for-syntax-highlight)
+/// struct Struct<'a, ..., 'z, A, B: DeclaredTrait, C, ..., Z> where C: WhereTrait {
+///     a: A,
+///     b: B::Item,
+///     b1: <B as DeclaredTrait>::Item,
+///     c1: <C as WhereTrait>::Item,
+///     c2: Option<<C as WhereTrait>::Item>,
+///     ...
+/// }
+/// ```
+///
+/// create an impl like:
+///
+/// ```ignore (only-for-syntax-highlight)
+/// impl<'a, ..., 'z, A, B: DeclaredTrait, C, ... Z> where
+///     C:                       WhereTrait,
+///     A: DerivedTrait + B1 + ... + BN,
+///     B: DerivedTrait + B1 + ... + BN,
+///     C: DerivedTrait + B1 + ... + BN,
+///     B::Item:                 DerivedTrait + B1 + ... + BN,
+///     <C as WhereTrait>::Item: DerivedTrait + B1 + ... + BN,
+///     ...
+/// {
+///     ...
+/// }
+/// ```
+///
+/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
+/// therefore does not get bound by the derived trait.
 fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResult<tt::Subtree> {
     let info = match parse_adt(tt) {
         Ok(info) => info,
         Err(e) => return ExpandResult::with_err(tt::Subtree::empty(), e),
     };
+    let mut where_block = vec![];
     let (params, args): (Vec<_>, Vec<_>) = info
         .param_types
         .into_iter()
-        .enumerate()
-        .map(|(idx, param_ty)| {
-            let ident = tt::Leaf::Ident(tt::Ident {
-                span: tt::TokenId::unspecified(),
-                text: format!("T{idx}").into(),
-            });
+        .map(|(ident, param_ty, bound)| {
             let ident_ = ident.clone();
+            if let Some(b) = bound {
+                let ident = ident.clone();
+                where_block.push(quote! { #ident : #b , });
+            }
             if let Some(ty) = param_ty {
                 (quote! { const #ident : #ty , }, quote! { #ident_ , })
             } else {
@@ -134,9 +218,16 @@ fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResu
             }
         })
         .unzip();
+
+    where_block.extend(info.associated_types.iter().map(|x| {
+        let x = x.clone();
+        let bound = trait_path.clone();
+        quote! { #x : #bound , }
+    }));
+
     let name = info.name;
     let expanded = quote! {
-        impl < ##params > #trait_path for #name < ##args > {}
+        impl < ##params > #trait_path for #name < ##args > where ##where_block {}
     };
     ExpandResult::ok(expanded)
 }
diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs
index 161e8385ec7..3564ed41334 100644
--- a/crates/hir-ty/src/tests/traits.rs
+++ b/crates/hir-ty/src/tests/traits.rs
@@ -4315,3 +4315,63 @@ impl Trait for () {
     "#,
     );
 }
+
+#[test]
+fn derive_macro_bounds() {
+    check_types(
+        r#"
+        //- minicore: clone, derive
+        #[derive(Clone)]
+        struct Copy;
+        struct NotCopy;
+        #[derive(Clone)]
+        struct Generic<T>(T);
+        trait Tr {
+            type Assoc;
+        }
+        impl Tr for Copy {
+            type Assoc = NotCopy;
+        }
+        #[derive(Clone)]
+        struct AssocGeneric<T: Tr>(T::Assoc);
+
+        #[derive(Clone)]
+        struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
+
+        #[derive(Clone)]
+        struct AssocGeneric3<T: Tr>(Generic<T::Assoc>);
+
+        #[derive(Clone)]
+        struct Vec<T>();
+
+        #[derive(Clone)]
+        struct R1(Vec<R2>);
+        #[derive(Clone)]
+        struct R2(R1);
+
+        fn f() {
+            let x = (&Copy).clone();
+              //^ Copy
+            let x = (&NotCopy).clone();
+              //^ &NotCopy
+            let x = (&Generic(Copy)).clone();
+              //^ Generic<Copy>
+            let x = (&Generic(NotCopy)).clone();
+              //^ &Generic<NotCopy>
+            let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy);
+            let x = x.clone();
+              //^ &AssocGeneric<Copy>
+            let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
+            let x = x.clone();
+              //^ &AssocGeneric2<Copy>
+            let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy));
+            let x = x.clone();
+              //^ &AssocGeneric3<Copy>
+            let x = (&R1(Vec())).clone();
+              //^ R1
+            let x = (&R2(R1(Vec()))).clone();
+              //^ R2
+        }
+        "#,
+    );
+}
diff --git a/crates/ide/src/expand_macro.rs b/crates/ide/src/expand_macro.rs
index 4382af43438..91af5716ca5 100644
--- a/crates/ide/src/expand_macro.rs
+++ b/crates/ide/src/expand_macro.rs
@@ -471,7 +471,7 @@ struct Foo {}
 "#,
             expect![[r#"
                 Clone
-                impl < >core::clone::Clone for Foo< >{}
+                impl < >core::clone::Clone for Foo< >where{}
             "#]],
         );
     }
@@ -488,7 +488,7 @@ struct Foo {}
 "#,
             expect![[r#"
                 Copy
-                impl < >core::marker::Copy for Foo< >{}
+                impl < >core::marker::Copy for Foo< >where{}
             "#]],
         );
     }
@@ -504,7 +504,7 @@ struct Foo {}
 "#,
             expect![[r#"
                 Copy
-                impl < >core::marker::Copy for Foo< >{}
+                impl < >core::marker::Copy for Foo< >where{}
             "#]],
         );
         check(
@@ -516,7 +516,7 @@ struct Foo {}
 "#,
             expect![[r#"
                 Clone
-                impl < >core::clone::Clone for Foo< >{}
+                impl < >core::clone::Clone for Foo< >where{}
             "#]],
         );
     }
diff --git a/crates/test-utils/src/minicore.rs b/crates/test-utils/src/minicore.rs
index 118b9ad631b..167af32a2ea 100644
--- a/crates/test-utils/src/minicore.rs
+++ b/crates/test-utils/src/minicore.rs
@@ -143,6 +143,12 @@ pub mod clone {
     pub trait Clone: Sized {
         fn clone(&self) -> Self;
     }
+
+    impl<T> Clone for &T {
+        fn clone(&self) -> Self {
+            *self
+        }
+    }
     // region:derive
     #[rustc_builtin_macro]
     pub macro Clone($item:item) {}