about summary refs log tree commit diff
diff options
context:
space:
mode:
authoraustaras <austaras@outlook.com>2024-02-04 11:35:27 +0800
committeraustaras <austaras@outlook.com>2024-02-08 22:44:39 +0800
commitdad0fdb13f84326326ca15b116dda62c0c482858 (patch)
tree61f8ddee5fe171dfccb9f3250ee6265b00b68ec1
parente9d3565cd11127f9df52b6824e4edb42a5bbd8bf (diff)
downloadrust-dad0fdb13f84326326ca15b116dda62c0c482858.tar.gz
rust-dad0fdb13f84326326ca15b116dda62c0c482858.zip
fix: preserve where clause when builtin derive
-rw-r--r--crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs2
-rw-r--r--crates/hir-expand/src/builtin_derive_macro.rs22
-rw-r--r--crates/hir-ty/src/tests/macros.rs31
3 files changed, 50 insertions, 5 deletions
diff --git a/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs b/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs
index 553c0b79533..86b4466153a 100644
--- a/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs
+++ b/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs
@@ -157,7 +157,7 @@ where
     generic: Vec<T::InGenericArg>,
 }
 
-impl <T: $crate::clone::Clone, > $crate::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: $crate::clone::Clone, T::InGenericArg: $crate::clone::Clone, {
+impl <T: $crate::clone::Clone, > $crate::clone::Clone for Foo<T, > where <T as Trait>::InWc: Marker, T: Trait, T::InFieldShorthand: $crate::clone::Clone, T::InGenericArg: $crate::clone::Clone, {
     fn clone(&self ) -> Self {
         match self {
             Foo {
diff --git a/crates/hir-expand/src/builtin_derive_macro.rs b/crates/hir-expand/src/builtin_derive_macro.rs
index 024fb8c1f61..27954875143 100644
--- a/crates/hir-expand/src/builtin_derive_macro.rs
+++ b/crates/hir-expand/src/builtin_derive_macro.rs
@@ -194,6 +194,7 @@ struct BasicAdtInfo {
     /// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
     /// third fields is where bounds, if any
     param_types: Vec<(tt::Subtree, Option<tt::Subtree>, Option<tt::Subtree>)>,
+    where_clause: Vec<tt::Subtree>,
     associated_types: Vec<tt::Subtree>,
 }
 
@@ -202,10 +203,11 @@ fn parse_adt(
     adt: &ast::Adt,
     call_site: Span,
 ) -> Result<BasicAdtInfo, ExpandError> {
-    let (name, generic_param_list, shape) = match adt {
+    let (name, generic_param_list, where_clause, shape) = match adt {
         ast::Adt::Struct(it) => (
             it.name(),
             it.generic_param_list(),
+            it.where_clause(),
             AdtShape::Struct(VariantShape::from(tm, it.field_list())?),
         ),
         ast::Adt::Enum(it) => {
@@ -217,6 +219,7 @@ fn parse_adt(
             (
                 it.name(),
                 it.generic_param_list(),
+                it.where_clause(),
                 AdtShape::Enum {
                     default_variant,
                     variants: it
@@ -233,7 +236,9 @@ fn parse_adt(
                 },
             )
         }
-        ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
+        ast::Adt::Union(it) => {
+            (it.name(), it.generic_param_list(), it.where_clause(), AdtShape::Union)
+        }
     };
 
     let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
@@ -274,6 +279,14 @@ fn parse_adt(
         })
         .collect();
 
+    let where_clause = if let Some(w) = where_clause {
+        w.predicates()
+            .map(|it| mbe::syntax_node_to_token_tree(it.syntax(), tm, call_site))
+            .collect()
+    } else {
+        vec![]
+    };
+
     // For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
     // types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
     // also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
@@ -301,7 +314,7 @@ fn parse_adt(
         .map(|it| mbe::syntax_node_to_token_tree(it.syntax(), tm, call_site))
         .collect();
     let name_token = name_to_token(tm, name)?;
-    Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
+    Ok(BasicAdtInfo { name: name_token, shape, param_types, where_clause, associated_types })
 }
 
 fn name_to_token(
@@ -366,7 +379,8 @@ fn expand_simple_derive(
         }
     };
     let trait_body = make_trait_body(&info);
-    let mut where_block = vec![];
+    let mut where_block: Vec<_> =
+        info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect();
     let (params, args): (Vec<_>, Vec<_>) = info
         .param_types
         .into_iter()
diff --git a/crates/hir-ty/src/tests/macros.rs b/crates/hir-ty/src/tests/macros.rs
index b0a9361f1c5..2f75338f994 100644
--- a/crates/hir-ty/src/tests/macros.rs
+++ b/crates/hir-ty/src/tests/macros.rs
@@ -1373,3 +1373,34 @@ pub fn attr_macro() {}
 "#,
     );
 }
+
+#[test]
+fn clone_with_type_bound() {
+    check_types(
+        r#"
+//- minicore: derive, clone, builtin_impls
+#[derive(Clone)]
+struct Float;
+
+trait TensorKind: Clone {
+    /// The primitive type of the tensor.
+    type Primitive: Clone;
+}
+
+impl TensorKind for Float {
+    type Primitive = f64;
+}
+
+#[derive(Clone)]
+struct Tensor<K = Float> where K: TensorKind
+{
+    primitive: K::Primitive,
+}
+
+fn foo(t: Tensor) {
+    let x = t.clone();
+      //^ Tensor<Float>
+}
+"#,
+    );
+}