about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_module.rs175
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/ast/make.rs17
2 files changed, 126 insertions, 66 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_module.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_module.rs
index da91d0ac280..e783b869345 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_module.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_module.rs
@@ -1,6 +1,5 @@
-use std::iter;
+use std::ops::RangeInclusive;
 
-use either::Either;
 use hir::{HasSource, ModuleSource};
 use ide_db::{
     FileId, FxHashMap, FxHashSet,
@@ -82,7 +81,15 @@ pub(crate) fn extract_module(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
         curr_parent_module = ast::Module::cast(mod_syn_opt);
     }
 
-    let mut module = extract_target(&node, ctx.selection_trimmed())?;
+    let selection_range = ctx.selection_trimmed();
+    let (mut module, module_text_range) = if let Some(item) = ast::Item::cast(node.clone()) {
+        let module = extract_single_target(&item);
+        (module, node.text_range())
+    } else {
+        let (module, range) = extract_child_target(&node, selection_range)?;
+        let module_text_range = range.start().text_range().cover(range.end().text_range());
+        (module, module_text_range)
+    };
     if module.body_items.is_empty() {
         return None;
     }
@@ -92,7 +99,7 @@ pub(crate) fn extract_module(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
     acc.add(
         AssistId::refactor_extract("extract_module"),
         "Extract Module",
-        module.text_range,
+        module_text_range,
         |builder| {
             //This takes place in three steps:
             //
@@ -110,17 +117,17 @@ pub(crate) fn extract_module(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
             //for change_visibility and usages for first point mentioned above in the process
 
             let (usages_to_be_processed, record_fields, use_stmts_to_be_inserted) =
-                module.get_usages_and_record_fields(ctx);
+                module.get_usages_and_record_fields(ctx, module_text_range);
 
             builder.edit_file(ctx.vfs_file_id());
             use_stmts_to_be_inserted.into_iter().for_each(|(_, use_stmt)| {
                 builder.insert(ctx.selection_trimmed().end(), format!("\n{use_stmt}"));
             });
 
-            let import_paths_to_be_removed = module.resolve_imports(curr_parent_module, ctx);
+            let import_items = module.resolve_imports(curr_parent_module, ctx);
             module.change_visibility(record_fields);
 
-            let module_def = generate_module_def(&impl_parent, &mut module, old_item_indent);
+            let module_def = generate_module_def(&impl_parent, module, old_item_indent).to_string();
 
             let mut usages_to_be_processed_for_cur_file = vec![];
             for (file_id, usages) in usages_to_be_processed {
@@ -157,15 +164,12 @@ pub(crate) fn extract_module(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
 
                 builder.insert(impl_.syntax().text_range().end(), format!("\n\n{module_def}"));
             } else {
-                for import_path_text_range in import_paths_to_be_removed {
-                    if module.text_range.intersect(import_path_text_range).is_some() {
-                        module.text_range = module.text_range.cover(import_path_text_range);
-                    } else {
-                        builder.delete(import_path_text_range);
+                for import_item in import_items {
+                    if !module_text_range.contains_range(import_item) {
+                        builder.delete(import_item);
                     }
                 }
-
-                builder.replace(module.text_range, module_def)
+                builder.replace(module_text_range, module_def)
             }
         },
     )
@@ -173,38 +177,50 @@ pub(crate) fn extract_module(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
 
 fn generate_module_def(
     parent_impl: &Option<ast::Impl>,
-    module: &mut Module,
+    module: Module,
     old_indent: IndentLevel,
-) -> String {
-    let (items_to_be_processed, new_item_indent) = if parent_impl.is_some() {
-        (Either::Left(module.body_items.iter()), old_indent + 2)
+) -> ast::Module {
+    let Module { name, body_items, use_items } = module;
+    let items = if let Some(self_ty) = parent_impl.as_ref().and_then(|imp| imp.self_ty()) {
+        let assoc_items = body_items
+            .into_iter()
+            .map(|item| item.syntax().clone())
+            .filter_map(ast::AssocItem::cast)
+            .map(|it| it.indent(IndentLevel(1)))
+            .collect_vec();
+        let assoc_item_list = make::assoc_item_list(Some(assoc_items));
+        let impl_ = make::impl_(None, None, None, self_ty.clone(), None, Some(assoc_item_list));
+        // Add the import for enum/struct corresponding to given impl block
+        let use_impl = make_use_stmt_of_node_with_super(self_ty.syntax());
+        let mut module_body_items = use_items;
+        module_body_items.insert(0, use_impl);
+        module_body_items.push(ast::Item::Impl(impl_));
+        module_body_items
     } else {
-        (Either::Right(module.use_items.iter().chain(module.body_items.iter())), old_indent + 1)
+        [use_items, body_items].concat()
     };
 
-    let mut body = items_to_be_processed
-        .map(|item| item.indent(IndentLevel(1)))
-        .map(|item| format!("{new_item_indent}{item}"))
-        .join("\n\n");
+    let items = items.into_iter().map(|it| it.reset_indent().indent(IndentLevel(1))).collect_vec();
+    let module_body = make::item_list(Some(items));
 
-    if let Some(self_ty) = parent_impl.as_ref().and_then(|imp| imp.self_ty()) {
-        let impl_indent = old_indent + 1;
-        body = format!("{impl_indent}impl {self_ty} {{\n{body}\n{impl_indent}}}");
+    let module_name = make::name(name);
+    make::mod_(module_name, Some(module_body)).indent(old_indent)
+}
 
-        // Add the import for enum/struct corresponding to given impl block
-        module.make_use_stmt_of_node_with_super(self_ty.syntax());
-        for item in module.use_items.iter() {
-            body = format!("{impl_indent}{item}\n\n{body}");
-        }
-    }
+fn make_use_stmt_of_node_with_super(node_syntax: &SyntaxNode) -> ast::Item {
+    let super_path = make::ext::ident_path("super");
+    let node_path = make::ext::ident_path(&node_syntax.to_string());
+    let use_ = make::use_(
+        None,
+        None,
+        make::use_tree(make::join_paths(vec![super_path, node_path]), None, None, false),
+    );
 
-    let module_name = module.name;
-    format!("mod {module_name} {{\n{body}\n{old_indent}}}")
+    ast::Item::from(use_)
 }
 
 #[derive(Debug)]
 struct Module {
-    text_range: TextRange,
     name: &'static str,
     /// All items except use items.
     body_items: Vec<ast::Item>,
@@ -214,22 +230,37 @@ struct Module {
     use_items: Vec<ast::Item>,
 }
 
-fn extract_target(node: &SyntaxNode, selection_range: TextRange) -> Option<Module> {
+fn extract_single_target(node: &ast::Item) -> Module {
+    let (body_items, use_items) = if matches!(node, ast::Item::Use(_)) {
+        (Vec::new(), vec![node.clone()])
+    } else {
+        (vec![node.clone()], Vec::new())
+    };
+    let name = "modname";
+    Module { name, body_items, use_items }
+}
+
+fn extract_child_target(
+    node: &SyntaxNode,
+    selection_range: TextRange,
+) -> Option<(Module, RangeInclusive<SyntaxNode>)> {
     let selected_nodes = node
         .children()
         .filter(|node| selection_range.contains_range(node.text_range()))
-        .chain(iter::once(node.clone()));
-    let (use_items, body_items) = selected_nodes
         .filter_map(ast::Item::cast)
-        .partition(|item| matches!(item, ast::Item::Use(..)));
-
-    Some(Module { text_range: selection_range, name: "modname", body_items, use_items })
+        .collect_vec();
+    let start = selected_nodes.first()?.syntax().clone();
+    let end = selected_nodes.last()?.syntax().clone();
+    let (use_items, body_items): (Vec<ast::Item>, Vec<ast::Item>) =
+        selected_nodes.into_iter().partition(|item| matches!(item, ast::Item::Use(..)));
+    Some((Module { name: "modname", body_items, use_items }, start..=end))
 }
 
 impl Module {
     fn get_usages_and_record_fields(
         &self,
         ctx: &AssistContext<'_>,
+        replace_range: TextRange,
     ) -> (FxHashMap<FileId, Vec<(TextRange, String)>>, Vec<SyntaxNode>, FxHashMap<TextSize, ast::Use>)
     {
         let mut adt_fields = Vec::new();
@@ -247,7 +278,7 @@ impl Module {
                     ast::Adt(it) => {
                         if let Some( nod ) = ctx.sema.to_def(&it) {
                             let node_def = Definition::Adt(nod);
-                            self.expand_and_group_usages_file_wise(ctx, node_def, &mut refs, &mut use_stmts_to_be_inserted);
+                            self.expand_and_group_usages_file_wise(ctx, replace_range,node_def, &mut refs, &mut use_stmts_to_be_inserted);
 
                             //Enum Fields are not allowed to explicitly specify pub, it is implied
                             match it {
@@ -281,30 +312,30 @@ impl Module {
                     ast::TypeAlias(it) => {
                         if let Some( nod ) = ctx.sema.to_def(&it) {
                             let node_def = Definition::TypeAlias(nod);
-                            self.expand_and_group_usages_file_wise(ctx, node_def, &mut refs, &mut use_stmts_to_be_inserted);
+                            self.expand_and_group_usages_file_wise(ctx,replace_range, node_def, &mut refs, &mut use_stmts_to_be_inserted);
                         }
                     },
                     ast::Const(it) => {
                         if let Some( nod ) = ctx.sema.to_def(&it) {
                             let node_def = Definition::Const(nod);
-                            self.expand_and_group_usages_file_wise(ctx, node_def, &mut refs, &mut use_stmts_to_be_inserted);
+                            self.expand_and_group_usages_file_wise(ctx,replace_range, node_def, &mut refs, &mut use_stmts_to_be_inserted);
                         }
                     },
                     ast::Static(it) => {
                         if let Some( nod ) = ctx.sema.to_def(&it) {
                             let node_def = Definition::Static(nod);
-                            self.expand_and_group_usages_file_wise(ctx, node_def, &mut refs, &mut use_stmts_to_be_inserted);
+                            self.expand_and_group_usages_file_wise(ctx,replace_range, node_def, &mut refs, &mut use_stmts_to_be_inserted);
                         }
                     },
                     ast::Fn(it) => {
                         if let Some( nod ) = ctx.sema.to_def(&it) {
                             let node_def = Definition::Function(nod);
-                            self.expand_and_group_usages_file_wise(ctx, node_def, &mut refs, &mut use_stmts_to_be_inserted);
+                            self.expand_and_group_usages_file_wise(ctx,replace_range, node_def, &mut refs, &mut use_stmts_to_be_inserted);
                         }
                     },
                     ast::Macro(it) => {
                         if let Some(nod) = ctx.sema.to_def(&it) {
-                            self.expand_and_group_usages_file_wise(ctx, Definition::Macro(nod), &mut refs, &mut use_stmts_to_be_inserted);
+                            self.expand_and_group_usages_file_wise(ctx,replace_range, Definition::Macro(nod), &mut refs, &mut use_stmts_to_be_inserted);
                         }
                     },
                     _ => (),
@@ -318,6 +349,7 @@ impl Module {
     fn expand_and_group_usages_file_wise(
         &self,
         ctx: &AssistContext<'_>,
+        replace_range: TextRange,
         node_def: Definition,
         refs_in_files: &mut FxHashMap<FileId, Vec<(TextRange, String)>>,
         use_stmts_to_be_inserted: &mut FxHashMap<TextSize, ast::Use>,
@@ -327,7 +359,7 @@ impl Module {
             syntax::NodeOrToken::Node(node) => node,
             syntax::NodeOrToken::Token(tok) => tok.parent().unwrap(), // won't panic
         };
-        let out_of_sel = |node: &SyntaxNode| !self.text_range.contains_range(node.text_range());
+        let out_of_sel = |node: &SyntaxNode| !replace_range.contains_range(node.text_range());
         let mut use_stmts_set = FxHashSet::default();
 
         for (file_id, refs) in node_def.usages(&ctx.sema).all() {
@@ -527,7 +559,8 @@ impl Module {
                     // mod -> ust_stmt transversal
                     // true  | false -> super import insertion
                     // true  | true -> super import insertion
-                    self.make_use_stmt_of_node_with_super(use_node);
+                    let super_use_node = make_use_stmt_of_node_with_super(use_node);
+                    self.use_items.insert(0, super_use_node);
                 }
                 None => {}
             }
@@ -556,7 +589,8 @@ impl Module {
 
                 use_tree_paths = Some(use_tree_str);
             } else if def_in_mod && def_out_sel {
-                self.make_use_stmt_of_node_with_super(use_node);
+                let super_use_node = make_use_stmt_of_node_with_super(use_node);
+                self.use_items.insert(0, super_use_node);
             }
         }
 
@@ -596,20 +630,6 @@ impl Module {
         import_path_to_be_removed
     }
 
-    fn make_use_stmt_of_node_with_super(&mut self, node_syntax: &SyntaxNode) -> ast::Item {
-        let super_path = make::ext::ident_path("super");
-        let node_path = make::ext::ident_path(&node_syntax.to_string());
-        let use_ = make::use_(
-            None,
-            None,
-            make::use_tree(make::join_paths(vec![super_path, node_path]), None, None, false),
-        );
-
-        let item = ast::Item::from(use_);
-        self.use_items.insert(0, item.clone());
-        item
-    }
-
     fn process_use_stmt_for_import_resolve(
         &self,
         use_stmt: Option<ast::Use>,
@@ -1424,10 +1444,10 @@ $0fn foo(x: B) {}$0
             struct B {}
 
 mod modname {
-    use super::B;
-
     use super::A;
 
+    use super::B;
+
     impl A {
         pub(crate) fn foo(x: B) {}
     }
@@ -1739,4 +1759,27 @@ fn main() {
 "#,
         );
     }
+
+    #[test]
+    fn test_miss_select_item() {
+        check_assist(
+            extract_module,
+            r#"
+mod foo {
+    mod $0bar {
+        fn foo(){}$0
+    }
+}
+"#,
+            r#"
+mod foo {
+    mod modname {
+        pub(crate) mod bar {
+            fn foo(){}
+        }
+    }
+}
+"#,
+        )
+    }
 }
diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs
index c5ca6097601..9897fd09415 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs
@@ -231,6 +231,23 @@ pub fn ty_fn_ptr<I: Iterator<Item = Param>>(
     }
 }
 
+pub fn item_list(body: Option<Vec<ast::Item>>) -> ast::ItemList {
+    let is_break_braces = body.is_some();
+    let body_newline = if is_break_braces { "\n" } else { "" };
+    let body_indent = if is_break_braces { "    " } else { "" };
+
+    let body = match body {
+        Some(bd) => bd.iter().map(|elem| elem.to_string()).join("\n\n    "),
+        None => String::new(),
+    };
+    ast_from_text(&format!("mod C {{{body_newline}{body_indent}{body}{body_newline}}}"))
+}
+
+pub fn mod_(name: ast::Name, body: Option<ast::ItemList>) -> ast::Module {
+    let body = body.map_or(";".to_owned(), |body| format!(" {body}"));
+    ast_from_text(&format!("mod {name}{body}"))
+}
+
 pub fn assoc_item_list(body: Option<Vec<ast::AssocItem>>) -> ast::AssocItemList {
     let is_break_braces = body.is_some();
     let body_newline = if is_break_braces { "\n".to_owned() } else { String::new() };