about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/hir/src/semantics.rs1
-rw-r--r--crates/ide/src/references.rs2
-rw-r--r--crates/ide/src/runnables.rs87
3 files changed, 48 insertions, 42 deletions
diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs
index 0cf868e64e1..281e6c65dc4 100644
--- a/crates/hir/src/semantics.rs
+++ b/crates/hir/src/semantics.rs
@@ -228,6 +228,7 @@ impl<'db, DB: HirDatabase> Semantics<'db, DB> {
         token.parent().into_iter().flat_map(move |it| self.ancestors_with_macros(it))
     }
 
+    /// Iterates the ancestors of the given node, climbing up macro expansions while doing so.
     pub fn ancestors_with_macros(&self, node: SyntaxNode) -> impl Iterator<Item = SyntaxNode> + '_ {
         self.imp.ancestors_with_macros(node)
     }
diff --git a/crates/ide/src/references.rs b/crates/ide/src/references.rs
index 8c3e380b76d..0e3b73d2354 100644
--- a/crates/ide/src/references.rs
+++ b/crates/ide/src/references.rs
@@ -109,7 +109,7 @@ pub(crate) fn find_all_refs(
     }
 }
 
-fn find_defs<'a>(
+pub(crate) fn find_defs<'a>(
     sema: &'a Semantics<RootDatabase>,
     syntax: &SyntaxNode,
     offset: TextSize,
diff --git a/crates/ide/src/runnables.rs b/crates/ide/src/runnables.rs
index f03bc1427cb..b2111bc4ee0 100644
--- a/crates/ide/src/runnables.rs
+++ b/crates/ide/src/runnables.rs
@@ -2,7 +2,7 @@ use std::fmt;
 
 use ast::HasName;
 use cfg::CfgExpr;
