about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-01-09 13:02:09 +0000
committerbors <bors@rust-lang.org>2023-01-09 13:02:09 +0000
commitae659125a509967f09665b96d06f6ce6bf1ddd1b (patch)
tree5f4a84703df302cdc8a4d0b6a5547f963d6d0c4c
parentfe8ee9c43a52042e2150de8024524ee7a2296692 (diff)
parentcfa914958c5eead0d8d84e0c7b8298ef7ac5530b (diff)
downloadrust-ae659125a509967f09665b96d06f6ce6bf1ddd1b.tar.gz
rust-ae659125a509967f09665b96d06f6ce6bf1ddd1b.zip
Auto merge of #13763 - rami3l:fix/gen-partial-eq-generic, r=Veykril
fix: add generic `TypeBoundList` in generated derivable impl

Potentially fixes #13727.

Continuing with the work in #13732, this fix tries to add correct type bounds in the generated `impl` block:

```diff
  enum Either<T, U> {
      Left(T),
      Right(U),
  }

- impl<T, U> PartialEq for Either<T, U> {
+ impl<T: PartialEq, U: PartialEq> PartialEq for Either<T, U> {
      fn eq(&self, other: &Self) -> bool {
          match (self, other) {
              (Self::Left(l0), Self::Left(r0)) => l0 == r0,
              (Self::Right(l0), Self::Right(r0)) => l0 == r0,
              _ => false,
          }
      }
  }
```
-rw-r--r--crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs6
-rw-r--r--crates/ide-assists/src/handlers/generate_impl.rs6
-rw-r--r--crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs64
-rw-r--r--crates/ide-assists/src/utils.rs54
-rw-r--r--crates/syntax/src/ast/make.rs23
5 files changed, 130 insertions, 23 deletions
diff --git a/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs b/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs
index 7c81d2c6a6c..742f1f78c2e 100644
--- a/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs
+++ b/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs
@@ -1,7 +1,9 @@
 use ide_db::{famous_defs::FamousDefs, RootDatabase};
 use syntax::ast::{self, AstNode, HasName};
 
-use crate::{utils::generate_trait_impl_text, AssistContext, AssistId, AssistKind, Assists};
+use crate::{
+    utils::generate_trait_impl_text_intransitive, AssistContext, AssistId, AssistKind, Assists,
+};
 
 // Assist: generate_from_impl_for_enum
 //
@@ -70,7 +72,7 @@ pub(crate) fn generate_from_impl_for_enum(
     }}"#
                 )
             };
-            let from_impl = generate_trait_impl_text(&enum_, &from_trait, &impl_code);
+            let from_impl = generate_trait_impl_text_intransitive(&enum_, &from_trait, &impl_code);
             edit.insert(start_offset, from_impl);
         },
     )
diff --git a/crates/ide-assists/src/handlers/generate_impl.rs b/crates/ide-assists/src/handlers/generate_impl.rs
index 690c97e26d8..9ad14a819d9 100644
--- a/crates/ide-assists/src/handlers/generate_impl.rs
+++ b/crates/ide-assists/src/handlers/generate_impl.rs
@@ -1,7 +1,7 @@
 use syntax::ast::{self, AstNode, HasName};
 
 use crate::{
-    utils::{generate_impl_text, generate_trait_impl_text},
+    utils::{generate_impl_text, generate_trait_impl_text_intransitive},
     AssistContext, AssistId, AssistKind, Assists,
 };
 
@@ -89,11 +89,11 @@ pub(crate) fn generate_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_>) ->
             let start_offset = nominal.syntax().text_range().end();
             match ctx.config.snippet_cap {
                 Some(cap) => {
-                    let snippet = generate_trait_impl_text(&nominal, "$0", "");
+                    let snippet = generate_trait_impl_text_intransitive(&nominal, "$0", "");
                     edit.insert_snippet(cap, start_offset, snippet);
                 }
                 None => {
-                    let text = generate_trait_impl_text(&nominal, "", "");
+                    let text = generate_trait_impl_text_intransitive(&nominal, "", "");
                     edit.insert(start_offset, text);
                 }
             }
