about summary refs log tree commit diff
diff options
context:
space:
mode:
authorrainy-me <github@yue.coffee>2021-11-25 00:21:29 +0900
committerrainy-me <github@yue.coffee>2021-11-25 00:21:29 +0900
commit0bb08ccb8ff4c82d5df068adcb4799f2804eb1e2 (patch)
treeffcbd7b30d2a8696ba5f9a4ad78d32497d0328ae
parent3e4ac8a2c9136052c6394014048095e5c2468859 (diff)
downloadrust-0bb08ccb8ff4c82d5df068adcb4799f2804eb1e2.tar.gz
rust-0bb08ccb8ff4c82d5df068adcb4799f2804eb1e2.zip
fix: derive path handling
-rw-r--r--crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs115
-rw-r--r--crates/ide_completion/src/completions/attribute.rs30
-rw-r--r--crates/ide_db/src/helpers.rs26
3 files changed, 118 insertions, 53 deletions
diff --git a/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
index 5ba045d3c8f..3e33c62144e 100644
--- a/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
+++ b/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
@@ -1,10 +1,13 @@
 use hir::ModuleDef;
-use ide_db::helpers::{import_assets::NameToImport, mod_path_to_ast};
+use ide_db::helpers::{
+    get_path_at_cursor_in_tt, import_assets::NameToImport, mod_path_to_ast,
+    parse_tt_as_comma_sep_paths,
+};
 use ide_db::items_locator;
 use itertools::Itertools;
 use syntax::{
-    ast::{self, make, AstNode, HasName},
-    SyntaxKind::{IDENT, WHITESPACE},
+    ast::{self, AstNode, AstToken, HasName},
+    SyntaxKind::WHITESPACE,
 };
 
 use crate::{
@@ -52,9 +55,8 @@ pub(crate) fn replace_derive_with_manual_impl(
         return None;
     }
 
-    let trait_token = args.syntax().token_at_offset(ctx.offset()).find(|t| t.kind() == IDENT)?;
-    let trait_name = trait_token.text();
-
+    let ident = args.syntax().token_at_offset(ctx.offset()).find_map(ast::Ident::cast)?;
+    let trait_path = get_path_at_cursor_in_tt(&ident)?;
     let adt = attr.syntax().parent().and_then(ast::Adt::cast)?;
 
     let current_module = ctx.sema.scope(adt.syntax()).module()?;
@@ -63,7 +65,7 @@ pub(crate) fn replace_derive_with_manual_impl(
     let found_traits = items_locator::items_with_name(
         &ctx.sema,
         current_crate,
-        NameToImport::Exact(trait_name.to_string()),
+        NameToImport::Exact(trait_path.segments().last()?.to_string()),
         items_locator::AssocItemSearch::Exclude,
         Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT.inner()),
     )
@@ -80,12 +82,23 @@ pub(crate) fn replace_derive_with_manual_impl(
     });
 
     let mut no_traits_found = true;
-    for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
-        add_assist(acc, ctx, &attr, &args, &trait_path, Some(trait_), &adt)?;
+    let current_derives = parse_tt_as_comma_sep_paths(args.clone())?;
+    let current_derives = current_derives.as_slice();
+    for (replace_trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
+        add_assist(
+            acc,
+            ctx,
+            &attr,
+            &current_derives,
+            &args,
+            &trait_path,
+            &replace_trait_path,
+            Some(trait_),
+            &adt,
+        )?;
     }
     if no_traits_found {
-        let trait_path = make::ext::ident_path(trait_name);
-        add_assist(acc, ctx, &attr, &args, &trait_path, None, &adt)?;
+        add_assist(acc, ctx, &attr, &current_derives, &args, &trait_path, &trait_path, None, &adt)?;
     }
     Some(())
 }
