about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDropDemBits <r3usrlnd@gmail.com>2023-12-11 17:37:45 -0500
committerDropDemBits <r3usrlnd@gmail.com>2024-02-08 19:09:33 -0500
commit0e39257e5be05458596fb5cce9bb806081ea0cf1 (patch)
treece5527ba3a63b29db0bb30eacca4240f05178d88
parent3924a0ef7c86540d4c84919bf1e0054e30c34711 (diff)
downloadrust-0e39257e5be05458596fb5cce9bb806081ea0cf1.tar.gz
rust-0e39257e5be05458596fb5cce9bb806081ea0cf1.zip
Migrate `extract_function` to mutable ast
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs294
-rw-r--r--crates/ide-assists/src/tests.rs13
2 files changed, 194 insertions, 113 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index 1eb28626f75..fce132f7824 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -1,4 +1,4 @@
-use std::iter;
+use std::{iter, ops::RangeInclusive};
 
 use ast::make;
 use either::Either;
@@ -12,27 +12,25 @@ use ide_db::{
     helpers::mod_path_to_ast,
     imports::insert_use::{insert_use, ImportScope},
     search::{FileReference, ReferenceCategory, SearchScope},
+    source_change::SourceChangeBuilder,
     syntax_helpers::node_ext::{
         for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr,
     },
     FxIndexSet, RootDatabase,
 };
-use itertools::Itertools;
-use stdx::format_to;
 use syntax::{
     ast::{
-        self,
-        edit::{AstNodeEdit, IndentLevel},
-        AstNode, HasGenericParams,
+        self, edit::IndentLevel, edit_in_place::Indent, AstNode, AstToken, HasGenericParams,
+        HasName,
     },
-    match_ast, ted, AstToken, SyntaxElement,
+    match_ast, ted, SyntaxElement,
     SyntaxKind::{self, COMMENT},
     SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, WalkEvent, T,
 };
 
 use crate::{
     assist_context::{AssistContext, Assists, TreeMutator},
-    utils::generate_impl_text,
+    utils::generate_impl,
     AssistId,
 };
 
@@ -134,17 +132,65 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
             let new_indent = IndentLevel::from_node(&insert_after);
             let old_indent = fun.body.indent_level();
 
-            builder.replace(target_range, make_call(ctx, &fun, old_indent));
+            let insert_after = builder.make_syntax_mut(insert_after);
+
+            let call_expr = make_call(ctx, &fun, old_indent);
+
+            // Map the element range to replace into the mutable version
+            let elements = match &fun.body {
+                FunctionBody::Expr(expr) => {
+                    // expr itself becomes the replacement target
+                    let expr = &builder.make_mut(expr.clone());
+                    let node = SyntaxElement::Node(expr.syntax().clone());
+
+                    node.clone()..=node
+                }
+                FunctionBody::Span { parent, elements, .. } => {
+                    // Map the element range into the mutable versions
+                    let parent = builder.make_mut(parent.clone());
+
+                    let start = parent
+                        .syntax()
+                        .children_with_tokens()
+                        .nth(elements.start().index())
+                        .expect("should be able to find mutable start element");
+
+                    let end = parent
+                        .syntax()
+                        .children_with_tokens()
+                        .nth(elements.end().index())
+                        .expect("should be able to find mutable end element");
+
+                    start..=end
+                }
+            };
 
             let has_impl_wrapper =
                 insert_after.ancestors().any(|a| a.kind() == SyntaxKind::IMPL && a != insert_after);
 
+            let fn_def = format_function(ctx, module, &fun, old_indent).clone_for_update();
+
+            if let Some(cap) = ctx.config.snippet_cap {
+                if let Some(name) = fn_def.name() {
+                    builder.add_tabstop_before(cap, name);
+                }
+            }
+
             let fn_def = match fun.self_param_adt(ctx) {
                 Some(adt) if anchor == Anchor::Method && !has_impl_wrapper => {
-                    let fn_def = format_function(ctx, module, &fun, old_indent, new_indent + 1);
-                    generate_impl_text(&adt, &fn_def).replace("{\n\n", "{")
+                    fn_def.indent(1.into());
+
+                    let impl_ = generate_impl(&adt);
+                    impl_.indent(new_indent);
+                    impl_.get_or_create_assoc_item_list().add_item(fn_def.into());
+
+                    impl_.syntax().clone()
+                }
+                _ => {
+                    fn_def.indent(new_indent.into());
+
+                    fn_def.syntax().clone()
                 }
-                _ => format_function(ctx, module, &fun, old_indent, new_indent),
             };
 
             // There are external control flows
