about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir/src/semantics.rs9
-rw-r--r--crates/hir_expand/src/lib.rs56
-rw-r--r--crates/ide_assists/src/handlers/auto_import.rs56
-rw-r--r--crates/ide_assists/src/handlers/extract_function.rs2
-rw-r--r--crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs2
-rw-r--r--crates/ide_assists/src/handlers/replace_qualified_name_with_use.rs2
-rw-r--r--crates/ide_assists/src/tests.rs5
-rw-r--r--crates/ide_completion/src/completions/attribute/derive.rs2
-rw-r--r--crates/ide_completion/src/completions/flyimport.rs2
-rw-r--r--crates/ide_completion/src/completions/postfix.rs3
-rw-r--r--crates/ide_completion/src/completions/snippet.rs3
-rw-r--r--crates/ide_completion/src/lib.rs2
-rw-r--r--crates/ide_db/src/helpers/insert_use.rs61
-rw-r--r--crates/ide_db/src/helpers/insert_use/tests.rs25
14 files changed, 182 insertions, 48 deletions
diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs
index ca23d85cb24..4f481613dd8 100644
--- a/crates/hir/src/semantics.rs
+++ b/crates/hir/src/semantics.rs
@@ -208,6 +208,10 @@ impl<'db, DB: HirDatabase> Semantics<'db, DB> {
         self.imp.original_range_opt(node)
     }
 
+    pub fn original_ast_node<N: AstNode>(&self, node: N) -> Option<N> {
+        self.imp.original_ast_node(node)
+    }
+
     pub fn diagnostics_display_range(&self, diagnostics: InFile<SyntaxNodePtr>) -> FileRange {
         self.imp.diagnostics_display_range(diagnostics)
     }
@@ -660,6 +664,11 @@ impl<'db> SemanticsImpl<'db> {
         node.as_ref().original_file_range_opt(self.db.upcast())
     }
 
