about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-assists/src/handlers/bool_to_enum.rs227
-rw-r--r--crates/ide-assists/src/handlers/promote_local_to_const.rs89
-rw-r--r--crates/ide-assists/src/utils.rs18
3 files changed, 264 insertions, 70 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs
index 0f2d1057c0a..b7b00e7ed06 100644
--- a/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -16,11 +16,14 @@ use syntax::{
         edit_in_place::{AttrsOwnerEdit, Indent},
         make, HasName,
     },
-    ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
+    AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
 };
 use text_edit::TextRange;
 
-use crate::assist_context::{AssistContext, Assists};
+use crate::{
+    assist_context::{AssistContext, Assists},
+    utils,
+};
 
 // Assist: bool_to_enum
 //
@@ -73,7 +76,7 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
 
             let usages = definition.usages(&ctx.sema).all();
             add_enum_def(edit, ctx, &usages, target_node, &target_module);
-            replace_usages(edit, ctx, &usages, definition, &target_module);
+            replace_usages(edit, ctx, usages, definition, &target_module);
         },
     )
 }
@@ -169,8 +172,8 @@ fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr) {
 
 /// Converts an expression of type `bool` to one of the new enum type.
 fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
-    let true_expr = make::expr_path(make::path_from_text("Bool::True")).clone_for_update();
-    let false_expr = make::expr_path(make::path_from_text("Bool::False")).clone_for_update();
+    let true_expr = make::expr_path(make::path_from_text("Bool::True"));
+    let false_expr = make::expr_path(make::path_from_text("Bool::False"));
 
     if let ast::Expr::Literal(literal) = &expr {
         match literal.kind() {
@@ -184,7 +187,6 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
             make::tail_only_block_expr(true_expr),
             Some(ast::ElseBranch::Block(make::tail_only_block_expr(false_expr))),
         )
-        .clone_for_update()
     }
 }
 
