about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDorian Scheidt <dorian.scheidt@gmail.com>2022-07-01 15:44:10 -0500
committerDorian Scheidt <dorian.scheidt@gmail.com>2022-07-02 15:00:02 -0500
commite3940003a20b6df3e2def1a5e38b20d35e8a1db8 (patch)
tree5fecb5f9be111e386c7162a9a846e8d9816ea17b
parentcc0bb71e258fa1a180435e265ece1fcc3ffb59f1 (diff)
downloadrust-e3940003a20b6df3e2def1a5e38b20d35e8a1db8.tar.gz
rust-e3940003a20b6df3e2def1a5e38b20d35e8a1db8.zip
fix: Extract function from trait impl
This change fixes #10036, "Extract to function assist implements nonexistent
trait methods".

When we detect that the extraction is coming from within a trait impl, and that
a `self` param will be necessary, we adjust which `SyntaxNode` to `insert_after`,
and create a new empty `impl` block for the newly extracted function.
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs85
1 files changed, 75 insertions, 10 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index 68dcb6c0d23..c9f01ba64a5 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -2,7 +2,7 @@ use std::iter;
 
 use ast::make;
 use either::Either;
-use hir::{HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo};
+use hir::{HasSource, HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo};
 use ide_db::{
     defs::{Definition, NameRefClass},
     famous_defs::FamousDefs,
@@ -27,6 +27,7 @@ use syntax::{
 
 use crate::{
     assist_context::{AssistContext, Assists, TreeMutator},
+    utils::generate_impl_text,
     AssistId,
 };
 
@@ -106,6 +107,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
             let params =
                 body.extracted_function_params(ctx, &container_info, locals_used.iter().copied());
 
+            let extracted_from_trait_impl = body.extracted_from_trait_impl();
+
             let name = make_function_name(&semantics_scope);
 
             let fun = Function {
@@ -124,8 +127,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
 
             builder.replace(target_range, make_call(ctx, &fun, old_indent));
 
-            let fn_def = format_function(ctx, module, &fun, old_indent, new_indent);
-            let insert_offset = insert_after.text_range().end();
+            let fn_def = match fun.self_param_adt(ctx) {
+                Some(adt) if extracted_from_trait_impl => {
+                    let fn_def = format_function(ctx, module, &fun, old_indent, new_indent + 1);
+                    generate_impl_text(&adt, &fn_def).replace("{\n\n", "{")
+                }
+                _ => format_function(ctx, module, &fun, old_indent, new_indent),
+            };
 
             if fn_def.contains("ControlFlow") {
                 let scope = match scope {
@@ -150,6 +158,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
                 }
             }
 
+            let insert_offset = insert_after.text_range().end();
+
             match ctx.config.snippet_cap {
                 Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def),
                 None => builder.insert(insert_offset, fn_def),
@@ -381,6 +391,14 @@ impl Function {
             },
         }
     }
+
+    fn self_param_adt(&self, ctx: &AssistContext) -> Option<ast::Adt> {
+        let self_param = self.self_param.as_ref()?;
+        let def = ctx.sema.to_def(self_param)?;
+        let adt = def.ty(ctx.db()).strip_references().as_adt()?;
+        let InFile { file_id: _, value } = adt.source(ctx.db())?;
+        Some(value)
+    }
 }
 
 impl ParamKind {
@@ -485,6 +503,20 @@ impl FunctionBody {
         }
     }
 
+    fn node(&self) -> &SyntaxNode {
+        match self {
+            FunctionBody::Expr(e) => e.syntax(),
+            FunctionBody::Span { parent, .. } => parent.syntax(),
+        }
+    }
+
+    fn extracted_from_trait_impl(&self) -> bool {
+        match self.node().ancestors().find_map(ast::Impl::cast) {
+            Some(c) => return c.trait_().is_some(),
+            None => false,
+        }
+    }
+
     fn from_expr(expr: ast::Expr) -> Option<Self> {
         match expr {
             ast::Expr::BreakExpr(it) => it.expr().map(Self::Expr),
@@ -1111,10 +1143,7 @@ fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
 ///
 /// Function should be put right after returned node
 fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNode> {
-    let node = match body {
-        FunctionBody::Expr(e) => e.syntax(),
-        FunctionBody::Span { parent, .. } => parent.syntax(),
-    };
+    let node = body.node();
     let mut ancestors = node.ancestors().peekable();
     let mut last_ancestor = None;
     while let Some(next_ancestor) = ancestors.next() {
@@ -1126,9 +1155,8 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNod
                     break;
                 }
             }
-            SyntaxKind::ASSOC_ITEM_LIST if !matches!(anchor, Anchor::Method) => {
-                continue;
-            }
+            SyntaxKind::ASSOC_ITEM_LIST if !matches!(anchor, Anchor::Method) => continue,
+            SyntaxKind::ASSOC_ITEM_LIST if body.extracted_from_trait_impl() => continue,
             SyntaxKind::ASSOC_ITEM_LIST => {
                 if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::IMPL) {
                     break;
@@ -4780,4 +4808,41 @@ fn $0fun_name2() {
 "#,
         );
     }
+
+    #[test]
+    fn extract_method_from_trait_impl() {
+        check_assist(
+            extract_function,
+            r#"
+struct Struct(i32);
+trait Trait {
+    fn bar(&self) -> i32;
+}
+
+impl Trait for Struct {
+    fn bar(&self) -> i32 {
+        $0self.0 + 2$0
+    }
+}
+"#,
+            r#"
+struct Struct(i32);
+trait Trait {
+    fn bar(&self) -> i32;
+}
+
+impl Trait for Struct {
+    fn bar(&self) -> i32 {
+        self.fun_name()
+    }
+}
+
+impl Struct {
+    fn $0fun_name(&self) -> i32 {
+        self.0 + 2
+    }
+}
+"#,
+        );
+    }
 }