about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors[bot] <26634292+bors[bot]@users.noreply.github.com>2021-11-16 20:51:09 +0000
committerGitHub <noreply@github.com>2021-11-16 20:51:09 +0000
commitadd6cccd4c923fbb5c83cc27b06aa84b2cbc9557 (patch)
treed93f93b66c9c5cc6d19708615455cfa8f2b67e3c
parent1c49667c5658e0f8f0fb496786071e2f9b6cd43c (diff)
parent366499c3be15a7bbf7914d6825e4d92fdfbddb1e (diff)
downloadrust-add6cccd4c923fbb5c83cc27b06aa84b2cbc9557.tar.gz
rust-add6cccd4c923fbb5c83cc27b06aa84b2cbc9557.zip
Merge #10781
10781: internal: Do not use reference search in `runnables::related_tests` r=Veykril a=Veykril

bors r+

Co-authored-by: Lukas Wirth <lukastw97@gmail.com>
-rw-r--r--crates/hir/src/semantics.rs1
-rw-r--r--crates/ide/src/highlight_related.rs69
-rw-r--r--crates/ide/src/references.rs98
-rw-r--r--crates/ide/src/runnables.rs106
4 files changed, 122 insertions, 152 deletions
diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs
index 0cf868e64e1..281e6c65dc4 100644
--- a/crates/hir/src/semantics.rs
+++ b/crates/hir/src/semantics.rs
@@ -228,6 +228,7 @@ impl<'db, DB: HirDatabase> Semantics<'db, DB> {
         token.parent().into_iter().flat_map(move |it| self.ancestors_with_macros(it))
     }
 
+    /// Iterates the ancestors of the given node, climbing up macro expansions while doing so.
     pub fn ancestors_with_macros(&self, node: SyntaxNode) -> impl Iterator<Item = SyntaxNode> + '_ {
         self.imp.ancestors_with_macros(node)
     }
