about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs82
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs138
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/syntax_editor/mapping.rs6
-rw-r--r--src/tools/rust-analyzer/docs/book/src/assists_generated.md8
4 files changed, 174 insertions, 60 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs
index 8d391c64ce6..151c71c0a76 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs
@@ -8,12 +8,13 @@ use ide_db::{
 };
 use itertools::Itertools;
 use syntax::{
-    ast::{self, edit::AstNodeEdit, make, HasArgList},
-    ted, AstNode, SyntaxNode,
+    ast::{self, edit::AstNodeEdit, syntax_factory::SyntaxFactory, HasArgList},
+    syntax_editor::SyntaxEditor,
+    AstNode, SyntaxNode,
 };
 
 use crate::{
-    utils::{invert_boolean_expression_legacy, unwrap_trivial_block},
+    utils::{invert_boolean_expression, unwrap_trivial_block},
     AssistContext, AssistId, AssistKind, Assists,
 };
 
@@ -76,9 +77,9 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_>
         "Convert `if` expression to `bool::then` call",
         target,
         |builder| {
-            let closure_body = closure_body.clone_for_update();
+            let closure_body = closure_body.clone_subtree();
+            let mut editor = SyntaxEditor::new(closure_body.syntax().clone());
             // Rewrite all `Some(e)` in tail position to `e`
-            let mut replacements = Vec::new();
             for_each_tail_expr(&closure_body, &mut |e| {
                 let e = match e {
                     ast::Expr::BreakExpr(e) => e.expr(),
@@ -88,12 +89,16 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_>
                 if let Some(ast::Expr::CallExpr(call)) = e {
                     if let Some(arg_list) = call.arg_list() {
                         if let Some(arg) = arg_list.args().next() {
-                            replacements.push((call.syntax().clone(), arg.syntax().clone()));
+                            editor.replace(call.syntax(), arg.syntax());
                         }
                     }
                 }
             });
-            replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
+            let edit = editor.finish();
+            let closure_body = ast::Expr::cast(edit.new_root().clone()).unwrap();
+
+            let mut editor = builder.make_editor(expr.syntax());
+            let make = SyntaxFactory::new();
             let closure_body = match closure_body {
                 ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
                 e => e,
@@ -119,11 +124,18 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_>
                     | ast::Expr::WhileExpr(_)
                     | ast::Expr::YieldExpr(_)
             );
-            let cond = if invert_cond { invert_boolean_expression_legacy(cond) } else { cond };
-            let cond = if parenthesize { make::expr_paren(cond) } else { cond };
-            let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body)));
-            let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list);
-            builder.replace(target, mcall.to_string());
+            let cond = if invert_cond {
+                invert_boolean_expression(&make, cond)
+            } else {
+                cond.clone_for_update()
+            };
+            let cond = if parenthesize { make.expr_paren(cond).into() } else { cond };
+            let arg_list = make.arg_list(Some(make.expr_closure(None, closure_body).into()));
+            let mcall = make.expr_method_call(cond, make.name_ref("then"), arg_list);
+            editor.replace(expr.syntax(), mcall.syntax());
+
+            editor.add_mappings(make.finish_with_mappings());
+            builder.add_file_edits(ctx.file_id(), editor);
         },
     )
 }