@@ -192,21 +194,19 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
 fn replace_usages(
     edit: &mut SourceChangeBuilder,
     ctx: &AssistContext<'_>,
-    usages: &UsageSearchResult,
+    usages: UsageSearchResult,
     target_definition: Definition,
     target_module: &hir::Module,
 ) {
-    for (file_id, references) in usages.iter() {
-        edit.edit_file(*file_id);
+    for (file_id, references) in usages {
+        edit.edit_file(file_id);
 
-        let refs_with_imports =
-            augment_references_with_imports(edit, ctx, references, target_module);
+        let refs_with_imports = augment_references_with_imports(ctx, references, target_module);
 
         refs_with_imports.into_iter().rev().for_each(
-            |FileReferenceWithImport { range, old_name, new_name, import_data }| {
+            |FileReferenceWithImport { range, name, import_data }| {
                 // replace the usages in patterns and expressions
-                if let Some(ident_pat) = old_name.syntax().ancestors().find_map(ast::IdentPat::cast)
-                {
+                if let Some(ident_pat) = name.syntax().ancestors().find_map(ast::IdentPat::cast) {
                     cov_mark::hit!(replaces_record_pat_shorthand);
 
                     let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
@@ -214,36 +214,35 @@ fn replace_usages(
                         replace_usages(
                             edit,
                             ctx,
-                            &def.usages(&ctx.sema).all(),
+                            def.usages(&ctx.sema).all(),
                             target_definition,
                             target_module,
                         )
                     }
-                } else if let Some(initializer) = find_assignment_usage(&new_name) {
+                } else if let Some(initializer) = find_assignment_usage(&name) {
                     cov_mark::hit!(replaces_assignment);
 
                     replace_bool_expr(edit, initializer);
-                } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&new_name) {
+                } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&name) {
                     cov_mark::hit!(replaces_negation);
 
                     edit.replace(
                         prefix_expr.syntax().text_range(),
                         format!("{} == Bool::False", inner_expr),
                     );
-                } else if let Some((record_field, initializer)) = old_name
+                } else if let Some((record_field, initializer)) = name
                     .as_name_ref()
                     .and_then(ast::RecordExprField::for_field_name)
                     .and_then(|record_field| ctx.sema.resolve_record_field(&record_field))
                     .and_then(|(got_field, _, _)| {
-                        find_record_expr_usage(&new_name, got_field, target_definition)
+                        find_record_expr_usage(&name, got_field, target_definition)
                     })
                 {
                     cov_mark::hit!(replaces_record_expr);
 
-                    let record_field = edit.make_mut(record_field);
                     let enum_expr = bool_expr_to_enum_expr(initializer);
-                    record_field.replace_expr(enum_expr);
-                } else if let Some(pat) = find_record_pat_field_usage(&old_name) {
+                    utils::replace_record_field_expr(ctx, edit, record_field, enum_expr);
+                } else if let Some(pat) = find_record_pat_field_usage(&name) {
                     match pat {
                         ast::Pat::IdentPat(ident_pat) => {
                             cov_mark::hit!(replaces_record_pat);
@@ -253,7 +252,7 @@ fn replace_usages(
                                 replace_usages(
                                     edit,
                                     ctx,
-                                    &def.usages(&ctx.sema).all(),
+                                    def.usages(&ctx.sema).all(),
                                     target_definition,
                                     target_module,
                                 )
@@ -270,40 +269,44 @@ fn replace_usages(
                         }
                         _ => (),
                     }
-                } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name)
-                {
+                } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&name) {
                     edit.replace(ty_annotation.syntax().text_range(), "Bool");
                     replace_bool_expr(edit, initializer);
-                } else if let Some(receiver) = find_method_call_expr_usage(&new_name) {
+                } else if let Some(receiver) = find_method_call_expr_usage(&name) {
                     edit.replace(
                         receiver.syntax().text_range(),
                         format!("({} == Bool::True)", receiver),
                     );
-                } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
+                } else if name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
                     // for any other usage in an expression, replace it with a check that it is the true variant
-                    if let Some((record_field, expr)) = new_name
-                        .as_name_ref()
-                        .and_then(ast::RecordExprField::for_field_name)
-                        .and_then(|record_field| {
-                            record_field.expr().map(|expr| (record_field, expr))
-                        })
+                    if let Some((record_field, expr)) =
+                        name.as_name_ref().and_then(ast::RecordExprField::for_field_name).and_then(
+                            |record_field| record_field.expr().map(|expr| (record_field, expr)),
+                        )
                     {
-                        record_field.replace_expr(
+                        utils::replace_record_field_expr(
+                            ctx,
+                            edit,
+                            record_field,
                             make::expr_bin_op(
                                 expr,
                                 ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }),
                                 make::expr_path(make::path_from_text("Bool::True")),
-                            )
-                            .clone_for_update(),
+                            ),
                         );
                     } else {
-                        edit.replace(range, format!("{} == Bool::True", new_name.text()));
+                        edit.replace(range, format!("{} == Bool::True", name.text()));
                     }
                 }
 
                 // add imports across modules where needed
                 if let Some((import_scope, path)) = import_data {
-                    insert_use(&import_scope, path, &ctx.config.insert_use);
+                    let scope = match import_scope.clone() {
+                        ImportScope::File(it) => ImportScope::File(edit.make_mut(it)),
+                        ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)),
+                        ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)),
+                    };
+                    insert_use(&scope, path, &ctx.config.insert_use);
                 }
             },
         )