+    fn original_ast_node<N: AstNode>(&self, node: N) -> Option<N> {
+        let file = self.find_file(node.syntax().clone());
+        file.with_value(node).original_ast_node(self.db.upcast()).map(|it| it.value)
+    }
+
     fn diagnostics_display_range(&self, src: InFile<SyntaxNodePtr>) -> FileRange {
         let root = self.db.parse_or_expand(src.file_id).unwrap();
         let node = src.value.to_node(&root);
diff --git a/crates/hir_expand/src/lib.rs b/crates/hir_expand/src/lib.rs
index f411ef9b3ca..4742cb089eb 100644
--- a/crates/hir_expand/src/lib.rs
+++ b/crates/hir_expand/src/lib.rs
@@ -24,9 +24,9 @@ use std::{hash::Hash, iter, sync::Arc};
 
 use base_db::{impl_intern_key, salsa, CrateId, FileId, FileRange};
 use syntax::{
-    algo::skip_trivia_token,
+    algo::{self, skip_trivia_token},
     ast::{self, AstNode, HasAttrs},
-    Direction, SyntaxNode, SyntaxToken, TextRange,
+    Direction, SyntaxNode, SyntaxToken,
 };
 
 use crate::{
@@ -600,13 +600,15 @@ impl<'a> InFile<&'a SyntaxNode> {
 
     /// Attempts to map the syntax node back up its macro calls.
     pub fn original_file_range_opt(self, db: &dyn db::AstDatabase) -> Option<FileRange> {
-        match original_range_opt(db, self) {
-            Some(range) => {
-                let original_file = range.file_id.original_file(db);
-                if range.file_id != original_file.into() {
+        match ascend_node_border_tokens(db, self) {
+            Some(InFile { file_id, value: (first, last) }) => {
+                let original_file = file_id.original_file(db);
+                let range = first.text_range().cover(last.text_range());
+                if file_id != original_file.into() {
                     tracing::error!("Failed mapping up more for {:?}", range);
+                    return None;
                 }
-                Some(FileRange { file_id: original_file, range: range.value })
+                Some(FileRange { file_id: original_file, range })
             }
             _ if !self.file_id.is_macro() => Some(FileRange {
                 file_id: self.file_id.original_file(db),
@@ -617,28 +619,29 @@ impl<'a> InFile<&'a SyntaxNode> {
     }
 }
 
-fn original_range_opt(
+fn ascend_node_border_tokens(
     db: &dyn db::AstDatabase,
-    node: InFile<&SyntaxNode>,
-) -> Option<InFile<TextRange>> {
-    let expansion = node.file_id.expansion_info(db)?;
+    InFile { file_id, value: node }: InFile<&SyntaxNode>,
+) -> Option<InFile<(SyntaxToken, SyntaxToken)>> {
+    let expansion = file_id.expansion_info(db)?;
 
     // the input node has only one token ?
-    let single = skip_trivia_token(node.value.first_token()?, Direction::Next)?
-        == skip_trivia_token(node.value.last_token()?, Direction::Prev)?;
+    let first = skip_trivia_token(node.first_token()?, Direction::Next)?;
+    let last = skip_trivia_token(node.last_token()?, Direction::Prev)?;
+    let is_single_token = first == last;
 
-    node.value.descendants().find_map(|it| {
+    node.descendants().find_map(|it| {
         let first = skip_trivia_token(it.first_token()?, Direction::Next)?;
-        let first = ascend_call_token(db, &expansion, node.with_value(first))?;
+        let first = ascend_call_token(db, &expansion, InFile::new(file_id, first))?;
 
         let last = skip_trivia_token(it.last_token()?, Direction::Prev)?;
-        let last = ascend_call_token(db, &expansion, node.with_value(last))?;
+        let last = ascend_call_token(db, &expansion, InFile::new(file_id, last))?;
 
-        if (!single && first == last) || (first.file_id != last.file_id) {
+        if (!is_single_token && first == last) || (first.file_id != last.file_id) {
             return None;
         }
 
-        Some(first.with_value(first.value.text_range().cover(last.value.text_range())))
+        Some(InFile::new(first.file_id, (first.value, last.value)))
     })
 }
 
@@ -674,6 +677,23 @@ impl<N: AstNode> InFile<N> {
         self.value.syntax().descendants().filter_map(T::cast).map(move |n| self.with_value(n))
     }
 
+    pub fn original_ast_node(self, db: &dyn db::AstDatabase) -> Option<InFile<N>> {
+        match ascend_node_border_tokens(db, self.syntax()) {
+            Some(InFile { file_id, value: (first, last) }) => {
+                let original_file = file_id.original_file(db);
+                if file_id != original_file.into() {
+                    let range = first.text_range().cover(last.text_range());
+                    tracing::error!("Failed mapping up more for {:?}", range);
+                    return None;
+                }
+                let anc = algo::least_common_ancestor(&first.parent()?, &last.parent()?)?;
+                Some(InFile::new(file_id, anc.ancestors().find_map(N::cast)?))
+            }
+            _ if !self.file_id.is_macro() => Some(self),
+            _ => None,
+        }
+    }
+
     pub fn syntax(&self) -> InFile<&SyntaxNode> {
         self.with_value(self.value.syntax())
     }
diff --git a/crates/ide_assists/src/handlers/auto_import.rs b/crates/ide_assists/src/handlers/auto_import.rs
index a5858869c50..cac736ff850 100644
--- a/crates/ide_assists/src/handlers/auto_import.rs
+++ b/crates/ide_assists/src/handlers/auto_import.rs
@@ -95,7 +95,7 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
         NodeOrToken::Token(token) => token.text_range(),
     };
     let group_label = group_label(import_assets.import_candidate());
-    let scope = ImportScope::find_insert_use_container_with_macros(
+    let scope = ImportScope::find_insert_use_container(
         &match syntax_under_caret {
             NodeOrToken::Node(it) => it,
             NodeOrToken::Token(it) => it.parent()?,
@@ -165,6 +165,60 @@ mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
 
     #[test]
+    fn not_applicable_if_scope_inside_macro() {
+        check_assist_not_applicable(
+            auto_import,
+            r"
+mod bar {
+    pub struct Baz;
+}
+macro_rules! foo {
+    ($it:ident) => {
+        mod __ {
+            fn __(x: $it) {}
+        }
+    };
+}
+foo! {
+    Baz$0
+}
+",
+        );
+    }
+
+    #[test]
+    fn applicable_in_attributes() {
+        check_assist(
+            auto_import,
+            r"
+//- proc_macros: identity
+#[proc_macros::identity]
+mod foo {
+    mod bar {
+        const _: Baz$0 = ();
+    }
+}
+mod baz {
+    pub struct Baz;
+}
+",
+            r"
+#[proc_macros::identity]
+mod foo {
+    mod bar {
+        use crate::baz::Baz;
+
+        const _: Baz = ();
+    }
+}
+mod baz {
+    pub struct Baz;
+}
+",
+        );
+    }
+
+    #[test]
     fn applicable_when_found_an_import_partial() {
         check_assist(
             auto_import,
diff --git a/crates/ide_assists/src/handlers/extract_function.rs b/crates/ide_assists/src/handlers/extract_function.rs
index 7ffb5728cc5..3a334efe0ab 100644
--- a/crates/ide_assists/src/handlers/extract_function.rs
+++ b/crates/ide_assists/src/handlers/extract_function.rs
@@ -91,7 +91,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
 
     let target_range = body.text_range();
 
-    let scope = ImportScope::find_insert_use_container_with_macros(&node, &ctx.sema)?;
+    let scope = ImportScope::find_insert_use_container(&node, &ctx.sema)?;
 
     acc.add(
         AssistId("extract_function", crate::AssistKind::RefactorExtract),
diff --git a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
index 3b812cdf173..8e28f0443d6 100644
--- a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
+++ b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
@@ -314,7 +314,7 @@ fn process_references(
                 if let Some(mut mod_path) = mod_path {
                     mod_path.pop_segment();
                     mod_path.push_segment(variant_hir_name.clone());
-                    let scope = ImportScope::find_insert_use_container(&scope_node)?;
+                    let scope = ImportScope::find_insert_use_container(&scope_node, &ctx.sema)?;
                     visited_modules.insert(module);
                     return Some((segment, scope_node, Some((scope, mod_path))));
                 }
diff --git a/crates/ide_assists/src/handlers/replace_qualified_name_with_use.rs b/crates/ide_assists/src/handlers/replace_qualified_name_with_use.rs
index 47600be7638..8df8c4b726f 100644
--- a/crates/ide_assists/src/handlers/replace_qualified_name_with_use.rs
+++ b/crates/ide_assists/src/handlers/replace_qualified_name_with_use.rs
@@ -70,7 +70,7 @@ pub(crate) fn replace_qualified_name_with_use(
         })
         .flatten();
 
-    let scope = ImportScope::find_insert_use_container_with_macros(path.syntax(), &ctx.sema)?;
+    let scope = ImportScope::find_insert_use_container(path.syntax(), &ctx.sema)?;
     let target = path.syntax().text_range();
     acc.add(
         AssistId("replace_qualified_name_with_use", AssistKind::RefactorRewrite),
diff --git a/crates/ide_assists/src/tests.rs b/crates/ide_assists/src/tests.rs
index 7f3c93d7403..5a799795333 100644
--- a/crates/ide_assists/src/tests.rs
+++ b/crates/ide_assists/src/tests.rs
@@ -2,7 +2,7 @@ mod sourcegen;
 mod generated;
 
 use expect_test::expect;
-use hir::Semantics;
+use hir::{db::DefDatabase, Semantics};
 use ide_db::{
     base_db::{fixture::WithFixture, FileId, FileRange, SourceDatabaseExt},
     helpers::{
@@ -117,7 +117,8 @@ enum ExpectedResult<'a> {
 
 #[track_caller]
 fn check(handler: Handler, before: &str, expected: ExpectedResult, assist_label: Option<&str>) {
-    let (db, file_with_caret_id, range_or_offset) = RootDatabase::with_range_or_offset(before);
+    let (mut db, file_with_caret_id, range_or_offset) = RootDatabase::with_range_or_offset(before);
+    db.set_enable_proc_attr_macros(true);
     let text_without_caret = db.file_text(file_with_caret_id).to_string();
 
     let frange = FileRange { file_id: file_with_caret_id, range: range_or_offset.into() };
diff --git a/crates/ide_completion/src/completions/attribute/derive.rs b/crates/ide_completion/src/completions/attribute/derive.rs
index 5c08d12ef24..33b84819325 100644
--- a/crates/ide_completion/src/completions/attribute/derive.rs
+++ b/crates/ide_completion/src/completions/attribute/derive.rs
@@ -98,7 +98,7 @@ fn flyimport_attribute(ctx: &CompletionContext, acc: &mut Completions) -> Option
         &ctx.sema,
         parent.clone(),
     )?;
-    let import_scope = ImportScope::find_insert_use_container_with_macros(&parent, &ctx.sema)?;
+    let import_scope = ImportScope::find_insert_use_container(&parent, &ctx.sema)?;
     acc.add_all(
         import_assets
             .search_for_imports(&ctx.sema, ctx.config.insert_use.prefix_kind)
diff --git a/crates/ide_completion/src/completions/flyimport.rs b/crates/ide_completion/src/completions/flyimport.rs
index a5c134714ba..956411f1a3b 100644
--- a/crates/ide_completion/src/completions/flyimport.rs
+++ b/crates/ide_completion/src/completions/flyimport.rs
@@ -129,7 +129,7 @@ pub(crate) fn import_on_the_fly(acc: &mut Completions, ctx: &CompletionContext)
 
     let user_input_lowercased = potential_import_name.to_lowercase();
     let import_assets = import_assets(ctx, potential_import_name)?;
-    let import_scope = ImportScope::find_insert_use_container_with_macros(
+    let import_scope = ImportScope::find_insert_use_container(
         &position_for_import(ctx, Some(import_assets.import_candidate()))?,
         &ctx.sema,
     )?;
diff --git a/crates/ide_completion/src/completions/postfix.rs b/crates/ide_completion/src/completions/postfix.rs
index 4ee257ab43a..c239401e487 100644
--- a/crates/ide_completion/src/completions/postfix.rs
+++ b/crates/ide_completion/src/completions/postfix.rs
@@ -244,8 +244,7 @@ fn add_custom_postfix_completions(
     postfix_snippet: impl Fn(&str, &str, &str) -> Builder,
     receiver_text: &str,
 ) -> Option<()> {
-    let import_scope =
-        ImportScope::find_insert_use_container_with_macros(&ctx.token.parent()?, &ctx.sema)?;
+    let import_scope = ImportScope::find_insert_use_container(&ctx.token.parent()?, &ctx.sema)?;
     ctx.config.postfix_snippets().filter(|(_, snip)| snip.scope == SnippetScope::Expr).for_each(
         |(trigger, snippet)| {
             let imports = match snippet.imports(ctx, &import_scope) {
diff --git a/crates/ide_completion/src/completions/snippet.rs b/crates/ide_completion/src/completions/snippet.rs
index 1840e780edf..12bccfae11d 100644
--- a/crates/ide_completion/src/completions/snippet.rs
+++ b/crates/ide_completion/src/completions/snippet.rs
@@ -102,8 +102,7 @@ fn add_custom_completions(
     cap: SnippetCap,
     scope: SnippetScope,
 ) -> Option<()> {
-    let import_scope =
-        ImportScope::find_insert_use_container_with_macros(&ctx.token.parent()?, &ctx.sema)?;
+    let import_scope = ImportScope::find_insert_use_container(&ctx.token.parent()?, &ctx.sema)?;
     ctx.config.prefix_snippets().filter(|(_, snip)| snip.scope == scope).for_each(
         |(trigger, snip)| {
             let imports = match snip.imports(ctx, &import_scope) {
diff --git a/crates/ide_completion/src/lib.rs b/crates/ide_completion/src/lib.rs
index 463744f22af..d9618642c45 100644
--- a/crates/ide_completion/src/lib.rs
+++ b/crates/ide_completion/src/lib.rs
@@ -183,7 +183,7 @@ pub fn resolve_completion_edits(
     let _p = profile::span("resolve_completion_edits");
     let ctx = CompletionContext::new(db, position, config)?;
     let position_for_import = &position_for_import(&ctx, None)?;
-    let scope = ImportScope::find_insert_use_container_with_macros(position_for_import, &ctx.sema)?;
+    let scope = ImportScope::find_insert_use_container(position_for_import, &ctx.sema)?;
 
     let current_module = ctx.sema.scope(position_for_import).module()?;
     let current_crate = current_module.krate();
diff --git a/crates/ide_db/src/helpers/insert_use.rs b/crates/ide_db/src/helpers/insert_use.rs
index 9889dd772ea..1f3e3c55bf3 100644
--- a/crates/ide_db/src/helpers/insert_use.rs
+++ b/crates/ide_db/src/helpers/insert_use.rs
@@ -8,7 +8,7 @@ use hir::Semantics;
 use syntax::{
     algo,
     ast::{self, make, AstNode, HasAttrs, HasModuleItem, HasVisibility, PathSegmentKind},
-    match_ast, ted, AstToken, Direction, NodeOrToken, SyntaxNode, SyntaxToken,
+    ted, AstToken, Direction, NodeOrToken, SyntaxNode, SyntaxToken,
 };
 
 use crate::{
@@ -50,7 +50,10 @@ pub enum ImportScope {
 }
 
 impl ImportScope {
+    // FIXME: Remove this?
+    #[cfg(test)]
     fn from(syntax: SyntaxNode) -> Option<Self> {
+        use syntax::match_ast;
         fn contains_cfg_attr(attrs: &dyn HasAttrs) -> bool {
             attrs
                 .attrs()
@@ -76,16 +79,60 @@ impl ImportScope {
     }
 
     /// Determines the containing syntax node in which to insert a `use` statement affecting `position`.
-    pub fn find_insert_use_container_with_macros(
+    /// Returns the original source node inside attributes.
+    pub fn find_insert_use_container(
         position: &SyntaxNode,
         sema: &Semantics<'_, RootDatabase>,
     ) -> Option<Self> {
-        sema.ancestors_with_macros(position.clone()).find_map(Self::from)
-    }
+        fn contains_cfg_attr(attrs: &dyn HasAttrs) -> bool {
+            attrs
+                .attrs()
+                .any(|attr| attr.as_simple_call().map_or(false, |(ident, _)| ident == "cfg"))
+        }
 
-    /// Determines the containing syntax node in which to insert a `use` statement affecting `position`.
-    pub fn find_insert_use_container(position: &SyntaxNode) -> Option<Self> {
-        std::iter::successors(Some(position.clone()), SyntaxNode::parent).find_map(Self::from)
+        // Walk up the ancestor tree searching for a suitable node to do insertions on
+        // with special handling on cfg-gated items, in which case we want to insert imports locally
+        // or FIXME: annotate inserted imports with the same cfg
+        for syntax in sema.ancestors_with_macros(position.clone()) {
+            if let Some(file) = ast::SourceFile::cast(syntax.clone()) {
+                return Some(ImportScope::File(file));
+            } else if let Some(item) = ast::Item::cast(syntax) {
+                return match item {
+                    ast::Item::Const(konst) if contains_cfg_attr(&konst) => {
+                        // FIXME: Instead of bailing out with None, we should note down that
+                        // this import needs an attribute added
+                        match sema.original_ast_node(konst)?.body()? {
+                            ast::Expr::BlockExpr(block) => block,
+                            _ => return None,
+                        }
+                        .stmt_list()
+                        .map(ImportScope::Block)
+                    }
+                    ast::Item::Fn(func) if contains_cfg_attr(&func) => {
+                        // FIXME: Instead of bailing out with None, we should note down that
+                        // this import needs an attribute added
+                        sema.original_ast_node(func)?.body()?.stmt_list().map(ImportScope::Block)
+                    }
+                    ast::Item::Static(statik) if contains_cfg_attr(&statik) => {
+                        // FIXME: Instead of bailing out with None, we should note down that
+                        // this import needs an attribute added
+                        match sema.original_ast_node(statik)?.body()? {
+                            ast::Expr::BlockExpr(block) => block,
+                            _ => return None,
+                        }
+                        .stmt_list()
+                        .map(ImportScope::Block)
+                    }
+                    ast::Item::Module(module) => {
+                        // early return is important here, if we can't find the original module
+                        // in the input there is no way for us to insert an import anywhere.
+                        sema.original_ast_node(module)?.item_list().map(ImportScope::Module)
+                    }
+                    _ => continue,
+                };
+            }
+        }
+        None
     }
 
     pub fn as_syntax_node(&self) -> &SyntaxNode {
diff --git a/crates/ide_db/src/helpers/insert_use/tests.rs b/crates/ide_db/src/helpers/insert_use/tests.rs
index f3b9c7130f4..34a6900e267 100644
--- a/crates/ide_db/src/helpers/insert_use/tests.rs
+++ b/crates/ide_db/src/helpers/insert_use/tests.rs
@@ -1,5 +1,7 @@
+use base_db::fixture::WithFixture;
 use hir::PrefixKind;
-use test_utils::{assert_eq_text, extract_range_or_offset, CURSOR_MARKER};
+use stdx::trim_indent;
+use test_utils::{assert_eq_text, CURSOR_MARKER};
 
 use super::*;
 
@@ -865,17 +867,20 @@ fn check_with_config(
     ra_fixture_after: &str,
     config: &InsertUseConfig,
 ) {
-    let (text, pos) = if ra_fixture_before.contains(CURSOR_MARKER) {
-        let (range_or_offset, text) = extract_range_or_offset(ra_fixture_before);
-        (text, Some(range_or_offset))
+    let (db, file_id, pos) = if ra_fixture_before.contains(CURSOR_MARKER) {
+        let (db, file_id, range_or_offset) = RootDatabase::with_range_or_offset(ra_fixture_before);
+        (db, file_id, Some(range_or_offset))
     } else {
-        (ra_fixture_before.to_owned(), None)
+        let (db, file_id) = RootDatabase::with_single_file(ra_fixture_before);
+        (db, file_id, None)
     };
-    let syntax = ast::SourceFile::parse(&text).tree().syntax().clone_for_update();
+    let sema = &Semantics::new(&db);
+    let source_file = sema.parse(file_id);
+    let syntax = source_file.syntax().clone_for_update();
     let file = pos
         .and_then(|pos| syntax.token_at_offset(pos.expect_offset()).next()?.parent())
-        .and_then(|it| super::ImportScope::find_insert_use_container(&it))
-        .or_else(|| super::ImportScope::from(syntax))
+        .and_then(|it| ImportScope::find_insert_use_container(&it, sema))
+        .or_else(|| ImportScope::from(syntax))
         .unwrap();
     let path = ast::SourceFile::parse(&format!("use {};", path))
         .tree()
@@ -886,7 +891,7 @@ fn check_with_config(
 
     insert_use(&file, path, config);
     let result = file.as_syntax_node().ancestors().last().unwrap().to_string();
-    assert_eq_text!(ra_fixture_after, &result);
+    assert_eq_text!(&trim_indent(ra_fixture_after), &result);
 }
 
 fn check(
@@ -942,6 +947,6 @@ fn check_merge_only_fail(ra_fixture0: &str, ra_fixture1: &str, mb: MergeBehavior
 
 fn check_guess(ra_fixture: &str, expected: ImportGranularityGuess) {
     let syntax = ast::SourceFile::parse(ra_fixture).tree().syntax().clone();
-    let file = super::ImportScope::from(syntax).unwrap();
+    let file = ImportScope::from(syntax).unwrap();
     assert_eq!(super::guess_granularity_from_scope(&file), expected);
 }