diff --git a/crates/ide/src/highlight_related.rs b/crates/ide/src/highlight_related.rs
index 5ff6fd10bc8..7d4d52ff50e 100644
--- a/crates/ide/src/highlight_related.rs
+++ b/crates/ide/src/highlight_related.rs
@@ -1,7 +1,7 @@
 use hir::Semantics;
 use ide_db::{
-    base_db::FilePosition,
-    defs::{Definition, NameClass, NameRefClass},
+    base_db::{FileId, FilePosition},
+    defs::Definition,
     helpers::{for_each_break_expr, for_each_tail_expr, node_ext::walk_expr, pick_best_token},
     search::{FileReference, ReferenceCategory, SearchScope},
     RootDatabase,
@@ -11,7 +11,7 @@ use syntax::{
     ast::{self, HasLoopBody},
     match_ast, AstNode,
     SyntaxKind::IDENT,
-    SyntaxNode, SyntaxToken, TextRange, TextSize, T,
+    SyntaxNode, SyntaxToken, TextRange, T,
 };
 
 use crate::{display::TryToNav, references, NavigationTarget};
@@ -45,12 +45,12 @@ pub struct HighlightRelatedConfig {
 pub(crate) fn highlight_related(
     sema: &Semantics<RootDatabase>,
     config: HighlightRelatedConfig,
-    position: FilePosition,
+    FilePosition { offset, file_id }: FilePosition,
 ) -> Option<Vec<HighlightedRange>> {
     let _p = profile::span("highlight_related");
-    let syntax = sema.parse(position.file_id).syntax().clone();
+    let syntax = sema.parse(file_id).syntax().clone();
 
-    let token = pick_best_token(syntax.token_at_offset(position.offset), |kind| match kind {
+    let token = pick_best_token(syntax.token_at_offset(offset), |kind| match kind {
         T![?] => 4, // prefer `?` when the cursor is sandwiched like in `await$0?`
         T![->] => 3,
         kind if kind.is_keyword() => 2,
@@ -68,17 +68,18 @@ pub(crate) fn highlight_related(
             highlight_break_points(token)
         }
         T![break] | T![loop] | T![while] if config.break_points => highlight_break_points(token),
-        _ if config.references => highlight_references(sema, &syntax, position),
+        _ if config.references => highlight_references(sema, &syntax, token, file_id),
         _ => None,
     }
 }
 
 fn highlight_references(
     sema: &Semantics<RootDatabase>,
-    syntax: &SyntaxNode,
-    FilePosition { offset, file_id }: FilePosition,
+    node: &SyntaxNode,
+    token: SyntaxToken,
+    file_id: FileId,
 ) -> Option<Vec<HighlightedRange>> {
-    let defs = find_defs(sema, syntax, offset);
+    let defs = find_defs(sema, token.clone());
     let usages = defs
         .iter()
         .filter_map(|&d| {
@@ -105,11 +106,8 @@ fn highlight_references(
         .filter(|decl| decl.file_id == file_id)
         .and_then(|decl| {
             let range = decl.focus_range?;
-            let category = if references::decl_mutability(&def, syntax, range) {
-                Some(ReferenceCategory::Write)
-            } else {
-                None
-            };
+            let category =
+                references::decl_mutability(&def, node, range).then(|| ReferenceCategory::Write);
             Some(HighlightedRange { range, category })
         })
     });
@@ -293,43 +291,10 @@ fn cover_range(r0: Option<TextRange>, r1: Option<TextRange>) -> Option<TextRange
     }
 }
 
-fn find_defs(
-    sema: &Semantics<RootDatabase>,
-    syntax: &SyntaxNode,
-    offset: TextSize,
-) -> FxHashSet<Definition> {
-    sema.find_nodes_at_offset_with_descend(syntax, offset)
-        .flat_map(|name_like| {
-            Some(match name_like {
-                ast::NameLike::NameRef(name_ref) => {
-                    match NameRefClass::classify(sema, &name_ref)? {
-                        NameRefClass::Definition(def) => vec![def],
-                        NameRefClass::FieldShorthand { local_ref, field_ref } => {
-                            vec![Definition::Local(local_ref), Definition::Field(field_ref)]
-                        }
-                    }
-                }
-                ast::NameLike::Name(name) => match NameClass::classify(sema, &name)? {
-                    NameClass::Definition(it) | NameClass::ConstReference(it) => vec![it],
-                    NameClass::PatFieldShorthand { local_def, field_ref } => {
-                        vec![Definition::Local(local_def), Definition::Field(field_ref)]
-                    }
-                },
-                ast::NameLike::Lifetime(lifetime) => {
-                    NameRefClass::classify_lifetime(sema, &lifetime)
-                        .and_then(|class| match class {
-                            NameRefClass::Definition(it) => Some(it),
-                            _ => None,
-                        })
-                        .or_else(|| {
-                            NameClass::classify_lifetime(sema, &lifetime)
-                                .and_then(NameClass::defined)
-                        })
-                        .map(|it| vec![it])?
-                }
-            })
-        })
-        .flatten()
+fn find_defs(sema: &Semantics<RootDatabase>, token: SyntaxToken) -> FxHashSet<Definition> {
+    sema.descend_into_macros(token)
+        .into_iter()
+        .flat_map(|token| Definition::from_token(sema, &token))
         .collect()
 }
 
diff --git a/crates/ide/src/references.rs b/crates/ide/src/references.rs
index d1219044fe5..0e3b73d2354 100644
--- a/crates/ide/src/references.rs
+++ b/crates/ide/src/references.rs
@@ -9,9 +9,6 @@
 //! at the index that the match starts at and its tree parent is
 //! resolved to the search element definition, we get a reference.
 
-use std::iter;
-
-use either::Either;
 use hir::{PathResolution, Semantics};
 use ide_db::{
     base_db::FileId,
@@ -58,60 +55,58 @@ pub(crate) fn find_all_refs(
 ) -> Option<Vec<ReferenceSearchResult>> {
     let _p = profile::span("find_all_refs");
     let syntax = sema.parse(position.file_id).syntax().clone();
+    let make_searcher = |literal_search: bool| {
+        move |def: Definition| {
+            let mut usages =
+                def.usages(sema).set_scope(search_scope.clone()).include_self_refs().all();
+            let declaration = match def {
+                Definition::Module(module) => {
+                    Some(NavigationTarget::from_module_to_decl(sema.db, module))
+                }
+                def => def.try_to_nav(sema.db),
+            }
+            .map(|nav| {
+                let decl_range = nav.focus_or_full_range();
+                Declaration {
+                    is_mut: decl_mutability(&def, sema.parse(nav.file_id).syntax(), decl_range),
+                    nav,
+                }
+            });
+            if literal_search {
+                retain_adt_literal_usages(&mut usages, def, sema);
+            }
+
+            let references = usages
+                .into_iter()
+                .map(|(file_id, refs)| {
+                    (
+                        file_id,
+                        refs.into_iter()
+                            .map(|file_ref| (file_ref.range, file_ref.category))
+                            .collect(),
+                    )
+                })
+                .collect();
+
+            ReferenceSearchResult { declaration, references }
+        }
+    };
 
-    let mut is_literal_search = false;
-    let defs = match name_for_constructor_search(&syntax, position) {
+    match name_for_constructor_search(&syntax, position) {
         Some(name) => {
-            is_literal_search = true;
             let def = match NameClass::classify(sema, &name)? {
                 NameClass::Definition(it) | NameClass::ConstReference(it) => it,
                 NameClass::PatFieldShorthand { local_def: _, field_ref } => {
                     Definition::Field(field_ref)
                 }
             };
-            Either::Left(iter::once(def))
+            Some(vec![make_searcher(true)(def)])
         }
-        None => Either::Right(find_defs(sema, &syntax, position.offset)),
-    };
-
-    Some(
-        defs.into_iter()
-            .map(|def| {
-                let mut usages =
-                    def.usages(sema).set_scope(search_scope.clone()).include_self_refs().all();
-                let declaration = match def {
-                    Definition::Module(module) => {
-                        Some(NavigationTarget::from_module_to_decl(sema.db, module))
-                    }
-                    def => def.try_to_nav(sema.db),
-                }
-                .map(|nav| {
-                    let decl_range = nav.focus_or_full_range();
-                    Declaration {
-                        is_mut: decl_mutability(&def, sema.parse(nav.file_id).syntax(), decl_range),
-                        nav,
-                    }
-                });
-                if is_literal_search {
-                    retain_adt_literal_usages(&mut usages, def, sema);
-                }
-
-                let references = usages
-                    .into_iter()
-                    .map(|(file_id, refs)| {
-                        (
-                            file_id,
-                            refs.into_iter()
-                                .map(|file_ref| (file_ref.range, file_ref.category))
-                                .collect(),
-                        )
-                    })
-                    .collect();
-
-                ReferenceSearchResult { declaration, references }
-            })
-            .collect(),
-    )
+        None => {
+            let search = make_searcher(false);
+            Some(find_defs(sema, &syntax, position.offset).into_iter().map(search).collect())
+        }
+    }
 }
 
 pub(crate) fn find_defs<'a>(
@@ -119,8 +114,8 @@ pub(crate) fn find_defs<'a>(
     syntax: &SyntaxNode,
     offset: TextSize,
 ) -> impl Iterator<Item = Definition> + 'a {
-    sema.find_nodes_at_offset_with_descend(syntax, offset).filter_map(move |node| {
-        Some(match node {
+    sema.find_nodes_at_offset_with_descend(syntax, offset).filter_map(move |name_like| {
+        let def = match name_like {
             ast::NameLike::NameRef(name_ref) => match NameRefClass::classify(sema, &name_ref)? {
                 NameRefClass::Definition(def) => def,
                 NameRefClass::FieldShorthand { local_ref, field_ref: _ } => {
@@ -141,7 +136,8 @@ pub(crate) fn find_defs<'a>(
                 .or_else(|| {
                     NameClass::classify_lifetime(sema, &lifetime).and_then(NameClass::defined)
                 })?,
-        })
+        };
+        Some(def)
     })
 }
 
diff --git a/crates/ide/src/runnables.rs b/crates/ide/src/runnables.rs
index 21130e06075..b2111bc4ee0 100644
--- a/crates/ide/src/runnables.rs
+++ b/crates/ide/src/runnables.rs
@@ -2,7 +2,7 @@ use std::fmt;
 
 use ast::HasName;
 use cfg::CfgExpr;
-use hir::{AsAssocItem, HasAttrs, HasSource, HirDisplay, InFile, Semantics};
+use hir::{AsAssocItem, HasAttrs, HasSource, HirDisplay, Semantics};
 use ide_assists::utils::test_related_attribute;
 use ide_db::{
     base_db::{FilePosition, FileRange},
@@ -14,7 +14,10 @@ use ide_db::{
 use itertools::Itertools;
 use rustc_hash::{FxHashMap, FxHashSet};
 use stdx::{always, format_to};
-use syntax::ast::{self, AstNode, HasAttrs as _};
+use syntax::{
+    ast::{self, AstNode, HasAttrs as _},
+    SmolStr, SyntaxNode,
+};
 
 use crate::{
     display::{ToNav, TryToNav},
@@ -31,7 +34,7 @@ pub struct Runnable {
 
 #[derive(Debug, Clone, Hash, PartialEq, Eq)]
 pub enum TestId {
-    Name(String),
+    Name(SmolStr),
     Path(String),
 }
 
@@ -206,68 +209,71 @@ pub(crate) fn related_tests(
 ) -> Vec<Runnable> {
     let sema = Semantics::new(db);
     let mut res: FxHashSet<Runnable> = FxHashSet::default();
+    let syntax = sema.parse(position.file_id).syntax().clone();
 
-    find_related_tests(&sema, position, search_scope, &mut res);
+    find_related_tests(&sema, &syntax, position, search_scope, &mut res);
 
-    res.into_iter().collect_vec()
+    res.into_iter().collect()
 }
 
 fn find_related_tests(
     sema: &Semantics<RootDatabase>,
+    syntax: &SyntaxNode,
     position: FilePosition,
     search_scope: Option<SearchScope>,
     tests: &mut FxHashSet<Runnable>,
 ) {
-    if let Some(refs) = references::find_all_refs(sema, position, search_scope) {
-        for (file_id, refs) in refs.into_iter().flat_map(|refs| refs.references) {
-            let file = sema.parse(file_id);
-            let file = file.syntax();
-
-            // create flattened vec of tokens
-            let tokens = refs.iter().flat_map(|(range, _)| {
-                match file.token_at_offset(range.start()).next() {
-                    Some(token) => sema.descend_into_macros(token),
-                    None => Default::default(),
-                }
-            });
-
-            // find first suitable ancestor
-            let functions = tokens
-                .filter_map(|token| token.ancestors().find_map(ast::Fn::cast))
-                .map(|f| hir::InFile::new(sema.hir_file_for(f.syntax()), f));
-
-            for fn_def in functions {
-                let InFile { value: fn_def, .. } = &fn_def;
-                if let Some(runnable) = as_test_runnable(sema, fn_def) {
+    let defs = references::find_defs(sema, syntax, position.offset);
+    for def in defs {
+        let defs = def
+            .usages(sema)
+            .set_scope(search_scope.clone())
+            .all()
+            .references
+            .into_values()
+            .flatten();
+        for ref_ in defs {
+            let name_ref = match ref_.name {
+                ast::NameLike::NameRef(name_ref) => name_ref,
+                _ => continue,
+            };
+            if let Some(fn_def) =
+                sema.ancestors_with_macros(name_ref.syntax().clone()).find_map(ast::Fn::cast)
+            {
+                if let Some(runnable) = as_test_runnable(sema, &fn_def) {
                     // direct test
                     tests.insert(runnable);
-                } else if let Some(module) = parent_test_module(sema, fn_def) {
+                } else if let Some(module) = parent_test_module(sema, &fn_def) {
                     // indirect test
-                    find_related_tests_in_module(sema, fn_def, &module, tests);
+                    find_related_tests_in_module(sema, syntax, &fn_def, &module, tests);
                 }
             }
         }
     }
 }
+
 fn find_related_tests_in_module(
     sema: &Semantics<RootDatabase>,
+    syntax: &SyntaxNode,
     fn_def: &ast::Fn,
     parent_module: &hir::Module,
     tests: &mut FxHashSet<Runnable>,
 ) {
-    if let Some(fn_name) = fn_def.name() {
-        let mod_source = parent_module.definition_source(sema.db);
-        let range = match mod_source.value {
-            hir::ModuleSource::Module(m) => m.syntax().text_range(),
-            hir::ModuleSource::BlockExpr(b) => b.syntax().text_range(),
-            hir::ModuleSource::SourceFile(f) => f.syntax().text_range(),
-        };
+    let fn_name = match fn_def.name() {
+        Some(it) => it,
+        _ => return,
+    };
+    let mod_source = parent_module.definition_source(sema.db);
+    let range = match &mod_source.value {
+        hir::ModuleSource::Module(m) => m.syntax().text_range(),
+        hir::ModuleSource::BlockExpr(b) => b.syntax().text_range(),
+        hir::ModuleSource::SourceFile(f) => f.syntax().text_range(),
+    };
 
-        let file_id = mod_source.file_id.original_file(sema.db);
-        let mod_scope = SearchScope::file_range(FileRange { file_id, range });
-        let fn_pos = FilePosition { file_id, offset: fn_name.syntax().text_range().start() };
-        find_related_tests(sema, fn_pos, Some(mod_scope), tests)
-    }
+    let file_id = mod_source.file_id.original_file(sema.db);
+    let mod_scope = SearchScope::file_range(FileRange { file_id, range });
+    let fn_pos = FilePosition { file_id, offset: fn_name.syntax().text_range().start() };
+    find_related_tests(sema, syntax, fn_pos, Some(mod_scope), tests)
 }
 
 fn as_test_runnable(sema: &Semantics<RootDatabase>, fn_def: &ast::Fn) -> Option<Runnable> {
@@ -294,24 +300,26 @@ fn parent_test_module(sema: &Semantics<RootDatabase>, fn_def: &ast::Fn) -> Optio
 
 pub(crate) fn runnable_fn(sema: &Semantics<RootDatabase>, def: hir::Function) -> Option<Runnable> {
     let func = def.source(sema.db)?;
-    let name_string = def.name(sema.db).to_string();
+    let name = def.name(sema.db).to_smol_str();
 
     let root = def.module(sema.db).krate().root_module(sema.db);
 
-    let kind = if name_string == "main" && def.module(sema.db) == root {
+    let kind = if name == "main" && def.module(sema.db) == root {
         RunnableKind::Bin
     } else {
-        let canonical_path = {
-            let def: hir::ModuleDef = def.into();
-            def.canonical_path(sema.db)
+        let test_id = || {
+            let canonical_path = {
+                let def: hir::ModuleDef = def.into();
+                def.canonical_path(sema.db)
+            };
+            canonical_path.map(TestId::Path).unwrap_or(TestId::Name(name))
         };
-        let test_id = canonical_path.map(TestId::Path).unwrap_or(TestId::Name(name_string));
 
         if test_related_attribute(&func.value).is_some() {
             let attr = TestAttr::from_fn(&func.value);
-            RunnableKind::Test { test_id, attr }
+            RunnableKind::Test { test_id: test_id(), attr }
         } else if func.value.has_atom_attr("bench") {
-            RunnableKind::Bench { test_id }
+            RunnableKind::Bench { test_id: test_id() }
         } else {
             return None;
         }
@@ -430,7 +438,7 @@ fn module_def_doctest(db: &RootDatabase, def: Definition) -> Option<Runnable> {
         Some(path)
     })();
 
-    let test_id = path.map_or_else(|| TestId::Name(def_name.to_string()), TestId::Path);
+    let test_id = path.map_or_else(|| TestId::Name(def_name.to_smol_str()), TestId::Path);
 
     let mut nav = match def {
         Definition::Module(def) => NavigationTarget::from_module_to_decl(db, def),