@@ -177,12 +223,15 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
                 }
             }
 
-            let insert_offset = insert_after.text_range().end();
+            // Replace the call site with the call to the new function
+            fixup_call_site(builder, &fun.body);
+            ted::replace_all(elements, vec![call_expr.into()]);
 
-            match ctx.config.snippet_cap {
-                Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def),
-                None => builder.insert(insert_offset, fn_def),
-            };
+            // Insert the newly extracted function (or impl)
+            ted::insert_all_raw(
+                ted::Position::after(insert_after),
+                vec![make::tokens::whitespace(&format!("\n\n{new_indent}")).into(), fn_def.into()],
+            );
         },
     )
 }
@@ -225,10 +274,10 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option<Fu
     if let Some(stmt) = ast::Stmt::cast(node.clone()) {
         return match stmt {
             ast::Stmt::Item(_) => None,
-            ast::Stmt::ExprStmt(_) | ast::Stmt::LetStmt(_) => Some(FunctionBody::from_range(
+            ast::Stmt::ExprStmt(_) | ast::Stmt::LetStmt(_) => FunctionBody::from_range(
                 node.parent().and_then(ast::StmtList::cast)?,
                 node.text_range(),
-            )),
+            ),
         };
     }
 
@@ -241,7 +290,7 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option<Fu
         }
 
         // Extract the full statements.
-        return Some(FunctionBody::from_range(stmt_list, selection_range));
+        return FunctionBody::from_range(stmt_list, selection_range);
     }
 
     let expr = ast::Expr::cast(node.clone())?;
@@ -371,7 +420,7 @@ impl RetType {
 #[derive(Debug)]
 enum FunctionBody {
     Expr(ast::Expr),
-    Span { parent: ast::StmtList, text_range: TextRange },
+    Span { parent: ast::StmtList, elements: RangeInclusive<SyntaxElement>, text_range: TextRange },
 }
 
 #[derive(Debug)]
@@ -569,26 +618,38 @@ impl FunctionBody {
         }
     }
 
