about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRyan Mehri <ryan.mehri1@gmail.com>2023-09-09 11:40:29 -0700
committerRyan Mehri <ryan.mehri1@gmail.com>2023-09-09 11:59:59 -0700
commit2e13aed3bc235d47d92f9ce3b8fd4fa3c5f87939 (patch)
tree6c41e0a7322aa40e10869f687a843a6fc412ffa7
parent136a9dbe36606cb00b546c3562088c462d8a0926 (diff)
downloadrust-2e13aed3bc235d47d92f9ce3b8fd4fa3c5f87939.tar.gz
rust-2e13aed3bc235d47d92f9ce3b8fd4fa3c5f87939.zip
feat: support cross module imports
-rw-r--r--crates/ide-assists/src/handlers/bool_to_enum.rs226
1 files changed, 214 insertions, 12 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs
index 97522648440..f59b0528131 100644
--- a/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -1,9 +1,13 @@
+use hir::ModuleDef;
 use ide_db::{
     assists::{AssistId, AssistKind},
     defs::Definition,
-    search::{FileReference, SearchScope, UsageSearchResult},
+    helpers::mod_path_to_ast,
+    imports::insert_use::{insert_use, ImportScope},
+    search::{FileReference, UsageSearchResult},
     source_change::SourceChangeBuilder,
 };
+use itertools::Itertools;
 use syntax::{
     ast::{
         self,
@@ -48,6 +52,7 @@ use crate::assist_context::{AssistContext, Assists};
 pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
     let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
         find_bool_node(ctx)?;
+    let target_module = ctx.sema.scope(&target_node)?.module();
 
     let target = name.syntax().text_range();
     acc.add(
@@ -64,13 +69,10 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
                 replace_bool_expr(edit, initializer);
             }
 
-            let usages = definition
-                .usages(&ctx.sema)
-                .in_scope(&SearchScope::single_file(ctx.file_id()))
-                .all();
-            replace_usages(edit, &usages);
+            let usages = definition.usages(&ctx.sema).all();
 
-            add_enum_def(edit, ctx, &usages, target_node);
+            add_enum_def(edit, ctx, &usages, target_node, &target_module);
+            replace_usages(edit, ctx, &usages, &target_module);
         },
     )
 }
@@ -186,8 +188,45 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
 }
 
 /// Replaces all usages of the target identifier, both when read and written to.
-fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) {
-    for (_, references) in usages.iter() {
+fn replace_usages(
+    edit: &mut SourceChangeBuilder,
+    ctx: &AssistContext<'_>,
+    usages: &UsageSearchResult,
+    target_module: &hir::Module,
+) {
+    for (file_id, references) in usages.iter() {
+        edit.edit_file(*file_id);
+
+        // add imports across modules where needed
+        references
+            .iter()
+            .filter_map(|FileReference { name, .. }| {
+                ctx.sema.scope(name.syntax()).map(|scope| (name, scope.module()))
+            })
+            .unique_by(|name_and_module| name_and_module.1)
+            .filter(|(_, module)| module != target_module)
+            .filter_map(|(name, module)| {
+                let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema);
+                let mod_path = module.find_use_path_prefixed(
+                    ctx.sema.db,
+                    ModuleDef::Module(*target_module),
+                    ctx.config.insert_use.prefix_kind,
+                    ctx.config.prefer_no_std,
+                );
+                import_scope.zip(mod_path)
+            })
+            .for_each(|(import_scope, mod_path)| {
+                let import_scope = match import_scope {
+                    ImportScope::File(it) => ImportScope::File(edit.make_mut(it)),
+                    ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)),
+                    ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)),
+                };
+                let path =
+                    make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool"));
+                insert_use(&import_scope, path, &ctx.config.insert_use);
+            });
+
+        // replace the usages in expressions
         references
             .into_iter()
             .filter_map(|FileReference { range, name, .. }| match name {
@@ -213,7 +252,7 @@ fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) {
                     let record_field = edit.make_mut(record_field);
                     let enum_expr = bool_expr_to_enum_expr(initializer);
                     record_field.replace_expr(enum_expr);
-                } else if name_ref.syntax().ancestors().find_map(ast::Expr::cast).is_some() {
+                } else if name_ref.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
                     // for any other usage in an expression, replace it with a check that it is the true variant
                     edit.replace(range, format!("{} == Bool::True", name_ref.text()));
                 }
