about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs272
1 files changed, 268 insertions, 4 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index 52a55ead3af..40d0327ef71 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -109,8 +109,6 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
             let params =
                 body.extracted_function_params(ctx, &container_info, locals_used.iter().copied());
 
-            let extracted_from_trait_impl = body.extracted_from_trait_impl();
-
             let name = make_function_name(&semantics_scope);
 
             let fun = Function {
@@ -129,8 +127,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
 
             builder.replace(target_range, make_call(ctx, &fun, old_indent));
 
+            let has_impl_wrapper = insert_after
+                .ancestors()
+                .find(|a| a.kind() == SyntaxKind::IMPL && a != &insert_after)
+                .is_some();
+
             let fn_def = match fun.self_param_adt(ctx) {
-                Some(adt) if extracted_from_trait_impl => {
+                Some(adt) if anchor == Anchor::Method && !has_impl_wrapper => {
                     let fn_def = format_function(ctx, module, &fun, old_indent, new_indent + 1);
                     generate_impl_text(&adt, &fn_def).replace("{\n\n", "{")
                 }
@@ -271,7 +274,7 @@ enum FunType {
 }
 
 /// Where to put extracted function definition
-#[derive(Debug)]
+#[derive(Debug, Eq, PartialEq, Clone, Copy)]
 enum Anchor {
     /// Extract free function and put right after current top-level function
     Freestanding,
@@ -1244,6 +1247,15 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNod
     while let Some(next_ancestor) = ancestors.next() {
         match next_ancestor.kind() {
             SyntaxKind::SOURCE_FILE => break,
+            SyntaxKind::IMPL => {
+                if body.extracted_from_trait_impl() && matches!(anchor, Anchor::Method) {
+                    let impl_node = find_non_trait_impl(&next_ancestor);
+                    let target_node = impl_node.as_ref().and_then(last_impl_member);
+                    if target_node.is_some() {
+                        return target_node;
+                    }
+                }
+            }
             SyntaxKind::ITEM_LIST if !matches!(anchor, Anchor::Freestanding) => continue,
             SyntaxKind::ITEM_LIST => {
                 if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::MODULE) {
@@ -1264,6 +1276,28 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNod
     last_ancestor
 }
 
+fn find_non_trait_impl(trait_impl: &SyntaxNode) -> Option<SyntaxNode> {
+    let impl_type = Some(impl_type_name(trait_impl)?);
+
+    let mut sibblings = trait_impl.parent()?.children();
+    sibblings.find(|s| impl_type_name(s) == impl_type && !is_trait_impl(s))
+}
+
+fn last_impl_member(impl_node: &SyntaxNode) -> Option<SyntaxNode> {
+    impl_node.children().find(|c| c.kind() == SyntaxKind::ASSOC_ITEM_LIST)?.last_child()
+}
+
+fn is_trait_impl(node: &SyntaxNode) -> bool {
+    match ast::Impl::cast(node.clone()) {
+        Some(c) => c.trait_().is_some(),
+        None => false,
+    }
+}
+
+fn impl_type_name(impl_node: &SyntaxNode) -> Option<String> {
+    Some(ast::Impl::cast(impl_node.clone())?.self_ty()?.to_string())
+}
+
 fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> String {
     let ret_ty = fun.return_type(ctx);
 
@@ -5059,6 +5093,236 @@ impl Struct {
     }
 
     #[test]
+    fn extract_method_from_trait_with_existing_non_empty_impl_block() {
+        check_assist(
+            extract_function,
+            r#"
+struct Struct(i32);
+trait Trait {
+    fn bar(&self) -> i32;
+}
+
+impl Struct {
+    fn foo() {}
+}
+
+impl Trait for Struct {
+    fn bar(&self) -> i32 {
+        $0self.0 + 2$0
+    }
+}
+"#,
+            r#"
+struct Struct(i32);
+trait Trait {
+    fn bar(&self) -> i32;
+}
+
+impl Struct {
+    fn foo() {}
+
+    fn $0fun_name(&self) -> i32 {
+        self.0 + 2
+    }
+}
+
+impl Trait for Struct {
+    fn bar(&self) -> i32 {
+        self.fun_name()
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn extract_function_from_trait_with_existing_non_empty_impl_block() {
+        check_assist(
+            extract_function,
+            r#"
+struct Struct(i32);
+trait Trait {
+    fn bar(&self) -> i32;
+}
+
+impl Struct {
+    fn foo() {}
+}
+
+impl Trait for Struct {
+    fn bar(&self) -> i32 {
+        let three_squared = $03 * 3$0;
+        self.0 + three_squared
+    }
+}
+"#,
+            r#"
+struct Struct(i32);
+trait Trait {
+    fn bar(&self) -> i32;
+}
+
+impl Struct {
+    fn foo() {}
+}
+
+impl Trait for Struct {
+    fn bar(&self) -> i32 {
+        let three_squared = fun_name();
+        self.0 + three_squared
+    }
+}
+
+fn $0fun_name() -> i32 {
+    3 * 3
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn extract_method_from_trait_with_multiple_existing_impl_blocks() {
+        check_assist(
+            extract_function,
+            r#"
+struct Struct(i32);
+struct StructBefore(i32);
+struct StructAfter(i32);
+trait Trait {
+    fn bar(&self) -> i32;
+}
+
+impl StructBefore {
+    fn foo(){}
+}
+
+impl Struct {
+    fn foo(){}
+}
+
+impl StructAfter {
+    fn foo(){}
+}
+
+impl Trait for Struct {
+    fn bar(&self) -> i32 {
+        $0self.0 + 2$0
+    }
+}
+"#,
+            r#"
+struct Struct(i32);
+struct StructBefore(i32);
+struct StructAfter(i32);
+trait Trait {
+    fn bar(&self) -> i32;
+}
+
+impl StructBefore {
+    fn foo(){}
+}
+
+impl Struct {
+    fn foo(){}
+
+    fn $0fun_name(&self) -> i32 {
+        self.0 + 2
+    }
+}
+
+impl StructAfter {
+    fn foo(){}
+}
+
+impl Trait for Struct {
+    fn bar(&self) -> i32 {
+        self.fun_name()
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn extract_method_from_trait_with_multiple_existing_trait_impl_blocks() {
+        check_assist(
+            extract_function,
+            r#"
+struct Struct(i32);
+trait Trait {
+    fn bar(&self) -> i32;
+}
+trait TraitBefore {
+    fn before(&self) -> i32;
+}
+trait TraitAfter {
+    fn after(&self) -> i32;
+}
+
+impl TraitBefore for Struct {
+    fn before(&self) -> i32 {
+        42
+    }
+}
+
+impl Struct {
+    fn foo(){}
+}
+
+impl TraitAfter for Struct {
+    fn after(&self) -> i32 {
+        42
+    }
+}
+
+impl Trait for Struct {
+    fn bar(&self) -> i32 {
+        $0self.0 + 2$0
+    }
+}
+"#,
+            r#"
+struct Struct(i32);
+trait Trait {
+    fn bar(&self) -> i32;
+}
+trait TraitBefore {
+    fn before(&self) -> i32;
+}
+trait TraitAfter {
+    fn after(&self) -> i32;
+}
+
+impl TraitBefore for Struct {
+    fn before(&self) -> i32 {
+        42
+    }
+}
+
+impl Struct {
+    fn foo(){}
+
+    fn $0fun_name(&self) -> i32 {
+        self.0 + 2
+    }
+}
+
+impl TraitAfter for Struct {
+    fn after(&self) -> i32 {
+        42
+    }
+}
+
+impl Trait for Struct {
+    fn bar(&self) -> i32 {
+        self.fun_name()
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
     fn closure_arguments() {
         check_assist(
             extract_function,