about summary refs log tree commit diff
diff options
context:
space:
mode:
authorAli Bektas <bektasali@protonmail.com>2023-07-04 17:13:07 +0200
committerAli Bektas <bektasali@protonmail.com>2023-07-04 17:33:45 +0200
commit03423116ad533e1e80d7095595429e7cb0af2af7 (patch)
tree76b5964ac23b531a7d8662be00090a014a5bb479
parent20c877a700264e01a85cbd27709f3644853e50cf (diff)
downloadrust-03423116ad533e1e80d7095595429e7cb0af2af7.tar.gz
rust-03423116ad533e1e80d7095595429e7cb0af2af7.zip
Generate trait from impl v2
-rw-r--r--crates/ide-assists/src/handlers/generate_trait_from_impl.rs108
-rw-r--r--crates/syntax/src/ast/make.rs2
2 files changed, 79 insertions, 31 deletions
diff --git a/crates/ide-assists/src/handlers/generate_trait_from_impl.rs b/crates/ide-assists/src/handlers/generate_trait_from_impl.rs
index ce9eeb82007..d3192ae4091 100644
--- a/crates/ide-assists/src/handlers/generate_trait_from_impl.rs
+++ b/crates/ide-assists/src/handlers/generate_trait_from_impl.rs
@@ -1,8 +1,8 @@
 use crate::assist_context::{AssistContext, Assists};
-use ide_db::{assists::AssistId, SnippetCap};
+use ide_db::assists::AssistId;
 use syntax::{
-    ast::{self, HasGenericParams, HasVisibility},
-    AstNode,
+    ast::{self, edit::IndentLevel, make, HasGenericParams, HasVisibility},
+    ted, AstNode, SyntaxKind,
 };
 
 // NOTES :
@@ -68,6 +68,16 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_
     // Get AST Node
     let impl_ast = ctx.find_node_at_offset::<ast::Impl>()?;
 
+    // Check if cursor is to the left of assoc item list's L_CURLY.
+    // if no L_CURLY then return.
+    let l_curly = impl_ast.assoc_item_list()?.l_curly_token()?;
+
+    let cursor_offset = ctx.offset();
+    let l_curly_offset = l_curly.text_range();
+    if cursor_offset >= l_curly_offset.start() {
+        return None;
+    }
+
     // If impl is not inherent then we don't really need to go any further.
     if impl_ast.for_token().is_some() {
         return None;
@@ -80,9 +90,11 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_
         return None;
     }
 
+    let impl_name = impl_ast.self_ty()?;
+
     acc.add(
         AssistId("generate_trait_from_impl", ide_db::assists::AssistKind::Generate),
-        "Generate trait from impl".to_owned(),
+        "Generate trait from impl",
         impl_ast.syntax().text_range(),
         |builder| {
             let trait_items = assoc_items.clone_for_update();
@@ -93,45 +105,43 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_
                 remove_items_visibility(&item);
             });
 
-            syntax::ted::replace(assoc_items.clone_for_update().syntax(), impl_items.syntax());
+            ted::replace(assoc_items.clone_for_update().syntax(), impl_items.syntax());
 
             impl_items.assoc_items().for_each(|item| {
                 remove_items_visibility(&item);
             });
 