-use hir::{AsAssocItem, HasAttrs, HasSource, HirDisplay, InFile, Semantics};
+use hir::{AsAssocItem, HasAttrs, HasSource, HirDisplay, Semantics};
 use ide_assists::utils::test_related_attribute;
 use ide_db::{
     base_db::{FilePosition, FileRange},
@@ -14,7 +14,10 @@ use ide_db::{
 use itertools::Itertools;
 use rustc_hash::{FxHashMap, FxHashSet};
 use stdx::{always, format_to};
-use syntax::ast::{self, AstNode, HasAttrs as _};
+use syntax::{
+    ast::{self, AstNode, HasAttrs as _},
+    SmolStr, SyntaxNode,
+};
 
 use crate::{
     display::{ToNav, TryToNav},
@@ -31,7 +34,7 @@ pub struct Runnable {
 
 #[derive(Debug, Clone, Hash, PartialEq, Eq)]
 pub enum TestId {
-    Name(String),
+    Name(SmolStr),
     Path(String),
 }
 
@@ -206,45 +209,44 @@ pub(crate) fn related_tests(
 ) -> Vec<Runnable> {
     let sema = Semantics::new(db);
     let mut res: FxHashSet<Runnable> = FxHashSet::default();
+    let syntax = sema.parse(position.file_id).syntax().clone();
 
-    find_related_tests(&sema, position, search_scope, &mut res);
+    find_related_tests(&sema, &syntax, position, search_scope, &mut res);
 
     res.into_iter().collect()
 }
 
 fn find_related_tests(
     sema: &Semantics<RootDatabase>,
+    syntax: &SyntaxNode,
     position: FilePosition,
     search_scope: Option<SearchScope>,
     tests: &mut FxHashSet<Runnable>,
 ) {
-    let refs = match references::find_all_refs(sema, position, search_scope) {
-        Some(it) => it,
-        _ => return,
-    };
-    for (file_id, refs) in refs.into_iter().flat_map(|refs| refs.references) {
-        let file = sema.parse(file_id);
-        let file = file.syntax();
-
-        // create flattened vec of tokens
-        let tokens =
-            refs.iter().flat_map(|(range, _)| match file.token_at_offset(range.start()).next() {
-                Some(token) => sema.descend_into_macros(token),
-                None => Default::default(),
-            });
-
-        // find first suitable ancestor
-        let functions = tokens
-            .filter_map(|token| token.ancestors().find_map(ast::Fn::cast))
-            .map(|f| hir::InFile::new(sema.hir_file_for(f.syntax()), f));
-
-        for InFile { value: fn_def, .. } in functions {
-            if let Some(runnable) = as_test_runnable(sema, &fn_def) {
-                // direct test
-                tests.insert(runnable);
-            } else if let Some(module) = parent_test_module(sema, &fn_def) {
-                // indirect test
-                find_related_tests_in_module(sema, &fn_def, &module, tests);
+    let defs = references::find_defs(sema, syntax, position.offset);
+    for def in defs {
+        let defs = def
+            .usages(sema)
+            .set_scope(search_scope.clone())
+            .all()
+            .references
+            .into_values()
+            .flatten();
+        for ref_ in defs {
+            let name_ref = match ref_.name {
+                ast::NameLike::NameRef(name_ref) => name_ref,
+                _ => continue,
+            };
+            if let Some(fn_def) =
+                sema.ancestors_with_macros(name_ref.syntax().clone()).find_map(ast::Fn::cast)
+            {
+                if let Some(runnable) = as_test_runnable(sema, &fn_def) {
+                    // direct test
+                    tests.insert(runnable);
+                } else if let Some(module) = parent_test_module(sema, &fn_def) {
+                    // indirect test
+                    find_related_tests_in_module(sema, syntax, &fn_def, &module, tests);
+                }
             }
         }
     }
@@ -252,6 +254,7 @@ fn find_related_tests(
 
 fn find_related_tests_in_module(
     sema: &Semantics<RootDatabase>,
+    syntax: &SyntaxNode,
     fn_def: &ast::Fn,
     parent_module: &hir::Module,
     tests: &mut FxHashSet<Runnable>,
@@ -270,7 +273,7 @@ fn find_related_tests_in_module(
     let file_id = mod_source.file_id.original_file(sema.db);
     let mod_scope = SearchScope::file_range(FileRange { file_id, range });
     let fn_pos = FilePosition { file_id, offset: fn_name.syntax().text_range().start() };
-    find_related_tests(sema, fn_pos, Some(mod_scope), tests)
+    find_related_tests(sema, syntax, fn_pos, Some(mod_scope), tests)
 }
 
 fn as_test_runnable(sema: &Semantics<RootDatabase>, fn_def: &ast::Fn) -> Option<Runnable> {
@@ -297,24 +300,26 @@ fn parent_test_module(sema: &Semantics<RootDatabase>, fn_def: &ast::Fn) -> Optio
 
 pub(crate) fn runnable_fn(sema: &Semantics<RootDatabase>, def: hir::Function) -> Option<Runnable> {
     let func = def.source(sema.db)?;
-    let name_string = def.name(sema.db).to_string();
+    let name = def.name(sema.db).to_smol_str();
 
     let root = def.module(sema.db).krate().root_module(sema.db);
 
-    let kind = if name_string == "main" && def.module(sema.db) == root {
+    let kind = if name == "main" && def.module(sema.db) == root {
         RunnableKind::Bin
     } else {
-        let canonical_path = {
-            let def: hir::ModuleDef = def.into();
-            def.canonical_path(sema.db)
+        let test_id = || {
+            let canonical_path = {
+                let def: hir::ModuleDef = def.into();
+                def.canonical_path(sema.db)
+            };
+            canonical_path.map(TestId::Path).unwrap_or(TestId::Name(name))
         };
-        let test_id = canonical_path.map(TestId::Path).unwrap_or(TestId::Name(name_string));
 
         if test_related_attribute(&func.value).is_some() {
             let attr = TestAttr::from_fn(&func.value);
-            RunnableKind::Test { test_id, attr }
+            RunnableKind::Test { test_id: test_id(), attr }
         } else if func.value.has_atom_attr("bench") {
-            RunnableKind::Bench { test_id }
+            RunnableKind::Bench { test_id: test_id() }
         } else {
             return None;
         }
@@ -433,7 +438,7 @@ fn module_def_doctest(db: &RootDatabase, def: Definition) -> Option<Runnable> {
         Some(path)
     })();
 
-    let test_id = path.map_or_else(|| TestId::Name(def_name.to_string()), TestId::Path);
+    let test_id = path.map_or_else(|| TestId::Name(def_name.to_smol_str()), TestId::Path);
 
     let mut nav = match def {
         Definition::Module(def) => NavigationTarget::from_module_to_decl(db, def),