@@ -312,37 +315,31 @@ fn replace_usages(
 
 struct FileReferenceWithImport {
     range: TextRange,
-    old_name: ast::NameLike,
-    new_name: ast::NameLike,
+    name: ast::NameLike,
     import_data: Option<(ImportScope, ast::Path)>,
 }
 
 fn augment_references_with_imports(
-    edit: &mut SourceChangeBuilder,
     ctx: &AssistContext<'_>,
-    references: &[FileReference],
+    references: Vec<FileReference>,
     target_module: &hir::Module,
 ) -> Vec<FileReferenceWithImport> {
     let mut visited_modules = FxHashSet::default();
 
     references
-        .iter()
+        .into_iter()
         .filter_map(|FileReference { range, name, .. }| {
             let name = name.clone().into_name_like()?;
-            ctx.sema.scope(name.syntax()).map(|scope| (*range, name, scope.module()))
+            ctx.sema.scope(name.syntax()).map(|scope| (range, name, scope.module()))
         })
         .map(|(range, name, ref_module)| {
-            let old_name = name.clone();
-            let new_name = edit.make_mut(name.clone());
-
             // if the referenced module is not the same as the target one and has not been seen before, add an import
             let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module
                 && !visited_modules.contains(&ref_module)
             {
                 visited_modules.insert(ref_module);
 
-                let import_scope =
-                    ImportScope::find_insert_use_container(new_name.syntax(), &ctx.sema);
+                let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema);
                 let path = ref_module
                     .find_use_path_prefixed(
                         ctx.sema.db,
@@ -360,7 +357,7 @@ fn augment_references_with_imports(
                 None
             };
 
-            FileReferenceWithImport { range, old_name, new_name, import_data }
+            FileReferenceWithImport { range, name, import_data }
         })
         .collect()
 }
@@ -405,13 +402,12 @@ fn find_record_expr_usage(
     let record_field = ast::RecordExprField::for_field_name(name_ref)?;
     let initializer = record_field.expr()?;
 
-    if let Definition::Field(expected_field) = target_definition {
-        if got_field != expected_field {
-            return None;
+    match target_definition {
+        Definition::Field(expected_field) if got_field == expected_field => {
+            Some((record_field, initializer))
         }
+        _ => None,
     }
-
-    Some((record_field, initializer))
 }
 
 fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
@@ -466,12 +462,9 @@ fn add_enum_def(
     let indent = IndentLevel::from_node(&insert_before);
     enum_def.reindent_to(indent);
 
-    ted::insert_all(
-        ted::Position::before(&edit.make_syntax_mut(insert_before)),
-        vec![
-            enum_def.syntax().clone().into(),
-            make::tokens::whitespace(&format!("\n\n{indent}")).into(),
-        ],
+    edit.insert(
+        insert_before.text_range().start(),
+        format!("{}\n\n{indent}", enum_def.syntax().text()),
     );
 }
 
@@ -801,6 +794,78 @@ fn main() {
     }
 
     #[test]
+    fn local_var_init_struct_usage() {
+        check_assist(
+            bool_to_enum,
+            r#"
+struct Foo {
+    foo: bool,
+}
+
+fn main() {
+    let $0foo = true;
+    let s = Foo { foo };
+}
+"#,
+            r#"
+struct Foo {
+    foo: bool,
+}
+
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+    let foo = Bool::True;
+    let s = Foo { foo: foo == Bool::True };
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn local_var_init_struct_usage_in_macro() {
+        check_assist(
+            bool_to_enum,
+            r#"
+struct Struct {
+    boolean: bool,
+}
+
+macro_rules! identity {
+    ($body:expr) => {
+        $body
+    }
+}
+
+fn new() -> Struct {
+    let $0boolean = true;
+    identity![Struct { boolean }]
+}
+"#,
+            r#"
+struct Struct {
+    boolean: bool,
+}
+
+macro_rules! identity {
+    ($body:expr) => {
+        $body
+    }
+}
+
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn new() -> Struct {
+    let boolean = Bool::True;
+    identity![Struct { boolean: boolean == Bool::True }]
+}
+"#,
+        )
+    }
+
+    #[test]
     fn field_struct_basic() {
         cov_mark::check!(replaces_record_expr);
         check_assist(
@@ -1322,6 +1387,46 @@ fn main() {
     }
 
     #[test]
+    fn field_in_macro() {
+        check_assist(
+            bool_to_enum,
+            r#"
+struct Struct {
+    $0boolean: bool,
+}
+
+fn boolean(x: Struct) {
+    let Struct { boolean } = x;
+}
+
+macro_rules! identity { ($body:expr) => { $body } }
+
+fn new() -> Struct {
+    identity!(Struct { boolean: true })
+}
+"#,
+            r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Struct {
+    boolean: Bool,
+}
+
+fn boolean(x: Struct) {
+    let Struct { boolean } = x;
+}
+
+macro_rules! identity { ($body:expr) => { $body } }
+
+fn new() -> Struct {
+    identity!(Struct { boolean: Bool::True })
+}
+"#,
+        )
+    }
+
+    #[test]
     fn field_non_bool() {
         cov_mark::check!(not_applicable_non_bool_field);
         check_assist_not_applicable(
diff --git a/crates/ide-assists/src/handlers/promote_local_to_const.rs b/crates/ide-assists/src/handlers/promote_local_to_const.rs
index 6ed9bd85fcc..67fea772c79 100644
--- a/crates/ide-assists/src/handlers/promote_local_to_const.rs
+++ b/crates/ide-assists/src/handlers/promote_local_to_const.rs
@@ -11,7 +11,10 @@ use syntax::{
     ted, AstNode, WalkEvent,
 };
 
-use crate::assist_context::{AssistContext, Assists};
+use crate::{
+    assist_context::{AssistContext, Assists},
+    utils,
+};
 
 // Assist: promote_local_to_const
 //
@@ -79,15 +82,13 @@ pub(crate) fn promote_local_to_const(acc: &mut Assists, ctx: &AssistContext<'_>)
                 let name_ref = make::name_ref(&name);
 
                 for usage in usages {
-                    let Some(usage) = usage.name.as_name_ref().cloned() else { continue };
-                    if let Some(record_field) = ast::RecordExprField::for_name_ref(&usage) {
-                        let record_field = edit.make_mut(record_field);
-                        let name_expr =
-                            make::expr_path(make::path_from_text(&name)).clone_for_update();
-                        record_field.replace_expr(name_expr);
+                    let Some(usage_name) = usage.name.as_name_ref().cloned() else { continue };
+                    if let Some(record_field) = ast::RecordExprField::for_name_ref(&usage_name) {
+                        let name_expr = make::expr_path(make::path_from_text(&name));
+                        utils::replace_record_field_expr(ctx, edit, record_field, name_expr);
                     } else {
-                        let usage = edit.make_mut(usage);
-                        ted::replace(usage.syntax(), name_ref.clone_for_update().syntax());
+                        let usage_range = usage.range;
+                        edit.replace(usage_range, name_ref.syntax().text());
                     }
                 }
             }
@@ -213,6 +214,76 @@ fn main() {
     }
 
     #[test]
+    fn usage_in_macro() {
+        check_assist(
+            promote_local_to_const,
+            r"
+macro_rules! identity {
+    ($body:expr) => {
+        $body
+    }
+}
+
+fn baz() -> usize {
+    let $0foo = 2;
+    identity![foo]
+}
+",
+            r"
+macro_rules! identity {
+    ($body:expr) => {
+        $body
+    }
+}
+
+fn baz() -> usize {
+    const $0FOO: usize = 2;
+    identity![FOO]
+}
+",
+        )
+    }
+
+    #[test]
+    fn usage_shorthand_in_macro() {
+        check_assist(
+            promote_local_to_const,
+            r"
+struct Foo {
+    foo: usize,
+}
+
+macro_rules! identity {
+    ($body:expr) => {
+        $body
+    };
+}
+
+fn baz() -> Foo {
+    let $0foo = 2;
+    identity![Foo { foo }]
+}
+",
+            r"
+struct Foo {
+    foo: usize,
+}
+
+macro_rules! identity {
+    ($body:expr) => {
+        $body
+    };
+}
+
+fn baz() -> Foo {
+    const $0FOO: usize = 2;
+    identity![Foo { foo: FOO }]
+}
+",
+        )
+    }
+
+    #[test]
     fn not_applicable_non_const_meth_call() {
         cov_mark::check!(promote_local_non_const);
         check_assist_not_applicable(
diff --git a/crates/ide-assists/src/utils.rs b/crates/ide-assists/src/utils.rs
index f51e99a914e..927a8e3c19a 100644
--- a/crates/ide-assists/src/utils.rs
+++ b/crates/ide-assists/src/utils.rs
@@ -813,3 +813,21 @@ fn test_required_hashes() {
     assert_eq!(3, required_hashes("#ab\"##c"));
     assert_eq!(5, required_hashes("#ab\"##\"####c"));
 }
+
+/// Replaces the record expression, handling field shorthands including inside macros.
+pub(crate) fn replace_record_field_expr(
+    ctx: &AssistContext<'_>,
+    edit: &mut SourceChangeBuilder,
+    record_field: ast::RecordExprField,
+    initializer: ast::Expr,
+) {
+    if let Some(ast::Expr::PathExpr(path_expr)) = record_field.expr() {
+        // replace field shorthand
+        let file_range = ctx.sema.original_range(path_expr.syntax());
+        edit.insert(file_range.range.end(), format!(": {}", initializer.syntax().text()))
+    } else if let Some(expr) = record_field.expr() {
+        // just replace expr
+        let file_range = ctx.sema.original_range(expr.syntax());
+        edit.replace(file_range.range, initializer.syntax().text());
+    }
+}