@@ -255,8 +294,15 @@ fn add_enum_def(
     ctx: &AssistContext<'_>,
     usages: &UsageSearchResult,
     target_node: SyntaxNode,
+    target_module: &hir::Module,
 ) {
-    let make_enum_pub = usages.iter().any(|(file_id, _)| file_id != &ctx.file_id());
+    let make_enum_pub = usages
+        .iter()
+        .flat_map(|(_, refs)| refs)
+        .filter_map(|FileReference { name, .. }| {
+            ctx.sema.scope(name.syntax()).map(|scope| scope.module())
+        })
+        .any(|module| &module != target_module);
     let enum_def = make_bool_enum(make_enum_pub);
 
     let indent = IndentLevel::from_node(&target_node);
@@ -649,7 +695,7 @@ fn main() {
 "#,
             r#"
 #[derive(PartialEq, Eq)]
-enum $0Bool { True, False }
+enum Bool { True, False }
 
 struct Foo {
     bar: Bool,
@@ -714,6 +760,162 @@ fn main() {
     }
 
     #[test]
+    fn const_in_module() {
+        check_assist(
+            bool_to_enum,
+            r#"
+fn main() {
+    if foo::FOO {
+        println!("foo");
+    }
+}
+
+mod foo {
+    pub const $0FOO: bool = true;
+}
+"#,
+            r#"
+use foo::Bool;
+
+fn main() {
+    if foo::FOO == Bool::True {
+        println!("foo");
+    }
+}
+
+mod foo {
+    #[derive(PartialEq, Eq)]
+    pub enum Bool { True, False }
+
+    pub const FOO: Bool = Bool::True;
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn const_in_module_with_import() {
+        check_assist(
+            bool_to_enum,
+            r#"
+fn main() {
+    use foo::FOO;
+
+    if FOO {
+        println!("foo");
+    }
+}
+
+mod foo {
+    pub const $0FOO: bool = true;
+}
+"#,
+            r#"
+use crate::foo::Bool;
+
+fn main() {
+    use foo::FOO;
+
+    if FOO == Bool::True {
+        println!("foo");
+    }
+}
+
+mod foo {
+    #[derive(PartialEq, Eq)]
+    pub enum Bool { True, False }
+
+    pub const FOO: Bool = Bool::True;
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn const_cross_file() {
+        check_assist(
+            bool_to_enum,
+            r#"
+//- /main.rs
+mod foo;
+
+fn main() {
+    if foo::FOO {
+        println!("foo");
+    }
+}
+
+//- /foo.rs
+pub const $0FOO: bool = true;
+"#,
+            r#"
+//- /main.rs
+use foo::Bool;
+
+mod foo;
+
+fn main() {
+    if foo::FOO == Bool::True {
+        println!("foo");
+    }
+}
+
+//- /foo.rs
+#[derive(PartialEq, Eq)]
+pub enum Bool { True, False }
+
+pub const FOO: Bool = Bool::True;
+"#,
+        )
+    }
+
+    #[test]
+    fn const_cross_file_and_module() {
+        check_assist(
+            bool_to_enum,
+            r#"
+//- /main.rs
+mod foo;
+
+fn main() {
+    use foo::bar;
+
+    if bar::BAR {
+        println!("foo");
+    }
+}
+
+//- /foo.rs
+pub mod bar {
+    pub const $0BAR: bool = false;
+}
+"#,
+            r#"
+//- /main.rs
+use crate::foo::bar::Bool;
+
+mod foo;
+
+fn main() {
+    use foo::bar;
+
+    if bar::BAR == Bool::True {
+        println!("foo");
+    }
+}
+
+//- /foo.rs
+pub mod bar {
+    #[derive(PartialEq, Eq)]
+    pub enum Bool { True, False }
+
+    pub const BAR: Bool = Bool::False;
+}
+"#,
+        )
+    }
+
+    #[test]
     fn const_non_bool() {
         cov_mark::check!(not_applicable_non_bool_const);
         check_assist_not_applicable(