about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRyo Yoshida <low.ryoshida@gmail.com>2023-01-15 02:36:26 +0900
committerRyo Yoshida <low.ryoshida@gmail.com>2023-01-16 20:55:56 +0900
commitfc56cacfc1c0e94e5fee6d876fdcdcf3e01f6b66 (patch)
tree99a498b74b28d6d129790f7a210c304e311a8c25
parent8afaaa54b0fc920780011d74723b44e2e8a760a4 (diff)
downloadrust-fc56cacfc1c0e94e5fee6d876fdcdcf3e01f6b66.tar.gz
rust-fc56cacfc1c0e94e5fee6d876fdcdcf3e01f6b66.zip
Test `TraitRef` equality before generating missing impl method body
-rw-r--r--crates/ide-assists/src/handlers/add_missing_impl_members.rs60
-rw-r--r--crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs2
-rw-r--r--crates/ide-assists/src/utils/gen_trait_fn_body.rs36
3 files changed, 80 insertions, 18 deletions
diff --git a/crates/ide-assists/src/handlers/add_missing_impl_members.rs b/crates/ide-assists/src/handlers/add_missing_impl_members.rs
index 161bcc5c8da..627a9852fc8 100644
--- a/crates/ide-assists/src/handlers/add_missing_impl_members.rs
+++ b/crates/ide-assists/src/handlers/add_missing_impl_members.rs
@@ -1,7 +1,5 @@
 use hir::HasSource;
