about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-assists/src/handlers/add_missing_match_arms.rs133
-rw-r--r--crates/ide-assists/src/handlers/fix_visibility.rs105
-rw-r--r--crates/ide-assists/src/handlers/promote_local_to_const.rs56
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs21
4 files changed, 182 insertions, 133 deletions
diff --git a/crates/ide-assists/src/handlers/add_missing_match_arms.rs b/crates/ide-assists/src/handlers/add_missing_match_arms.rs
index 7384390f28b..ac0b74ee8e7 100644
--- a/crates/ide-assists/src/handlers/add_missing_match_arms.rs
+++ b/crates/ide-assists/src/handlers/add_missing_match_arms.rs
@@ -8,10 +8,7 @@ use itertools::Itertools;
 use syntax::ast::edit_in_place::Removable;
 use syntax::ast::{self, make, AstNode, HasName, MatchArmList, MatchExpr, Pat};
 
-use crate::{
-    utils::{self, render_snippet, Cursor},
-    AssistContext, AssistId, AssistKind, Assists,
-};
+use crate::{utils, AssistContext, AssistId, AssistKind, Assists};
 
 // Assist: add_missing_match_arms
 //
@@ -75,14 +72,18 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
         .collect();
 
     let module = ctx.sema.scope(expr.syntax())?.module();
-    let (mut missing_pats, is_non_exhaustive): (
+    let (mut missing_pats, is_non_exhaustive, has_hidden_variants): (
         Peekable<Box<dyn Iterator<Item = (ast::Pat, bool)>>>,
         bool,
+        bool,
     ) = if let Some(enum_def) = resolve_enum_def(&ctx.sema, &expr) {
         let is_non_exhaustive = enum_def.is_non_exhaustive(ctx.db(), module.krate());
 
         let variants = enum_def.variants(ctx.db());
 
+        let has_hidden_variants =
+            variants.iter().any(|variant| variant.should_be_hidden(ctx.db(), module.krate()));
+
         let missing_pats = variants
             .into_iter()
             .filter_map(|variant| {
@@ -101,7 +102,7 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
         } else {
             Box::new(missing_pats)
         };
-        (missing_pats.peekable(), is_non_exhaustive)
+        (missing_pats.peekable(), is_non_exhaustive, has_hidden_variants)
     } else if let Some(enum_defs) = resolve_tuple_of_enum_def(&ctx.sema, &expr) {
         let is_non_exhaustive =
             enum_defs.iter().any(|enum_def| enum_def.is_non_exhaustive(ctx.db(), module.krate()));
@@ -124,6 +125,12 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
         if n_arms > 256 {
             return None;
         }
+
+        let has_hidden_variants = variants_of_enums
+            .iter()
+            .flatten()
+            .any(|variant| variant.should_be_hidden(ctx.db(), module.krate()));
+
         let missing_pats = variants_of_enums
             .into_iter()
             .multi_cartesian_product()
@@ -139,7 +146,11 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
                 (ast::Pat::from(make::tuple_pat(patterns)), is_hidden)
             })
             .filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat));
-        ((Box::new(missing_pats) as Box<dyn Iterator<Item = _>>).peekable(), is_non_exhaustive)
+        (
+            (Box::new(missing_pats) as Box<dyn Iterator<Item = _>>).peekable(),
+            is_non_exhaustive,
+            has_hidden_variants,
+        )
     } else if let Some((enum_def, len)) = resolve_array_of_enum_def(&ctx.sema, &expr) {
         let is_non_exhaustive = enum_def.is_non_exhaustive(ctx.db(), module.krate());
         let variants = enum_def.variants(ctx.db());
@@ -148,6 +159,9 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
             return None;
         }
 
+        let has_hidden_variants =
+            variants.iter().any(|variant| variant.should_be_hidden(ctx.db(), module.krate()));
+
         let variants_of_enums = vec![variants; len];
 
         let missing_pats = variants_of_enums
