about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir_expand/src/db.rs18
-rw-r--r--crates/hir_expand/src/lib.rs22
-rw-r--r--crates/ide/src/runnables.rs20
3 files changed, 42 insertions, 18 deletions
diff --git a/crates/hir_expand/src/db.rs b/crates/hir_expand/src/db.rs
index d71cc22ce4e..e806208e16d 100644
--- a/crates/hir_expand/src/db.rs
+++ b/crates/hir_expand/src/db.rs
@@ -54,18 +54,12 @@ impl TokenExpander {
             TokenExpander::MacroDef { mac, .. } => mac.expand(tt),
             TokenExpander::Builtin(it) => it.expand(db, id, tt),
             // FIXME switch these to ExpandResult as well
-            TokenExpander::BuiltinAttr(it) => {
-                let macro_arg = match db.macro_arg(id) {
-                    Some(it) => it,
-                    None => {
-                        return mbe::ExpandResult::only_err(
-                            mbe::ExpandError::Other("No item argument for attribute".to_string())
-                                .into(),
-                        );
-                    }
-                };
-                it.expand(db, id, tt, &macro_arg.0).into()
-            }
+            TokenExpander::BuiltinAttr(it) => match db.macro_arg(id) {
+                Some(macro_arg) => it.expand(db, id, tt, &macro_arg.0).into(),
+                None => mbe::ExpandResult::only_err(
+                    mbe::ExpandError::Other("No item argument for attribute".to_string()).into(),
+                ),
+            },
             TokenExpander::BuiltinDerive(it) => it.expand(db, id, tt).into(),
             TokenExpander::ProcMacro(_) => {
                 // We store the result in salsa db to prevent non-deterministic behavior in
diff --git a/crates/hir_expand/src/lib.rs b/crates/hir_expand/src/lib.rs
index 1452ab08d94..c9165a0cf60 100644
--- a/crates/hir_expand/src/lib.rs
+++ b/crates/hir_expand/src/lib.rs
@@ -562,6 +562,28 @@ impl<N: AstNode> InFile<N> {
     pub fn syntax(&self) -> InFile<&SyntaxNode> {
         self.with_value(self.value.syntax())
     }
+
+    pub fn nodes_with_attributes<'db>(
+        self,
+        db: &'db dyn db::AstDatabase,
+    ) -> impl Iterator<Item = InFile<N>> + 'db
+    where
+        N: 'db,
+    {
+        std::iter::successors(Some(self), move |node| {
+            let InFile { file_id, value } = node.file_id.call_node(db)?;
+            N::cast(value).map(|n| InFile::new(file_id, n))
+        })
+    }
+
+    pub fn node_with_attributes(self, db: &dyn db::AstDatabase) -> InFile<N> {
+        std::iter::successors(Some(self), move |node| {
+            let InFile { file_id, value } = node.file_id.call_node(db)?;
+            N::cast(value).map(|n| InFile::new(file_id, n))
+        })
+        .last()
+        .unwrap()
+    }
 }
 
 /// Given a `MacroCallId`, return what `FragmentKind` it belongs to.
diff --git a/crates/ide/src/runnables.rs b/crates/ide/src/runnables.rs
index 42f6ec5d9cf..ed220c9080c 100644
--- a/crates/ide/src/runnables.rs
+++ b/crates/ide/src/runnables.rs
@@ -232,22 +232,27 @@ fn find_related_tests(
             let functions = refs.iter().filter_map(|(range, _)| {
                 let token = file.token_at_offset(range.start()).next()?;
                 let token = sema.descend_into_macros(token);
-                token.ancestors().find_map(ast::Fn::cast)
+                // FIXME: This is the wrong file_id
+                token
+                    .ancestors()
+                    .find_map(ast::Fn::cast)
+                    .map(|f| hir::InFile::new(file_id.into(), f))
             });
 
             for fn_def in functions {
-                if let Some(runnable) = as_test_runnable(sema, &fn_def) {
+                // #[test/bench] expands to just the item causing us to lose the attribute, so recover them by going out of the attribute
+                let fn_def = fn_def.node_with_attributes(sema.db);
+                if let Some(runnable) = as_test_runnable(sema, &fn_def.value) {
                     // direct test
                     tests.insert(runnable);
-                } else if let Some(module) = parent_test_module(sema, &fn_def) {
+                } else if let Some(module) = parent_test_module(sema, &fn_def.value) {
                     // indirect test
-                    find_related_tests_in_module(sema, &fn_def, &module, tests);
+                    find_related_tests_in_module(sema, &fn_def.value, &module, tests);
                 }
             }
         }
     }
 }
-
 fn find_related_tests_in_module(
     sema: &Semantics<RootDatabase>,
     fn_def: &ast::Fn,
@@ -292,7 +297,8 @@ fn parent_test_module(sema: &Semantics<RootDatabase>, fn_def: &ast::Fn) -> Optio
 }
 
 pub(crate) fn runnable_fn(sema: &Semantics<RootDatabase>, def: hir::Function) -> Option<Runnable> {
-    let func = def.source(sema.db)?;
+    // #[test/bench] expands to just the item causing us to lose the attribute, so recover them by going out of the attribute
+    let func = def.source(sema.db)?.node_with_attributes(sema.db);
     let name_string = def.name(sema.db).to_string();
 
     let root = def.module(sema.db).krate().root_module(sema.db);
@@ -499,6 +505,8 @@ fn has_test_function_or_multiple_test_submodules(
         match item {
             hir::ModuleDef::Function(f) => {
                 if let Some(it) = f.source(sema.db) {
+                    // #[test/bench] expands to just the item causing us to lose the attribute, so recover them by going out of the attribute
+                    let it = it.node_with_attributes(sema.db);
                     if test_related_attribute(&it.value).is_some() {
                         return true;
                     }