-            let trait_ast = ast::make::trait_(
+            let trait_ast = make::trait_(
                 false,
-                "NewTrait".to_string(),
-                HasGenericParams::generic_param_list(&impl_ast),
-                HasGenericParams::where_clause(&impl_ast),
+                "NewTrait",
+                impl_ast.generic_param_list(),
+                impl_ast.where_clause(),
                 trait_items,
             );
 
             // Change `impl Foo` to `impl NewTrait for Foo`
-            // First find the PATH_TYPE which is what Foo is.
-            let impl_name = impl_ast.self_ty().unwrap();
-            let trait_name = if let Some(genpars) = impl_ast.generic_param_list() {
-                format!("NewTrait{}", genpars.to_generic_args())
+            let arg_list = if let Some(genpars) = impl_ast.generic_param_list() {
+                genpars.to_generic_args().to_string()
             } else {
-                format!("NewTrait")
+                "".to_string()
             };
 
             // // Then replace
             builder.replace(
-                impl_name.clone().syntax().text_range(),
-                format!("{} for {}", trait_name, impl_name.to_string()),
+                impl_name.syntax().text_range(),
+                format!("NewTrait{} for {}", arg_list, impl_name.to_string()),
             );
 
-            builder.replace(
-                impl_ast.assoc_item_list().unwrap().syntax().text_range(),
-                impl_items.to_string(),
-            );
+            builder.replace(assoc_items.syntax().text_range(), impl_items.to_string());
 
             // Insert trait before TraitImpl
-            builder.insert_snippet(
-                SnippetCap::new(true).unwrap(),
+            builder.insert(
                 impl_ast.syntax().text_range().start(),
-                format!("{}\n\n", trait_ast.to_string()),
+                format!(
+                    "{}\n\n{}",
+                    trait_ast.to_string(),
+                    IndentLevel::from_node(impl_ast.syntax())
+                ),
             );
         },
     );
@@ -144,17 +154,17 @@ fn remove_items_visibility(item: &ast::AssocItem) {
     match item {
         ast::AssocItem::Const(c) => {
             if let Some(vis) = c.visibility() {
-                syntax::ted::remove(vis.syntax());
+                ted::remove(vis.syntax());
             }
         }
         ast::AssocItem::Fn(f) => {
             if let Some(vis) = f.visibility() {
-                syntax::ted::remove(vis.syntax());
+                ted::remove(vis.syntax());
             }
         }
         ast::AssocItem::TypeAlias(t) => {
             if let Some(vis) = t.visibility() {
-                syntax::ted::remove(vis.syntax());
+                ted::remove(vis.syntax());
             }
         }
         _ => (),
@@ -168,12 +178,12 @@ fn strip_body(item: &ast::AssocItem) {
                 // In constrast to function bodies, we want to see no ws before a semicolon.
                 // So let's remove them if we see any.
                 if let Some(prev) = body.syntax().prev_sibling_or_token() {
-                    if prev.kind() == syntax::SyntaxKind::WHITESPACE {
-                        syntax::ted::remove(prev);
+                    if prev.kind() == SyntaxKind::WHITESPACE {
+                        ted::remove(prev);
                     }
                 }
 
-                syntax::ted::replace(body.syntax(), ast::make::tokens::semicolon());
+                ted::replace(body.syntax(), ast::make::tokens::semicolon());
             }
         }
         _ => (),
@@ -186,6 +196,21 @@ mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
 
     #[test]
+    fn test_trigger_when_cursor_on_header() {
+        check_assist_not_applicable(
+            generate_trait_from_impl,
+            r#"
+struct Foo(f64);
+
+impl Foo { $0
+    fn add(&mut self, x: f64) {
+        self.0 += x;
+    }
+}"#,
+        );
+    }
+
+    #[test]
     fn test_assoc_item_fn() {
         check_assist(
             generate_trait_from_impl,
@@ -299,7 +324,7 @@ impl<const N: usize> NewTrait<N> for Foo<N> {
     }
 
     #[test]
-    fn test_e0449_avoided() {
+    fn test_trait_items_should_not_have_vis() {
         check_assist(
             generate_trait_from_impl,
             r#"
@@ -334,4 +359,27 @@ impl Emp$0tyImpl{}
 "#,
         )
     }
+
+    #[test]
+    fn test_not_top_level_impl() {
+        check_assist(
+            generate_trait_from_impl,
+            r#"
+mod a {
+    impl S$0 {
+        fn foo() {}
+    }
+}"#,
+            r#"
+mod a {
+    trait NewTrait {
+        fn foo();
+    }
+
+    impl NewTrait for S {
+        fn foo() {}
+    }
+}"#,
+        )
+    }
 }
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index 1675d1af1dd..3facd90a11d 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -865,7 +865,7 @@ pub fn param_list(
 
 pub fn trait_(
     is_unsafe: bool,
-    ident: String,
+    ident: &str,
     gen_params: Option<ast::GenericParamList>,
     where_clause: Option<ast::WhereClause>,
     assoc_items: ast::AssocItemList,