diff --git a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
index 2854701c088..a6693d7d790 100644
--- a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
+++ b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
@@ -995,6 +995,68 @@ impl PartialEq for Foo {
     }
 
     #[test]
+    fn add_custom_impl_partial_eq_tuple_enum_generic() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: eq, derive
+#[derive(Partial$0Eq)]
+enum Either<T, U> {
+    Left(T),
+    Right(U),
+}
+"#,
+            r#"
+enum Either<T, U> {
+    Left(T),
+    Right(U),
+}
+
+impl<T: PartialEq, U: PartialEq> PartialEq for Either<T, U> {
+    $0fn eq(&self, other: &Self) -> bool {
+        match (self, other) {
+            (Self::Left(l0), Self::Left(r0)) => l0 == r0,
+            (Self::Right(l0), Self::Right(r0)) => l0 == r0,
+            _ => false,
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_eq_tuple_enum_generic_existing_bounds() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: eq, derive
+#[derive(Partial$0Eq)]
+enum Either<T: PartialEq + Error, U: Clone> {
+    Left(T),
+    Right(U),
+}
+"#,
+            r#"
+enum Either<T: PartialEq + Error, U: Clone> {
+    Left(T),
+    Right(U),
+}
+
+impl<T: PartialEq + Error, U: Clone + PartialEq> PartialEq for Either<T, U> {
+    $0fn eq(&self, other: &Self) -> bool {
+        match (self, other) {
+            (Self::Left(l0), Self::Left(r0)) => l0 == r0,
+            (Self::Right(l0), Self::Right(r0)) => l0 == r0,
+            _ => false,
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
     fn add_custom_impl_partial_eq_record_enum() {
         check_assist(
             replace_derive_with_manual_impl,
@@ -1170,7 +1232,7 @@ struct Foo<T, U> {
     bar: U,
 }
 
-impl<T, U> Default for Foo<T, U> {
+impl<T: Default, U: Default> Default for Foo<T, U> {
     $0fn default() -> Self {
         Self { foo: Default::default(), bar: Default::default() }
     }
diff --git a/crates/ide-assists/src/utils.rs b/crates/ide-assists/src/utils.rs
index 57c37e5b838..7add6606492 100644
--- a/crates/ide-assists/src/utils.rs
+++ b/crates/ide-assists/src/utils.rs
@@ -434,35 +434,67 @@ pub(crate) fn find_impl_block_end(impl_def: ast::Impl, buf: &mut String) -> Opti
     Some(end)
 }
 
-// Generates the surrounding `impl Type { <code> }` including type and lifetime
-// parameters
+/// Generates the surrounding `impl Type { <code> }` including type and lifetime
+/// parameters.
 pub(crate) fn generate_impl_text(adt: &ast::Adt, code: &str) -> String {
-    generate_impl_text_inner(adt, None, code)
+    generate_impl_text_inner(adt, None, true, code)
 }
 
-// Generates the surrounding `impl <trait> for Type { <code> }` including type
-// and lifetime parameters
+/// Generates the surrounding `impl <trait> for Type { <code> }` including type
+/// and lifetime parameters, with `<trait>` appended to `impl`'s generic parameters' bounds.
+///
+/// This is useful for traits like `PartialEq`, since `impl<T> PartialEq for U<T>` often requires `T: PartialEq`.
 pub(crate) fn generate_trait_impl_text(adt: &ast::Adt, trait_text: &str, code: &str) -> String {
-    generate_impl_text_inner(adt, Some(trait_text), code)
+    generate_impl_text_inner(adt, Some(trait_text), true, code)
+}
+
+/// Generates the surrounding `impl <trait> for Type { <code> }` including type
+/// and lifetime parameters, with `impl`'s generic parameters' bounds kept as-is.
+///
+/// This is useful for traits like `From<T>`, since `impl<T> From<T> for U<T>` doesn't require `T: From<T>`.
+pub(crate) fn generate_trait_impl_text_intransitive(
+    adt: &ast::Adt,
+    trait_text: &str,
+    code: &str,
+) -> String {
+    generate_impl_text_inner(adt, Some(trait_text), false, code)
 }
 
-fn generate_impl_text_inner(adt: &ast::Adt, trait_text: Option<&str>, code: &str) -> String {
+fn generate_impl_text_inner(
+    adt: &ast::Adt,
+    trait_text: Option<&str>,
+    trait_is_transitive: bool,
+    code: &str,
+) -> String {
     // Ensure lifetime params are before type & const params
     let generic_params = adt.generic_param_list().map(|generic_params| {
         let lifetime_params =
             generic_params.lifetime_params().map(ast::GenericParam::LifetimeParam);
-        let ty_or_const_params = generic_params.type_or_const_params().filter_map(|param| {
-            // remove defaults since they can't be specified in impls
+        let ty_or_const_params = generic_params.type_or_const_params().map(|param| {
             match param {
                 ast::TypeOrConstParam::Type(param) => {
                     let param = param.clone_for_update();
+                    // remove defaults since they can't be specified in impls
                     param.remove_default();
-                    Some(ast::GenericParam::TypeParam(param))
+                    let mut bounds =
+                        param.type_bound_list().map_or_else(Vec::new, |it| it.bounds().collect());
+                    if let Some(trait_) = trait_text {
+                        // Add the current trait to `bounds` if the trait is transitive,
+                        // meaning `impl<T> Trait for U<T>` requires `T: Trait`.
+                        if trait_is_transitive {
+                            bounds.push(make::type_bound(trait_));
+                        }
+                    };
+                    // `{ty_param}: {bounds}`
+                    let param =
+                        make::type_param(param.name().unwrap(), make::type_bound_list(bounds));
+                    ast::GenericParam::TypeParam(param)
                 }
                 ast::TypeOrConstParam::Const(param) => {
                     let param = param.clone_for_update();
+                    // remove defaults since they can't be specified in impls
                     param.remove_default();
-                    Some(ast::GenericParam::ConstParam(param))
+                    ast::GenericParam::ConstParam(param)
                 }
             }
         });
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index 8c26009add2..686bb40ecd3 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -719,12 +719,23 @@ pub fn param_list(
     ast_from_text(&list)
 }
 
-pub fn type_param(name: ast::Name, ty: Option<ast::TypeBoundList>) -> ast::TypeParam {
-    let bound = match ty {
-        Some(it) => format!(": {it}"),
-        None => String::new(),
-    };
-    ast_from_text(&format!("fn f<{name}{bound}>() {{ }}"))
+pub fn type_bound(bound: &str) -> ast::TypeBound {
+    ast_from_text(&format!("fn f<T: {bound}>() {{ }}"))
+}
+
+pub fn type_bound_list(
+    bounds: impl IntoIterator<Item = ast::TypeBound>,
+) -> Option<ast::TypeBoundList> {
+    let bounds = bounds.into_iter().map(|it| it.to_string()).unique().join(" + ");
+    if bounds.is_empty() {
+        return None;
+    }
+    Some(ast_from_text(&format!("fn f<T: {bounds}>() {{ }}")))
+}
+
+pub fn type_param(name: ast::Name, bounds: Option<ast::TypeBoundList>) -> ast::TypeParam {
+    let bounds = bounds.map_or_else(String::new, |it| format!(": {it}"));
+    ast_from_text(&format!("fn f<{name}{bounds}>() {{ }}"))
 }
 
 pub fn lifetime_param(lifetime: ast::Lifetime) -> ast::LifetimeParam {