about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs64
-rw-r--r--crates/ide-assists/src/utils.rs20
-rw-r--r--crates/syntax/src/ast/make.rs21
3 files changed, 98 insertions, 7 deletions
diff --git a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
index 2854701c088..a6693d7d790 100644
--- a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
+++ b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
@@ -995,6 +995,68 @@ impl PartialEq for Foo {
     }
 
     #[test]
+    fn add_custom_impl_partial_eq_tuple_enum_generic() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: eq, derive
+#[derive(Partial$0Eq)]
+enum Either<T, U> {
+    Left(T),
+    Right(U),
+}
+"#,
+            r#"
+enum Either<T, U> {
+    Left(T),
+    Right(U),
+}
+
+impl<T: PartialEq, U: PartialEq> PartialEq for Either<T, U> {
+    $0fn eq(&self, other: &Self) -> bool {
+        match (self, other) {
+            (Self::Left(l0), Self::Left(r0)) => l0 == r0,
+            (Self::Right(l0), Self::Right(r0)) => l0 == r0,
+            _ => false,
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_eq_tuple_enum_generic_existing_bounds() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: eq, derive
+#[derive(Partial$0Eq)]
+enum Either<T: PartialEq + Error, U: Clone> {
+    Left(T),
+    Right(U),
+}
+"#,
+            r#"
+enum Either<T: PartialEq + Error, U: Clone> {
+    Left(T),
+    Right(U),
+}
+
+impl<T: PartialEq + Error, U: Clone + PartialEq> PartialEq for Either<T, U> {
+    $0fn eq(&self, other: &Self) -> bool {
+        match (self, other) {
+            (Self::Left(l0), Self::Left(r0)) => l0 == r0,
+            (Self::Right(l0), Self::Right(r0)) => l0 == r0,
+            _ => false,
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
     fn add_custom_impl_partial_eq_record_enum() {
         check_assist(
             replace_derive_with_manual_impl,
@@ -1170,7 +1232,7 @@ struct Foo<T, U> {
     bar: U,
 }
 
-impl<T, U> Default for Foo<T, U> {
+impl<T: Default, U: Default> Default for Foo<T, U> {
     $0fn default() -> Self {
         Self { foo: Default::default(), bar: Default::default() }
     }
diff --git a/crates/ide-assists/src/utils.rs b/crates/ide-assists/src/utils.rs
index f38a2d04ff6..2dcf56501d6 100644
--- a/crates/ide-assists/src/utils.rs
+++ b/crates/ide-assists/src/utils.rs
@@ -5,6 +5,7 @@ use std::ops;
 pub(crate) use gen_trait_fn_body::gen_trait_fn_body;
 use hir::{db::HirDatabase, HirDisplay, Semantics};
 use ide_db::{famous_defs::FamousDefs, path_transform::PathTransform, RootDatabase, SnippetCap};
+use itertools::Itertools;
 use stdx::format_to;
 use syntax::{
     ast::{
@@ -452,15 +453,32 @@ fn generate_impl_text_inner(adt: &ast::Adt, trait_text: Option<&str>, code: &str
         let lifetime_params =
             generic_params.lifetime_params().map(ast::GenericParam::LifetimeParam);
         let ty_or_const_params = generic_params.type_or_const_params().filter_map(|param| {
-            // remove defaults since they can't be specified in impls
             match param {
                 ast::TypeOrConstParam::Type(param) => {
                     let param = param.clone_for_update();
+                    // remove defaults since they can't be specified in impls
                     param.remove_default();
+                    let mut bounds = param
+                        .type_bound_list()
+                        .map_or_else(Vec::new, |it| it.bounds().collect_vec());
+                    // `{ty_param}: {trait_text}`
+                    if let Some(trait_) = trait_text {
+                        // Defense against the following cases:
+                        // - The trait is undetermined, e.g. `$0`.
+                        // - The trait is a `From`, e.g. `From<T>`.
+                        if !trait_.starts_with('$')
+                            && !matches!(trait_.split_once('<'), Some((left, _right)) if left.trim() == "From")
+                        {
+                            bounds.push(make::type_bound(trait_));
+                        }
+                    };
+                    let param =
+                        make::type_param(param.name().unwrap(), make::type_bound_list(bounds));
                     Some(ast::GenericParam::TypeParam(param))
                 }
                 ast::TypeOrConstParam::Const(param) => {
                     let param = param.clone_for_update();
+                    // remove defaults since they can't be specified in impls
                     param.remove_default();
                     Some(ast::GenericParam::ConstParam(param))
                 }
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index 8c26009add2..11822361f40 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -719,11 +719,22 @@ pub fn param_list(
     ast_from_text(&list)
 }
 
-pub fn type_param(name: ast::Name, ty: Option<ast::TypeBoundList>) -> ast::TypeParam {
-    let bound = match ty {
-        Some(it) => format!(": {it}"),
-        None => String::new(),
-    };
+pub fn type_bound(bound: &str) -> ast::TypeBound {
+    ast_from_text(&format!("fn f<T: {bound}>() {{ }}"))
+}
+
+pub fn type_bound_list(
+    bounds: impl IntoIterator<Item = ast::TypeBound>,
+) -> Option<ast::TypeBoundList> {
+    let bounds = bounds.into_iter().map(|it| it.to_string()).unique().join(" + ");
+    if bounds.is_empty() {
+        return None;
+    }
+    Some(ast_from_text(&format!("fn f<T: {bounds}>() {{ }}")))
+}
+
+pub fn type_param(name: ast::Name, bounds: Option<ast::TypeBoundList>) -> ast::TypeParam {
+    let bound = bounds.map_or_else(String::new, |it| format!(": {it}"));
     ast_from_text(&format!("fn f<{name}{bound}>() {{ }}"))
 }