@@ -164,14 +178,20 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
                 (ast::Pat::from(make::slice_pat(patterns)), is_hidden)
             })
             .filter(|(variant_pat, _)| is_variant_missing(&top_lvl_pats, variant_pat));
-        ((Box::new(missing_pats) as Box<dyn Iterator<Item = _>>).peekable(), is_non_exhaustive)
+        (
+            (Box::new(missing_pats) as Box<dyn Iterator<Item = _>>).peekable(),
+            is_non_exhaustive,
+            has_hidden_variants,
+        )
     } else {
         return None;
     };
 
     let mut needs_catch_all_arm = is_non_exhaustive && !has_catch_all_arm;
 
-    if !needs_catch_all_arm && missing_pats.peek().is_none() {
+    if !needs_catch_all_arm
+        && ((has_hidden_variants && has_catch_all_arm) || missing_pats.peek().is_none())
+    {
         return None;
     }
 
@@ -179,13 +199,21 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
         AssistId("add_missing_match_arms", AssistKind::QuickFix),
         "Fill match arms",
         target_range,
-        |builder| {
+        |edit| {
             let new_match_arm_list = match_arm_list.clone_for_update();
+
+            // having any hidden variants means that we need a catch-all arm
+            needs_catch_all_arm |= has_hidden_variants;
+
             let missing_arms = missing_pats
-                .map(|(pat, hidden)| {
-                    (make::match_arm(iter::once(pat), None, make::ext::expr_todo()), hidden)
+                .filter(|(_, hidden)| {
+                    // filter out hidden patterns because they're handled by the catch-all arm
+                    !hidden
                 })
-                .map(|(it, hidden)| (it.clone_for_update(), hidden));
+                .map(|(pat, _)| {
+                    make::match_arm(iter::once(pat), None, make::ext::expr_todo())
+                        .clone_for_update()
+                });
 
             let catch_all_arm = new_match_arm_list
                 .arms()
@@ -204,15 +232,13 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
                     cov_mark::hit!(add_missing_match_arms_empty_expr);
                 }
             }
+
             let mut first_new_arm = None;
-            for (arm, hidden) in missing_arms {
-                if hidden {
-                    needs_catch_all_arm = !has_catch_all_arm;
-                } else {
-                    first_new_arm.get_or_insert_with(|| arm.clone());
-                    new_match_arm_list.add_arm(arm);
-                }
+            for arm in missing_arms {
+                first_new_arm.get_or_insert_with(|| arm.clone());
+                new_match_arm_list.add_arm(arm);
             }
+
             if needs_catch_all_arm && !has_catch_all_arm {
                 cov_mark::hit!(added_wildcard_pattern);
                 let arm = make::match_arm(
@@ -225,24 +251,39 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>)
                 new_match_arm_list.add_arm(arm);
             }
 
-            let old_range = ctx.sema.original_range(match_arm_list.syntax()).range;
-            match (first_new_arm, ctx.config.snippet_cap) {
-                (Some(first_new_arm), Some(cap)) => {
-                    let extend_lifetime;
-                    let cursor =
-                        match first_new_arm.syntax().descendants().find_map(ast::WildcardPat::cast)
-                        {
-                            Some(it) => {
-                                extend_lifetime = it.syntax().clone();
-                                Cursor::Replace(&extend_lifetime)
-                            }
-                            None => Cursor::Before(first_new_arm.syntax()),
-                        };
-                    let snippet = render_snippet(cap, new_match_arm_list.syntax(), cursor);
-                    builder.replace_snippet(cap, old_range, snippet);
+            if let (Some(first_new_arm), Some(cap)) = (first_new_arm, ctx.config.snippet_cap) {
+                match first_new_arm.syntax().descendants().find_map(ast::WildcardPat::cast) {
+                    Some(it) => edit.add_placeholder_snippet(cap, it),
+                    None => edit.add_tabstop_before(cap, first_new_arm),
                 }
-                _ => builder.replace(old_range, new_match_arm_list.to_string()),
             }
+
+            // FIXME: Hack for mutable syntax trees not having great support for macros
+            // Just replace the element that the original range came from
+            let old_place = {
+                // Find the original element
+                let old_file_range = ctx.sema.original_range(match_arm_list.syntax());
+                let file = ctx.sema.parse(old_file_range.file_id);
+                let old_place = file.syntax().covering_element(old_file_range.range);
+
+                // Make `old_place` mut
+                match old_place {
+                    syntax::SyntaxElement::Node(it) => {
+                        syntax::SyntaxElement::from(edit.make_syntax_mut(it))
+                    }
+                    syntax::SyntaxElement::Token(it) => {
+                        // Don't have a way to make tokens mut, so instead make the parent mut
+                        // and find the token again
+                        let parent = edit.make_syntax_mut(it.parent().unwrap());
+                        let mut_token =
+                            parent.covering_element(it.text_range()).into_token().unwrap();
+
+                        syntax::SyntaxElement::from(mut_token)
+                    }
+                }
+            };
+
+            syntax::ted::replace(old_place, new_match_arm_list.syntax());
         },
     )
 }
@@ -1621,10 +1662,9 @@ pub enum E { #[doc(hidden)] A, }
         );
     }
 
-    // FIXME: I don't think the assist should be applicable in this case
     #[test]
     fn does_not_fill_wildcard_with_wildcard() {
-        check_assist(
+        check_assist_not_applicable(
             add_missing_match_arms,
             r#"
 //- /main.rs crate:main deps:e
@@ -1636,13 +1676,6 @@ fn foo(t: ::e::E) {
 //- /e.rs crate:e
 pub enum E { #[doc(hidden)] A, }
 "#,
-            r#"
-fn foo(t: ::e::E) {
-    match t {
-        _ => todo!(),
-    }
-}
-"#,
         );
     }
 
@@ -1777,7 +1810,7 @@ fn foo(t: ::e::E, b: bool) {
 
     #[test]
     fn does_not_fill_wildcard_with_partial_wildcard_and_wildcard() {
-        check_assist(
+        check_assist_not_applicable(
             add_missing_match_arms,
             r#"
 //- /main.rs crate:main deps:e
@@ -1789,14 +1822,6 @@ fn foo(t: ::e::E, b: bool) {
 }
 //- /e.rs crate:e
 pub enum E { #[doc(hidden)] A, }"#,
-            r#"
-fn foo(t: ::e::E, b: bool) {
-    match t {
-        _ if b => todo!(),
-        _ => todo!(),
-    }
-}
-"#,
         );
     }
 
diff --git a/crates/ide-assists/src/handlers/fix_visibility.rs b/crates/ide-assists/src/handlers/fix_visibility.rs
index 4a2ab18c983..c9f272474e7 100644
--- a/crates/ide-assists/src/handlers/fix_visibility.rs
+++ b/crates/ide-assists/src/handlers/fix_visibility.rs
@@ -1,11 +1,11 @@
 use hir::{db::HirDatabase, HasSource, HasVisibility, ModuleDef, PathResolution, ScopeDef};
 use ide_db::base_db::FileId;
 use syntax::{
-    ast::{self, HasVisibility as _},
-    AstNode, TextRange, TextSize,
+    ast::{self, edit_in_place::HasVisibilityEdit, make, HasVisibility as _},
+    AstNode, TextRange,
 };
 
-use crate::{utils::vis_offset, AssistContext, AssistId, AssistKind, Assists};
+use crate::{AssistContext, AssistId, AssistKind, Assists};
 
 // FIXME: this really should be a fix for diagnostic, rather than an assist.
 
@@ -58,11 +58,13 @@ fn add_vis_to_referenced_module_def(acc: &mut Assists, ctx: &AssistContext<'_>)
         return None;
     };
 
-    let (offset, current_visibility, target, target_file, target_name) =
-        target_data_for_def(ctx.db(), def)?;
+    let (vis_owner, target, target_file, target_name) = target_data_for_def(ctx.db(), def)?;
 
-    let missing_visibility =
-        if current_module.krate() == target_module.krate() { "pub(crate)" } else { "pub" };
+    let missing_visibility = if current_module.krate() == target_module.krate() {
+        make::visibility_pub_crate()
+    } else {
+        make::visibility_pub()
+    };
 
     let assist_label = match target_name {
         None => format!("Change visibility to {missing_visibility}"),
@@ -71,23 +73,14 @@ fn add_vis_to_referenced_module_def(acc: &mut Assists, ctx: &AssistContext<'_>)
         }
     };
 
-    acc.add(AssistId("fix_visibility", AssistKind::QuickFix), assist_label, target, |builder| {
-        builder.edit_file(target_file);
-        match ctx.config.snippet_cap {
-            Some(cap) => match current_visibility {
-                Some(current_visibility) => builder.replace_snippet(
-                    cap,
-                    current_visibility.syntax().text_range(),
-                    format!("$0{missing_visibility}"),
-                ),
-                None => builder.insert_snippet(cap, offset, format!("$0{missing_visibility} ")),
-            },
-            None => match current_visibility {
-                Some(current_visibility) => {
-                    builder.replace(current_visibility.syntax().text_range(), missing_visibility)
-                }
-                None => builder.insert(offset, format!("{missing_visibility} ")),
-            },
+    acc.add(AssistId("fix_visibility", AssistKind::QuickFix), assist_label, target, |edit| {
+        edit.edit_file(target_file);
+
+        let vis_owner = edit.make_mut(vis_owner);
+        vis_owner.set_visibility(missing_visibility.clone_for_update());
+
+        if let Some((cap, vis)) = ctx.config.snippet_cap.zip(vis_owner.visibility()) {
+            edit.add_tabstop_before(cap, vis);
         }
     })
 }
@@ -107,19 +100,22 @@ fn add_vis_to_referenced_record_field(acc: &mut Assists, ctx: &AssistContext<'_>
     let target_module = parent.module(ctx.db());
 
     let in_file_source = record_field_def.source(ctx.db())?;
-    let (offset, current_visibility, target) = match in_file_source.value {
+    let (vis_owner, target) = match in_file_source.value {
         hir::FieldSource::Named(it) => {
-            let s = it.syntax();
-            (vis_offset(s), it.visibility(), s.text_range())
+            let range = it.syntax().text_range();
+            (ast::AnyHasVisibility::new(it), range)
         }
         hir::FieldSource::Pos(it) => {
-            let s = it.syntax();
-            (vis_offset(s), it.visibility(), s.text_range())
+            let range = it.syntax().text_range();
+            (ast::AnyHasVisibility::new(it), range)
         }
     };
 
-    let missing_visibility =
-        if current_module.krate() == target_module.krate() { "pub(crate)" } else { "pub" };
+    let missing_visibility = if current_module.krate() == target_module.krate() {
+        make::visibility_pub_crate()
+    } else {
+        make::visibility_pub()
+    };
     let target_file = in_file_source.file_id.original_file(ctx.db());
 
     let target_name = record_field_def.name(ctx.db());
@@ -129,23 +125,14 @@ fn add_vis_to_referenced_record_field(acc: &mut Assists, ctx: &AssistContext<'_>
         target_name.display(ctx.db())
     );
 
-    acc.add(AssistId("fix_visibility", AssistKind::QuickFix), assist_label, target, |builder| {
-        builder.edit_file(target_file);
-        match ctx.config.snippet_cap {
-            Some(cap) => match current_visibility {
-                Some(current_visibility) => builder.replace_snippet(
-                    cap,
-                    current_visibility.syntax().text_range(),
-                    format!("$0{missing_visibility}"),
-                ),
-                None => builder.insert_snippet(cap, offset, format!("$0{missing_visibility} ")),
-            },
-            None => match current_visibility {
-                Some(current_visibility) => {
-                    builder.replace(current_visibility.syntax().text_range(), missing_visibility)
-                }
-                None => builder.insert(offset, format!("{missing_visibility} ")),
-            },
+    acc.add(AssistId("fix_visibility", AssistKind::QuickFix), assist_label, target, |edit| {
+        edit.edit_file(target_file);
+
+        let vis_owner = edit.make_mut(vis_owner);
+        vis_owner.set_visibility(missing_visibility.clone_for_update());
+
+        if let Some((cap, vis)) = ctx.config.snippet_cap.zip(vis_owner.visibility()) {
+            edit.add_tabstop_before(cap, vis);
         }
     })
 }
@@ -153,11 +140,11 @@ fn add_vis_to_referenced_record_field(acc: &mut Assists, ctx: &AssistContext<'_>
 fn target_data_for_def(
     db: &dyn HirDatabase,
     def: hir::ModuleDef,
-) -> Option<(TextSize, Option<ast::Visibility>, TextRange, FileId, Option<hir::Name>)> {
+) -> Option<(ast::AnyHasVisibility, TextRange, FileId, Option<hir::Name>)> {
     fn offset_target_and_file_id<S, Ast>(
         db: &dyn HirDatabase,
         x: S,
-    ) -> Option<(TextSize, Option<ast::Visibility>, TextRange, FileId)>
+    ) -> Option<(ast::AnyHasVisibility, TextRange, FileId)>
     where
         S: HasSource<Ast = Ast>,
         Ast: AstNode + ast::HasVisibility,
@@ -165,18 +152,12 @@ fn target_data_for_def(
         let source = x.source(db)?;
         let in_file_syntax = source.syntax();
         let file_id = in_file_syntax.file_id;
-        let syntax = in_file_syntax.value;
-        let current_visibility = source.value.visibility();
-        Some((
-            vis_offset(syntax),
-            current_visibility,
-            syntax.text_range(),
-            file_id.original_file(db.upcast()),
-        ))
+        let range = in_file_syntax.value.text_range();
+        Some((ast::AnyHasVisibility::new(source.value), range, file_id.original_file(db.upcast())))
     }
 
     let target_name;
-    let (offset, current_visibility, target, target_file) = match def {
+    let (offset, target, target_file) = match def {
         hir::ModuleDef::Function(f) => {
             target_name = Some(f.name(db));
             offset_target_and_file_id(db, f)?
@@ -213,8 +194,8 @@ fn target_data_for_def(
             target_name = m.name(db);
             let in_file_source = m.declaration_source(db)?;
             let file_id = in_file_source.file_id.original_file(db.upcast());
-            let syntax = in_file_source.value.syntax();
-            (vis_offset(syntax), in_file_source.value.visibility(), syntax.text_range(), file_id)
+            let range = in_file_source.value.syntax().text_range();
+            (ast::AnyHasVisibility::new(in_file_source.value), range, file_id)
         }
         // FIXME
         hir::ModuleDef::Macro(_) => return None,
@@ -222,7 +203,7 @@ fn target_data_for_def(
         hir::ModuleDef::Variant(_) | hir::ModuleDef::BuiltinType(_) => return None,
     };
 
-    Some((offset, current_visibility, target, target_file, target_name))
+    Some((offset, target, target_file, target_name))
 }
 
 #[cfg(test)]
diff --git a/crates/ide-assists/src/handlers/promote_local_to_const.rs b/crates/ide-assists/src/handlers/promote_local_to_const.rs
index 23153b4c566..5cc110cf12b 100644
--- a/crates/ide-assists/src/handlers/promote_local_to_const.rs
+++ b/crates/ide-assists/src/handlers/promote_local_to_const.rs
@@ -8,13 +8,10 @@ use ide_db::{
 use stdx::to_upper_snake_case;
 use syntax::{
     ast::{self, make, HasName},
-    AstNode, WalkEvent,
+    ted, AstNode, WalkEvent,
 };
 
-use crate::{
-    assist_context::{AssistContext, Assists},
-    utils::{render_snippet, Cursor},
-};
+use crate::assist_context::{AssistContext, Assists};
 
 // Assist: promote_local_to_const
 //
@@ -70,29 +67,33 @@ pub(crate) fn promote_local_to_const(acc: &mut Assists, ctx: &AssistContext<'_>)
         cov_mark::hit!(promote_local_non_const);
         return None;
     }
-    let target = let_stmt.syntax().text_range();
+
     acc.add(
         AssistId("promote_local_to_const", AssistKind::Refactor),
         "Promote local to constant",
-        target,
-        |builder| {
+        let_stmt.syntax().text_range(),
+        |edit| {
             let name = to_upper_snake_case(&name.to_string());
             let usages = Definition::Local(local).usages(&ctx.sema).all();
             if let Some(usages) = usages.references.get(&ctx.file_id()) {
+                let name = make::name_ref(&name);
+
                 for usage in usages {
-                    builder.replace(usage.range, &name);
+                    let Some(usage) = usage.name.as_name_ref().cloned() else { continue };
+                    let usage = edit.make_mut(usage);
+                    ted::replace(usage.syntax(), name.clone_for_update().syntax());
                 }
             }
 
-            let item = make::item_const(None, make::name(&name), make::ty(&ty), initializer);
-            match ctx.config.snippet_cap.zip(item.name()) {
-                Some((cap, name)) => builder.replace_snippet(
-                    cap,
-                    target,
-                    render_snippet(cap, item.syntax(), Cursor::Before(name.syntax())),
-                ),
-                None => builder.replace(target, item.to_string()),
+            let item = make::item_const(None, make::name(&name), make::ty(&ty), initializer)
+                .clone_for_update();
+            let let_stmt = edit.make_mut(let_stmt);
+
+            if let Some((cap, name)) = ctx.config.snippet_cap.zip(item.name()) {
+                edit.add_tabstop_before(cap, name);
             }
+
+            ted::replace(let_stmt.syntax(), item.syntax());
         },
     )
 }
@@ -158,6 +159,27 @@ fn foo() {
     }
 
     #[test]
+    fn multiple_uses() {
+        check_assist(
+            promote_local_to_const,
+            r"
+fn foo() {
+    let x$0 = 0;
+    let y = x;
+    let z = (x, x, x, x);
+}
+",
+            r"
+fn foo() {
+    const $0X: i32 = 0;
+    let y = X;
+    let z = (X, X, X, X);
+}
+",
+        );
+    }
+
+    #[test]
     fn not_applicable_non_const_meth_call() {
         cov_mark::check!(promote_local_non_const);
         check_assist_not_applicable(
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index 1f5ed206b4a..606804aea25 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -748,6 +748,27 @@ fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> {
     Some(())
 }
 
+pub trait HasVisibilityEdit: ast::HasVisibility {
+    fn set_visibility(&self, visbility: ast::Visibility) {
+        match self.visibility() {
+            Some(current_visibility) => {
+                ted::replace(current_visibility.syntax(), visbility.syntax())
+            }
+            None => {
+                let vis_before = self
+                    .syntax()
+                    .children_with_tokens()
+                    .find(|it| !matches!(it.kind(), WHITESPACE | COMMENT | ATTR))
+                    .unwrap_or_else(|| self.syntax().first_child_or_token().unwrap());
+
+                ted::insert(ted::Position::before(vis_before), visbility.syntax());
+            }
+        }
+    }
+}
+
+impl<T: ast::HasVisibility> HasVisibilityEdit for T {}
+
 pub trait Indent: AstNode + Clone + Sized {
     fn indent_level(&self) -> IndentLevel {
         IndentLevel::from_node(self.syntax())