@@ -94,15 +107,16 @@ fn add_assist(
     acc: &mut Assists,
     ctx: &AssistContext,
     attr: &ast::Attr,
-    input: &ast::TokenTree,
-    trait_path: &ast::Path,
+    old_derives: &[ast::Path],
+    old_tree: &ast::TokenTree,
+    old_trait_path: &ast::Path,
+    replace_trait_path: &ast::Path,
     trait_: Option<hir::Trait>,
     adt: &ast::Adt,
 ) -> Option<()> {
     let target = attr.syntax().text_range();
     let annotated_name = adt.name()?;
-    let label = format!("Convert to manual `impl {} for {}`", trait_path, annotated_name);
-    let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
+    let label = format!("Convert to manual `impl {} for {}`", replace_trait_path, annotated_name);
 
     acc.add(
         AssistId("replace_derive_with_manual_impl", AssistKind::Refactor),
@@ -111,9 +125,9 @@ fn add_assist(
         |builder| {
             let insert_pos = adt.syntax().text_range().end();
             let impl_def_with_items =
-                impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, trait_path);
-            update_attribute(builder, input, &trait_name, attr);
-            let trait_path = format!("{}", trait_path);
+                impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, replace_trait_path);
+            update_attribute(builder, old_derives, old_tree, old_trait_path, attr);
+            let trait_path = format!("{}", replace_trait_path);
             match (ctx.config.snippet_cap, impl_def_with_items) {
                 (None, _) => {
                     builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, ""))
@@ -192,23 +206,20 @@ fn impl_def_from_trait(
 
 fn update_attribute(
     builder: &mut AssistBuilder,
-    input: &ast::TokenTree,
-    trait_name: &ast::NameRef,
+    old_derives: &[ast::Path],
+    old_tree: &ast::TokenTree,
+    old_trait_path: &ast::Path,
     attr: &ast::Attr,
 ) {
-    let trait_name = trait_name.text();
-    let new_attr_input = input
-        .syntax()
-        .descendants_with_tokens()
-        .filter(|t| t.kind() == IDENT)
-        .filter_map(|t| t.into_token().map(|t| t.text().to_string()))
-        .filter(|t| t != &trait_name)
+    let new_derives = old_derives
+        .iter()
+        .filter(|t| t.to_string() != old_trait_path.to_string())
         .collect::<Vec<_>>();
-    let has_more_derives = !new_attr_input.is_empty();
+    let has_more_derives = !new_derives.is_empty();
 
     if has_more_derives {
-        let new_attr_input = format!("({})", new_attr_input.iter().format(", "));
-        builder.replace(input.syntax().text_range(), new_attr_input);
+        let new_derives = format!("({})", new_derives.iter().format(", "));
+        builder.replace(old_tree.syntax().text_range(), new_derives);
     } else {
         let attr_range = attr.syntax().text_range();
         builder.delete(attr_range);
@@ -1165,4 +1176,48 @@ struct S;
             "#,
         );
     }
+
+    #[test]
+    fn add_custom_impl_keep_path() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: clone
+#[derive(std::fmt::Debug, Clo$0ne)]
+pub struct Foo;
+"#,
+            r#"
+#[derive(std::fmt::Debug)]
+pub struct Foo;
+
+impl Clone for Foo {
+    $0fn clone(&self) -> Self {
+        Self {  }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_replace_path() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: fmt
+#[derive(core::fmt::Deb$0ug, Clone)]
+pub struct Foo;
+"#,
+            r#"
+#[derive(Clone)]
+pub struct Foo;
+
+impl core::fmt::Debug for Foo {
+    $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+        f.debug_struct("Foo").finish()
+    }
+}
+"#,
+        )
+    }
 }
diff --git a/crates/ide_completion/src/completions/attribute.rs b/crates/ide_completion/src/completions/attribute.rs
index e86f38aaa41..8740db326b0 100644
--- a/crates/ide_completion/src/completions/attribute.rs
+++ b/crates/ide_completion/src/completions/attribute.rs
@@ -4,7 +4,10 @@
 //! for built-in attributes.
 
 use hir::HasAttrs;
-use ide_db::helpers::generated_lints::{CLIPPY_LINTS, DEFAULT_LINTS, FEATURES, RUSTDOC_LINTS};
+use ide_db::helpers::{
+    generated_lints::{CLIPPY_LINTS, DEFAULT_LINTS, FEATURES, RUSTDOC_LINTS},
+    parse_tt_as_comma_sep_paths,
+};
 use itertools::Itertools;
 use once_cell::sync::Lazy;
 use rustc_hash::FxHashMap;
@@ -30,12 +33,14 @@ pub(crate) fn complete_attribute(acc: &mut Completions, ctx: &CompletionContext)
     match (name_ref, attribute.token_tree()) {
         (Some(path), Some(token_tree)) => match path.text().as_str() {
             "repr" => repr::complete_repr(acc, ctx, token_tree),
-            "derive" => derive::complete_derive(acc, ctx, &parse_comma_sep_paths(token_tree)?),
+            "derive" => {
+                derive::complete_derive(acc, ctx, &parse_tt_as_comma_sep_paths(token_tree)?)
+            }
             "feature" => {
-                lint::complete_lint(acc, ctx, &parse_comma_sep_paths(token_tree)?, FEATURES)
+                lint::complete_lint(acc, ctx, &parse_tt_as_comma_sep_paths(token_tree)?, FEATURES)
             }
             "allow" | "warn" | "deny" | "forbid" => {
-                let existing_lints = parse_comma_sep_paths(token_tree)?;
+                let existing_lints = parse_tt_as_comma_sep_paths(token_tree)?;
                 lint::complete_lint(acc, ctx, &existing_lints, DEFAULT_LINTS);
                 lint::complete_lint(acc, ctx, &existing_lints, CLIPPY_LINTS);
                 lint::complete_lint(acc, ctx, &existing_lints, RUSTDOC_LINTS);
@@ -307,23 +312,6 @@ const ATTRIBUTES: &[AttrCompletion] = &[
     .prefer_inner(),
 ];
 
-fn parse_comma_sep_paths(input: ast::TokenTree) -> Option<Vec<ast::Path>> {
-    let r_paren = input.r_paren_token()?;
-    let tokens = input
-        .syntax()
-        .children_with_tokens()
-        .skip(1)
-        .take_while(|it| it.as_token() != Some(&r_paren));
-    let input_expressions = tokens.into_iter().group_by(|tok| tok.kind() == T![,]);
-    Some(
-        input_expressions
-            .into_iter()
-            .filter_map(|(is_sep, group)| (!is_sep).then(|| group))
-            .filter_map(|mut tokens| ast::Path::parse(&tokens.join("")).ok())
-            .collect::<Vec<ast::Path>>(),
-    )
-}
-
 fn parse_comma_sep_expr(input: ast::TokenTree) -> Option<Vec<ast::Expr>> {
     let r_paren = input.r_paren_token()?;
     let tokens = input
diff --git a/crates/ide_db/src/helpers.rs b/crates/ide_db/src/helpers.rs
index 97aff0970a8..1b9cb7ff51c 100644
--- a/crates/ide_db/src/helpers.rs
+++ b/crates/ide_db/src/helpers.rs
@@ -39,10 +39,9 @@ pub fn get_path_in_derive_attr(
     attr: &ast::Attr,
     cursor: &Ident,
 ) -> Option<ast::Path> {
-    let cursor = cursor.syntax();
     let path = attr.path()?;
     let tt = attr.token_tree()?;
-    if !tt.syntax().text_range().contains_range(cursor.text_range()) {
+    if !tt.syntax().text_range().contains_range(cursor.syntax().text_range()) {
         return None;
     }
     let scope = sema.scope(attr.syntax());
@@ -51,7 +50,12 @@ pub fn get_path_in_derive_attr(
     if PathResolution::Macro(derive) != resolved_attr {
         return None;
     }
+    get_path_at_cursor_in_tt(cursor)
+}
 
+/// Parses the path the identifier is part of inside a token tree.
+pub fn get_path_at_cursor_in_tt(cursor: &Ident) -> Option<ast::Path> {
+    let cursor = cursor.syntax();
     let first = cursor
         .siblings_with_tokens(Direction::Prev)
         .filter_map(SyntaxElement::into_token)
@@ -300,3 +304,21 @@ pub fn lint_eq_or_in_group(lint: &str, lint_is: &str) -> bool {
         false
     }
 }
+
+/// Parses the input token tree as comma separated paths.
+pub fn parse_tt_as_comma_sep_paths(input: ast::TokenTree) -> Option<Vec<ast::Path>> {
+    let r_paren = input.r_paren_token()?;
+    let tokens = input
+        .syntax()
+        .children_with_tokens()
+        .skip(1)
+        .take_while(|it| it.as_token() != Some(&r_paren));
+    let input_expressions = tokens.into_iter().group_by(|tok| tok.kind() == T![,]);
+    Some(
+        input_expressions
+            .into_iter()
+            .filter_map(|(is_sep, group)| (!is_sep).then(|| group))
+            .filter_map(|mut tokens| ast::Path::parse(&tokens.join("")).ok())
+            .collect::<Vec<ast::Path>>(),
+    )
+}