-    fn from_range(parent: ast::StmtList, selected: TextRange) -> FunctionBody {
+    fn from_range(parent: ast::StmtList, selected: TextRange) -> Option<FunctionBody> {
         let full_body = parent.syntax().children_with_tokens();
 
-        let mut text_range = full_body
+        // Get all of the elements intersecting with the selection
+        let mut stmts_in_selection = full_body
             .filter(|it| ast::Stmt::can_cast(it.kind()) || it.kind() == COMMENT)
-            .map(|element| element.text_range())
-            .filter(|&range| selected.intersect(range).filter(|it| !it.is_empty()).is_some())
-            .reduce(|acc, stmt| acc.cover(stmt));
-
-        if let Some(tail_range) = parent
-            .tail_expr()
-            .map(|it| it.syntax().text_range())
-            .filter(|&it| selected.intersect(it).is_some())
+            .filter(|it| selected.intersect(it.text_range()).filter(|it| !it.is_empty()).is_some());
+
+        let first_element = stmts_in_selection.next();
+
+        // If the tail expr is part of the selection too, make that the last element
+        // Otherwise use the last stmt
+        let last_element = if let Some(tail_expr) =
+            parent.tail_expr().filter(|it| selected.intersect(it.syntax().text_range()).is_some())
         {
-            text_range = Some(match text_range {
-                Some(text_range) => text_range.cover(tail_range),
-                None => tail_range,
-            });
-        }
-        Self::Span { parent, text_range: text_range.unwrap_or(selected) }
+            Some(tail_expr.syntax().clone().into())
+        } else {
+            stmts_in_selection.last()
+        };
+
+        let elements = match (first_element, last_element) {
+            (None, _) => {
+                cov_mark::hit!(extract_function_empty_selection_is_not_applicable);
+                return None;
+            }
+            (Some(first), None) => first.clone()..=first,
+            (Some(first), Some(last)) => first..=last,
+        };
+
+        let text_range = elements.start().text_range().cover(elements.end().text_range());
+
+        Some(Self::Span { parent, elements, text_range })
     }
 
     fn indent_level(&self) -> IndentLevel {
@@ -601,7 +662,7 @@ impl FunctionBody {
     fn tail_expr(&self) -> Option<ast::Expr> {
         match &self {
             FunctionBody::Expr(expr) => Some(expr.clone()),
-            FunctionBody::Span { parent, text_range } => {
+            FunctionBody::Span { parent, text_range, .. } => {
                 let tail_expr = parent.tail_expr()?;
                 text_range.contains_range(tail_expr.syntax().text_range()).then_some(tail_expr)
             }
@@ -611,7 +672,7 @@ impl FunctionBody {
     fn walk_expr(&self, cb: &mut dyn FnMut(ast::Expr)) {
         match self {
             FunctionBody::Expr(expr) => walk_expr(expr, cb),
-            FunctionBody::Span { parent, text_range } => {
+            FunctionBody::Span { parent, text_range, .. } => {
                 parent
                     .statements()
                     .filter(|stmt| text_range.contains_range(stmt.syntax().text_range()))
@@ -634,7 +695,7 @@ impl FunctionBody {
     fn preorder_expr(&self, cb: &mut dyn FnMut(WalkEvent<ast::Expr>) -> bool) {
         match self {
             FunctionBody::Expr(expr) => preorder_expr(expr, cb),
-            FunctionBody::Span { parent, text_range } => {
+            FunctionBody::Span { parent, text_range, .. } => {
                 parent
                     .statements()
                     .filter(|stmt| text_range.contains_range(stmt.syntax().text_range()))
@@ -657,7 +718,7 @@ impl FunctionBody {
     fn walk_pat(&self, cb: &mut dyn FnMut(ast::Pat)) {
         match self {
             FunctionBody::Expr(expr) => walk_patterns_in_expr(expr, cb),
-            FunctionBody::Span { parent, text_range } => {
+            FunctionBody::Span { parent, text_range, .. } => {
                 parent
                     .statements()
                     .filter(|stmt| text_range.contains_range(stmt.syntax().text_range()))
@@ -1151,7 +1212,7 @@ impl HasTokenAtOffset for FunctionBody {
     fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset<SyntaxToken> {
         match self {
             FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset),
-            FunctionBody::Span { parent, text_range } => {
+            FunctionBody::Span { parent, text_range, .. } => {
                 match parent.syntax().token_at_offset(offset) {
                     TokenAtOffset::None => TokenAtOffset::None,
                     TokenAtOffset::Single(t) => {
@@ -1316,7 +1377,19 @@ fn impl_type_name(impl_node: &ast::Impl) -> Option<String> {
     Some(impl_node.self_ty()?.to_string())
 }
 
-fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> String {
+/// Fixes up the call site before the target expressions are replaced with the call expression
+fn fixup_call_site(builder: &mut SourceChangeBuilder, body: &FunctionBody) {
+    let parent_match_arm = body.parent().and_then(ast::MatchArm::cast);
+
+    if let Some(parent_match_arm) = parent_match_arm {
+        if parent_match_arm.comma_token().is_none() {
+            let parent_match_arm = builder.make_mut(parent_match_arm);
+            ted::append_child_raw(parent_match_arm.syntax(), make::token(T![,]));
+        }
+    }
+}
+
+fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> SyntaxNode {
     let ret_ty = fun.return_type(ctx);
 
     let args = make::arg_list(fun.params.iter().map(|param| param.to_arg(ctx)));
@@ -1334,44 +1407,49 @@ fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> St
     if fun.control_flow.is_async {
         call_expr = make::expr_await(call_expr);
     }
-    let expr = handler.make_call_expr(call_expr).indent(indent);
 
-    let mut_modifier = |var: &OutlivedLocal| if var.mut_usage_outside_body { "mut " } else { "" };
+    let expr = handler.make_call_expr(call_expr).clone_for_update();
+    expr.indent(indent);
 
-    let mut buf = String::new();
-    match fun.outliving_locals.as_slice() {
-        [] => {}
+    let outliving_bindings = match fun.outliving_locals.as_slice() {
+        [] => None,
         [var] => {
-            let modifier = mut_modifier(var);
             let name = var.local.name(ctx.db());
-            format_to!(buf, "let {modifier}{} = ", name.display(ctx.db()))
+            let name = make::name(&name.display(ctx.db()).to_string());
+            Some(ast::Pat::IdentPat(make::ident_pat(
+                false,
+                var.mut_usage_outside_body,
+                name.into(),
+            )))
         }
         vars => {
-            buf.push_str("let (");
-            let bindings = vars.iter().format_with(", ", |local, f| {
-                let modifier = mut_modifier(local);
-                let name = local.local.name(ctx.db());
-                f(&format_args!("{modifier}{}", name.display(ctx.db())))?;
-                Ok(())
+            let binding_pats = vars.iter().map(|var| {
+                let name = var.local.name(ctx.db());
+                let name = make::name(&name.display(ctx.db()).to_string());
+                make::ident_pat(false, var.mut_usage_outside_body, name.into()).into()
             });
-            format_to!(buf, "{bindings}");
-            buf.push_str(") = ");
+            Some(ast::Pat::TuplePat(make::tuple_pat(binding_pats)))
         }
-    }
+    };
 
-    format_to!(buf, "{expr}");
     let parent_match_arm = fun.body.parent().and_then(ast::MatchArm::cast);
-    let insert_comma = parent_match_arm.as_ref().is_some_and(|it| it.comma_token().is_none());
 
-    if insert_comma {
-        buf.push(',');
-    } else if parent_match_arm.is_none()
+    if let Some(bindings) = outliving_bindings {
+        // with bindings that outlive it
+        make::let_stmt(bindings, None, Some(expr)).syntax().clone_for_update()
+    } else if parent_match_arm.as_ref().is_some() {
+        // as a tail expr for a match arm
+        expr.syntax().clone()
+    } else if parent_match_arm.as_ref().is_none()
         && fun.ret_ty.is_unit()
         && (!fun.outliving_locals.is_empty() || !expr.is_block_like())
     {
-        buf.push(';');
+        // as an expr stmt
+        make::expr_stmt(expr).syntax().clone_for_update()
+    } else {
+        // as a tail expr, or a block
+        expr.syntax().clone()
     }
-    buf
 }
 
 enum FlowHandler {
@@ -1500,42 +1578,25 @@ fn format_function(
     module: hir::Module,
     fun: &Function,
     old_indent: IndentLevel,
-    new_indent: IndentLevel,
-) -> String {
-    let mut fn_def = String::new();
-
-    let fun_name = &fun.name;
+) -> ast::Fn {
+    let fun_name = make::name(&fun.name.text());
     let params = fun.make_param_list(ctx, module);
     let ret_ty = fun.make_ret_ty(ctx, module);
-    let body = make_body(ctx, old_indent, new_indent, fun);
-    let const_kw = if fun.mods.is_const { "const " } else { "" };
-    let async_kw = if fun.control_flow.is_async { "async " } else { "" };
-    let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" };
+    let body = make_body(ctx, old_indent, fun);
     let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun);
 
-    format_to!(fn_def, "\n\n{new_indent}{const_kw}{async_kw}{unsafe_kw}");
-    match ctx.config.snippet_cap {
-        Some(_) => format_to!(fn_def, "fn $0{fun_name}"),
-        None => format_to!(fn_def, "fn {fun_name}"),
-    }
-
-    if let Some(generic_params) = generic_params {
-        format_to!(fn_def, "{generic_params}");
-    }
-
-    format_to!(fn_def, "{params}");
-
-    if let Some(ret_ty) = ret_ty {
-        format_to!(fn_def, " {ret_ty}");
-    }
-
-    if let Some(where_clause) = where_clause {
-        format_to!(fn_def, " {where_clause}");
-    }
-
-    format_to!(fn_def, " {body}");
-
-    fn_def
+    make::fn_(
+        None,
+        fun_name,
+        generic_params,
+        where_clause,
+        params,
+        body,
+        ret_ty,
+        fun.control_flow.is_async,
+        fun.mods.is_const,
+        fun.control_flow.is_unsafe,
+    )
 }
 
 fn make_generic_params_and_where_clause(
@@ -1716,12 +1777,7 @@ impl FunType {
     }
 }
 
-fn make_body(
-    ctx: &AssistContext<'_>,
-    old_indent: IndentLevel,
-    new_indent: IndentLevel,
-    fun: &Function,
-) -> ast::BlockExpr {
+fn make_body(ctx: &AssistContext<'_>, old_indent: IndentLevel, fun: &Function) -> ast::BlockExpr {
     let ret_ty = fun.return_type(ctx);
     let handler = FlowHandler::from_ret_ty(fun, &ret_ty);
 
@@ -1732,7 +1788,7 @@ fn make_body(
             match expr {
                 ast::Expr::BlockExpr(block) => {
                     // If the extracted expression is itself a block, there is no need to wrap it inside another block.
-                    let block = block.dedent(old_indent);
+                    block.dedent(old_indent);
                     let elements = block.stmt_list().map_or_else(
                         || Either::Left(iter::empty()),
                         |stmt_list| {
@@ -1752,13 +1808,13 @@ fn make_body(
                     make::hacky_block_expr(elements, block.tail_expr())
                 }
                 _ => {
-                    let expr = expr.dedent(old_indent).indent(IndentLevel(1));
+                    expr.reindent_to(1.into());
 
                     make::block_expr(Vec::new(), Some(expr))
                 }
             }
         }
-        FunctionBody::Span { parent, text_range } => {
+        FunctionBody::Span { parent, text_range, .. } => {
             let mut elements: Vec<_> = parent
                 .syntax()
                 .children_with_tokens()
@@ -1801,8 +1857,8 @@ fn make_body(
                 .map(|node_or_token| match &node_or_token {
                     syntax::NodeOrToken::Node(node) => match ast::Stmt::cast(node.clone()) {
                         Some(stmt) => {
-                            let indented = stmt.dedent(old_indent).indent(body_indent);
-                            let ast_node = indented.syntax().clone_subtree();
+                            stmt.reindent_to(body_indent);
+                            let ast_node = stmt.syntax().clone_subtree();
                             syntax::NodeOrToken::Node(ast_node)
                         }
                         _ => node_or_token,
@@ -1810,7 +1866,9 @@ fn make_body(
                     _ => node_or_token,
                 })
                 .collect::<Vec<SyntaxElement>>();
-            let tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent));
+            if let Some(tail_expr) = &mut tail_expr {
+                tail_expr.reindent_to(body_indent);
+            }
 
             make::hacky_block_expr(elements, tail_expr)
         }
@@ -1853,7 +1911,7 @@ fn make_body(
         }),
     };
 
-    block.indent(new_indent)
+    block
 }
 
 fn map_tail_expr(block: ast::BlockExpr, f: impl FnOnce(ast::Expr) -> ast::Expr) -> ast::BlockExpr {
@@ -2552,6 +2610,20 @@ fn $0fun_name(n: u32) -> u32 {
     }
 
     #[test]
+    fn empty_selection_is_not_applicable() {
+        cov_mark::check!(extract_function_empty_selection_is_not_applicable);
+        check_assist_not_applicable(
+            extract_function,
+            r#"
+fn main() {
+    $0
+
+    $0
+}"#,
+        );
+    }
+
+    #[test]
     fn part_of_expr_stmt() {
         check_assist(
             extract_function,
diff --git a/crates/ide-assists/src/tests.rs b/crates/ide-assists/src/tests.rs
index 573d69b5c6d..466264d8e4d 100644
--- a/crates/ide-assists/src/tests.rs
+++ b/crates/ide-assists/src/tests.rs
@@ -687,12 +687,21 @@ pub fn test_some_range(a: int) -> bool {
                                             delete: 59..60,
                                         },
                                         Indel {
-                                            insert: "\n\nfn $0fun_name() -> i32 {\n    5\n}",
+                                            insert: "\n\nfn fun_name() -> i32 {\n    5\n}",
                                             delete: 110..110,
                                         },
                                     ],
                                 },
-                                None,
+                                Some(
+                                    SnippetEdit(
+                                        [
+                                            (
+                                                0,
+                                                124..124,
+                                            ),
+                                        ],
+                                    ),
+                                ),
                             ),
                         },
                         file_system_edits: [],