-use ide_db::{
-    syntax_helpers::insert_whitespace_into_node::insert_ws_into, traits::resolve_target_trait,
-};
+use ide_db::syntax_helpers::insert_whitespace_into_node::insert_ws_into;
 use syntax::ast::{self, make, AstNode};
 
 use crate::{
@@ -107,6 +105,7 @@ fn add_missing_impl_members_inner(
 ) -> Option<()> {
     let _p = profile::span("add_missing_impl_members_inner");
     let impl_def = ctx.find_node_at_offset::<ast::Impl>()?;
+    let impl_ = ctx.sema.to_def(&impl_def)?;
 
     if ctx.token_at_offset().all(|t| {
         t.parent_ancestors()
@@ -116,7 +115,8 @@ fn add_missing_impl_members_inner(
     }
 
     let target_scope = ctx.sema.scope(impl_def.syntax())?;
-    let trait_ = resolve_target_trait(&ctx.sema, &impl_def)?;
+    let trait_ref = impl_.trait_ref(ctx.db())?;
+    let trait_ = trait_ref.trait_();
 
     let missing_items = filter_assoc_items(
         &ctx.sema,
@@ -155,7 +155,7 @@ fn add_missing_impl_members_inner(
                 let placeholder;
                 if let DefaultMethods::No = mode {
                     if let ast::AssocItem::Fn(func) = &first_new_item {
-                        if try_gen_trait_body(ctx, func, &trait_, &impl_def).is_none() {
+                        if try_gen_trait_body(ctx, func, trait_ref, &impl_def).is_none() {
                             if let Some(m) =
                                 func.syntax().descendants().find_map(ast::MacroCall::cast)
                             {
@@ -180,13 +180,13 @@ fn add_missing_impl_members_inner(
 fn try_gen_trait_body(
     ctx: &AssistContext<'_>,
     func: &ast::Fn,
-    trait_: &hir::Trait,
+    trait_ref: hir::TraitRef,
     impl_def: &ast::Impl,
 ) -> Option<()> {
-    let trait_path = make::ext::ident_path(&trait_.name(ctx.db()).to_string());
+    let trait_path = make::ext::ident_path(&trait_ref.trait_().name(ctx.db()).to_string());
     let hir_ty = ctx.sema.resolve_type(&impl_def.self_ty()?)?;
     let adt = hir_ty.as_adt()?.source(ctx.db())?;
-    gen_trait_fn_body(func, &trait_path, &adt.value)
+    gen_trait_fn_body(func, &trait_path, &adt.value, Some(trait_ref))
 }
 
 #[cfg(test)]
@@ -1353,6 +1353,50 @@ impl PartialEq for SomeStruct {
     }
 
     #[test]
+    fn test_partial_eq_body_when_types_semantically_match() {
+        check_assist(
+            add_missing_impl_members,
+            r#"
+//- minicore: eq
+struct S<T, U>(T, U);
+type Alias<T> = S<T, T>;
+impl<T> PartialEq<Alias<T>> for S<T, T> {$0}
+"#,
+            r#"
+struct S<T, U>(T, U);
+type Alias<T> = S<T, T>;
+impl<T> PartialEq<Alias<T>> for S<T, T> {
+    $0fn eq(&self, other: &Alias<T>) -> bool {
+        self.0 == other.0 && self.1 == other.1
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn test_partial_eq_body_when_types_dont_match() {
+        check_assist(
+            add_missing_impl_members,
+            r#"
+//- minicore: eq
+struct S<T, U>(T, U);
+type Alias<T> = S<T, T>;
+impl<T> PartialEq<Alias<T>> for S<T, i32> {$0}
+"#,
+            r#"
+struct S<T, U>(T, U);
+type Alias<T> = S<T, T>;
+impl<T> PartialEq<Alias<T>> for S<T, i32> {
+    fn eq(&self, other: &Alias<T>) -> bool {
+        ${0:todo!()}
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
     fn test_ignore_function_body() {
         check_assist_not_applicable(
             add_missing_default_members,
diff --git a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
index a6693d7d790..4cfae0c7212 100644
--- a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
+++ b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
@@ -214,7 +214,7 @@ fn impl_def_from_trait(
 
     // Generate a default `impl` function body for the derived trait.
     if let ast::AssocItem::Fn(ref func) = first_assoc_item {
-        let _ = gen_trait_fn_body(func, trait_path, adt);
+        let _ = gen_trait_fn_body(func, trait_path, adt, None);
     };
 
     Some((impl_def, first_assoc_item))
diff --git a/crates/ide-assists/src/utils/gen_trait_fn_body.rs b/crates/ide-assists/src/utils/gen_trait_fn_body.rs
index d4abb51259e..808b2340595 100644
--- a/crates/ide-assists/src/utils/gen_trait_fn_body.rs
+++ b/crates/ide-assists/src/utils/gen_trait_fn_body.rs
@@ -1,5 +1,6 @@
 //! This module contains functions to generate default trait impl function bodies where possible.
 
+use hir::TraitRef;
 use syntax::{
     ast::{self, edit::AstNodeEdit, make, AstNode, BinaryOp, CmpOp, HasName, LogicOp},
     ted,
@@ -7,6 +8,8 @@ use syntax::{
 
 /// Generate custom trait bodies without default implementation where possible.
 ///
+/// If `func` is defined within an existing impl block, pass [`TraitRef`]. Otherwise pass `None`.
+///
 /// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
 /// `None` means that generating a custom trait body failed, and the body will remain
 /// as `todo!` instead.
@@ -14,14 +17,15 @@ pub(crate) fn gen_trait_fn_body(
     func: &ast::Fn,
     trait_path: &ast::Path,
     adt: &ast::Adt,
+    trait_ref: Option<TraitRef>,
 ) -> Option<()> {
     match trait_path.segment()?.name_ref()?.text().as_str() {
         "Clone" => gen_clone_impl(adt, func),
         "Debug" => gen_debug_impl(adt, func),
         "Default" => gen_default_impl(adt, func),
         "Hash" => gen_hash_impl(adt, func),
-        "PartialEq" => gen_partial_eq(adt, func),
-        "PartialOrd" => gen_partial_ord(adt, func),
+        "PartialEq" => gen_partial_eq(adt, func, trait_ref),
+        "PartialOrd" => gen_partial_ord(adt, func, trait_ref),
         _ => None,
     }
 }
@@ -395,7 +399,7 @@ fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
 }
 
 /// Generate a `PartialEq` impl based on the fields and members of the target type.
-fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
+fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn, trait_ref: Option<TraitRef>) -> Option<()> {
     stdx::always!(func.name().map_or(false, |name| name.text() == "eq"));
     fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
         match expr {
@@ -423,8 +427,15 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
         ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name)))
     }
 
-    // FIXME: return `None` if the trait carries a generic type; we can only
-    // generate this code `Self` for the time being.
+    // Check that self type and rhs type match. We don't know how to implement the method
+    // automatically otherwise.
+    if let Some(trait_ref) = trait_ref {
+        let self_ty = trait_ref.self_ty();
+        let rhs_ty = trait_ref.get_type_argument(1)?;
+        if self_ty != rhs_ty {
+            return None;
+        }
+    }
 
     let body = match adt {
         // `PartialEq` cannot be derived for unions, so no default impl can be provided.
@@ -568,7 +579,7 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
                 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
             }
 
-            // No fields in the body means there's nothing to hash.
+            // No fields in the body means there's nothing to compare.
             None => {
                 let expr = make::expr_literal("true").into();
                 make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
@@ -580,7 +591,7 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
     Some(())
 }
 
-fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
+fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn, trait_ref: Option<TraitRef>) -> Option<()> {
     stdx::always!(func.name().map_or(false, |name| name.text() == "partial_cmp"));
     fn gen_partial_eq_match(match_target: ast::Expr) -> Option<ast::Stmt> {
         let mut arms = vec![];
@@ -605,8 +616,15 @@ fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
         make::expr_method_call(lhs, method, make::arg_list(Some(rhs)))
     }
 
-    // FIXME: return `None` if the trait carries a generic type; we can only
-    // generate this code `Self` for the time being.
+    // Check that self type and rhs type match. We don't know how to implement the method
+    // automatically otherwise.
+    if let Some(trait_ref) = trait_ref {
+        let self_ty = trait_ref.self_ty();
+        let rhs_ty = trait_ref.get_type_argument(1)?;
+        if self_ty != rhs_ty {
+            return None;
+        }
+    }
 
     let body = match adt {
         // `PartialOrd` cannot be derived for unions, so no default impl can be provided.