@@ -173,16 +185,17 @@ pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_>
         "Convert `bool::then` call to `if`",
         target,
         |builder| {
-            let closure_body = match closure_body {
+            let mapless_make = SyntaxFactory::without_mappings();
+            let closure_body = match closure_body.reset_indent() {
                 ast::Expr::BlockExpr(block) => block,
-                e => make::block_expr(None, Some(e)),
+                e => mapless_make.block_expr(None, Some(e)),
             };
 
-            let closure_body = closure_body.clone_for_update();
+            let closure_body = closure_body.clone_subtree();
+            let mut editor = SyntaxEditor::new(closure_body.syntax().clone());
             // Wrap all tails in `Some(...)`
-            let none_path = make::expr_path(make::ext::ident_path("None"));
-            let some_path = make::expr_path(make::ext::ident_path("Some"));
-            let mut replacements = Vec::new();
+            let none_path = mapless_make.expr_path(mapless_make.ident_path("None"));
+            let some_path = mapless_make.expr_path(mapless_make.ident_path("Some"));
             for_each_tail_expr(&ast::Expr::BlockExpr(closure_body.clone()), &mut |e| {
                 let e = match e {
                     ast::Expr::BreakExpr(e) => e.expr(),
@@ -190,28 +203,37 @@ pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_>
                     _ => Some(e.clone()),
                 };
                 if let Some(expr) = e {
-                    replacements.push((
+                    editor.replace(
                         expr.syntax().clone(),
-                        make::expr_call(some_path.clone(), make::arg_list(Some(expr)))
+                        mapless_make
+                            .expr_call(some_path.clone(), mapless_make.arg_list(Some(expr)))
                             .syntax()
-                            .clone_for_update(),
-                    ));
+                            .clone(),
+                    );
                 }
             });
-            replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
+            let edit = editor.finish();
+            let closure_body = ast::BlockExpr::cast(edit.new_root().clone()).unwrap();
+
+            let mut editor = builder.make_editor(mcall.syntax());
+            let make = SyntaxFactory::new();
 
             let cond = match &receiver {
                 ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
                 _ => receiver,
             };
-            let if_expr = make::expr_if(
-                cond,
-                closure_body.reset_indent(),
-                Some(ast::ElseBranch::Block(make::block_expr(None, Some(none_path)))),
-            )
-            .indent(mcall.indent_level());
+            let if_expr = make
+                .expr_if(
+                    cond,
+                    closure_body,
+                    Some(ast::ElseBranch::Block(make.block_expr(None, Some(none_path)))),
+                )
+                .indent(mcall.indent_level())
+                .clone_for_update();
+            editor.replace(mcall.syntax().clone(), if_expr.syntax().clone());
 
-            builder.replace(target, if_expr.to_string());
+            editor.add_mappings(make.finish_with_mappings());
+            builder.add_file_edits(ctx.file_id(), editor);
         },
     )
 }
diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
index 19c5c64e218..85393ca5b4c 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
@@ -129,7 +129,7 @@ impl SyntaxFactory {
 
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
-            builder.map_children(input.into_iter(), ast.segments().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.segments().map(|it| it.syntax().clone()));
             builder.finish(&mut mapping);
         }
 
@@ -162,7 +162,7 @@ impl SyntaxFactory {
 
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
-            builder.map_children(input.into_iter(), ast.pats().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.pats().map(|it| it.syntax().clone()));
             builder.finish(&mut mapping);
         }
 
@@ -175,7 +175,7 @@ impl SyntaxFactory {
 
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
-            builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.fields().map(|it| it.syntax().clone()));
             builder.finish(&mut mapping);
         }
 
@@ -193,7 +193,7 @@ impl SyntaxFactory {
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
             builder.map_node(path.syntax().clone(), ast.path().unwrap().syntax().clone());
-            builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.fields().map(|it| it.syntax().clone()));
             builder.finish(&mut mapping);
         }
 
@@ -230,7 +230,7 @@ impl SyntaxFactory {
 
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
-            builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.fields().map(|it| it.syntax().clone()));
             if let Some(rest_pat) = rest_pat {
                 builder
                     .map_node(rest_pat.syntax().clone(), ast.rest_pat().unwrap().syntax().clone());
@@ -315,10 +315,7 @@ impl SyntaxFactory {
                 builder.map_node(last_stmt, ast_tail.syntax().clone());
             }
 
-            builder.map_children(
-                input.into_iter(),
-                stmt_list.statements().map(|it| it.syntax().clone()),
-            );
+            builder.map_children(input, stmt_list.statements().map(|it| it.syntax().clone()));
 
             builder.finish(&mut mapping);
         }
@@ -351,7 +348,7 @@ impl SyntaxFactory {
 
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
-            builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.fields().map(|it| it.syntax().clone()));
             builder.finish(&mut mapping);
         }
 
@@ -454,7 +451,7 @@ impl SyntaxFactory {
 
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax.clone());
-            builder.map_children(input.into_iter(), ast.args().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.args().map(|it| it.syntax().clone()));
             builder.finish(&mut mapping);
         }
 
@@ -476,6 +473,31 @@ impl SyntaxFactory {
         ast.into()
     }
 
+    pub fn expr_closure(
+        &self,
+        pats: impl IntoIterator<Item = ast::Param>,
+        expr: ast::Expr,
+    ) -> ast::ClosureExpr {
+        let (args, input) = iterator_input(pats);
+        // FIXME: `make::expr_paren` should return a `ClosureExpr`, not just an `Expr`
+        let ast::Expr::ClosureExpr(ast) = make::expr_closure(args, expr.clone()).clone_for_update()
+        else {
+            unreachable!()
+        };
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax.clone());
+            builder.map_children(
+                input,
+                ast.param_list().unwrap().params().map(|param| param.syntax().clone()),
+            );
+            builder.map_node(expr.syntax().clone(), ast.body().unwrap().syntax().clone());
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
     pub fn expr_return(&self, expr: Option<ast::Expr>) -> ast::ReturnExpr {
         let ast::Expr::ReturnExpr(ast) = make::expr_return(expr.clone()).clone_for_update() else {
             unreachable!()
@@ -604,7 +626,7 @@ impl SyntaxFactory {
 
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
-            builder.map_children(input.into_iter(), ast.arms().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.arms().map(|it| it.syntax().clone()));
             builder.finish(&mut mapping);
         }
 
@@ -727,6 +749,19 @@ impl SyntaxFactory {
         ast
     }
 
+    pub fn param(&self, pat: ast::Pat, ty: ast::Type) -> ast::Param {
+        let ast = make::param(pat.clone(), ty.clone());
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+            builder.map_node(pat.syntax().clone(), ast.pat().unwrap().syntax().clone());
+            builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone());
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
     pub fn generic_arg_list(
         &self,
         generic_args: impl IntoIterator<Item = ast::GenericArg>,
@@ -741,10 +776,7 @@ impl SyntaxFactory {
 
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
-            builder.map_children(
-                input.into_iter(),
-                ast.generic_args().map(|arg| arg.syntax().clone()),
-            );
+            builder.map_children(input, ast.generic_args().map(|arg| arg.syntax().clone()));
             builder.finish(&mut mapping);
         }
 
@@ -761,7 +793,7 @@ impl SyntaxFactory {
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
 
-            builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.fields().map(|it| it.syntax().clone()));
 
             builder.finish(&mut mapping);
         }
@@ -806,7 +838,7 @@ impl SyntaxFactory {
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
 
-            builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.fields().map(|it| it.syntax().clone()));
 
             builder.finish(&mut mapping);
         }
@@ -901,7 +933,7 @@ impl SyntaxFactory {
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
 
-            builder.map_children(input.into_iter(), ast.variants().map(|it| it.syntax().clone()));
+            builder.map_children(input, ast.variants().map(|it| it.syntax().clone()));
 
             builder.finish(&mut mapping);
         }
@@ -953,6 +985,69 @@ impl SyntaxFactory {
         ast
     }
 
+    pub fn fn_(
+        &self,
+        visibility: Option<ast::Visibility>,
+        fn_name: ast::Name,
+        type_params: Option<ast::GenericParamList>,
+        where_clause: Option<ast::WhereClause>,
+        params: ast::ParamList,
+        body: ast::BlockExpr,
+        ret_type: Option<ast::RetType>,
+        is_async: bool,
+        is_const: bool,
+        is_unsafe: bool,
+        is_gen: bool,
+    ) -> ast::Fn {
+        let ast = make::fn_(
+            visibility.clone(),
+            fn_name.clone(),
+            type_params.clone(),
+            where_clause.clone(),
+            params.clone(),
+            body.clone(),
+            ret_type.clone(),
+            is_async,
+            is_const,
+            is_unsafe,
+            is_gen,
+        );
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+
+            if let Some(visibility) = visibility {
+                builder.map_node(
+                    visibility.syntax().clone(),
+                    ast.visibility().unwrap().syntax().clone(),
+                );
+            }
+            builder.map_node(fn_name.syntax().clone(), ast.name().unwrap().syntax().clone());
+            if let Some(type_params) = type_params {
+                builder.map_node(
+                    type_params.syntax().clone(),
+                    ast.generic_param_list().unwrap().syntax().clone(),
+                );
+            }
+            if let Some(where_clause) = where_clause {
+                builder.map_node(
+                    where_clause.syntax().clone(),
+                    ast.where_clause().unwrap().syntax().clone(),
+                );
+            }
+            builder.map_node(params.syntax().clone(), ast.param_list().unwrap().syntax().clone());
+            builder.map_node(body.syntax().clone(), ast.body().unwrap().syntax().clone());
+            if let Some(ret_type) = ret_type {
+                builder
+                    .map_node(ret_type.syntax().clone(), ast.ret_type().unwrap().syntax().clone());
+            }
+
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
     pub fn token_tree(
         &self,
         delimiter: SyntaxKind,
@@ -965,10 +1060,7 @@ impl SyntaxFactory {
 
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
-            builder.map_children(
-                input.into_iter(),
-                ast.token_trees_and_tokens().filter_map(only_nodes),
-            );
+            builder.map_children(input, ast.token_trees_and_tokens().filter_map(only_nodes));
             builder.finish(&mut mapping);
         }
 
diff --git a/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/mapping.rs b/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/mapping.rs
index 16bc55ed2d4..f71925a7955 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/mapping.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/mapping.rs
@@ -239,10 +239,10 @@ impl SyntaxMappingBuilder {
 
     pub fn map_children(
         &mut self,
-        input: impl Iterator<Item = SyntaxNode>,
-        output: impl Iterator<Item = SyntaxNode>,
+        input: impl IntoIterator<Item = SyntaxNode>,
+        output: impl IntoIterator<Item = SyntaxNode>,
     ) {
-        for pairs in input.zip_longest(output) {
+        for pairs in input.into_iter().zip_longest(output) {
             let (input, output) = match pairs {
                 itertools::EitherOrBoth::Both(l, r) => (l, r),
                 itertools::EitherOrBoth::Left(_) => {
diff --git a/src/tools/rust-analyzer/docs/book/src/assists_generated.md b/src/tools/rust-analyzer/docs/book/src/assists_generated.md
index 2d233ca62ad..72cecc2b02d 100644
--- a/src/tools/rust-analyzer/docs/book/src/assists_generated.md
+++ b/src/tools/rust-analyzer/docs/book/src/assists_generated.md
@@ -419,7 +419,7 @@ Converts comments to documentation.
 
 
 ### `convert_bool_then_to_if`
-**Source:**  [convert_bool_then.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_bool_then.rs#L131) 
+**Source:**  [convert_bool_then.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_bool_then.rs#L143) 
 
 Converts a `bool::then` method call to an equivalent if expression.
 
@@ -443,7 +443,7 @@ fn main() {
 
 
 ### `convert_closure_to_fn`
-**Source:**  [convert_closure_to_fn.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_closure_to_fn.rs#L25) 
+**Source:**  [convert_closure_to_fn.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_closure_to_fn.rs#L27) 
 
 This converts a closure to a freestanding function, changing all captures to parameters.
 
@@ -527,7 +527,7 @@ impl TryFrom<usize> for Thing {
 
 
 ### `convert_if_to_bool_then`
-**Source:**  [convert_bool_then.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_bool_then.rs#L20) 
+**Source:**  [convert_bool_then.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_bool_then.rs#L21) 
 
 Converts an if expression into a corresponding `bool::then` call.
 
@@ -2258,7 +2258,7 @@ fn bar() {
 
 
 ### `inline_local_variable`
-**Source:**  [inline_local_variable.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/inline_local_variable.rs#L17) 
+**Source:**  [inline_local_variable.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/inline_local_variable.rs#L21) 
 
 Inlines a local variable.