about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/rust-analyzer/crates/ide/src/goto_definition.rs51
-rw-r--r--src/tools/rust-analyzer/crates/ide/src/highlight_related.rs27
2 files changed, 45 insertions, 33 deletions
diff --git a/src/tools/rust-analyzer/crates/ide/src/goto_definition.rs b/src/tools/rust-analyzer/crates/ide/src/goto_definition.rs
index c90a544a718..fd465f31d43 100644
--- a/src/tools/rust-analyzer/crates/ide/src/goto_definition.rs
+++ b/src/tools/rust-analyzer/crates/ide/src/goto_definition.rs
@@ -412,45 +412,30 @@ pub(crate) fn find_branch_root(
     sema: &Semantics<'_, RootDatabase>,
     token: &SyntaxToken,
 ) -> Vec<SyntaxNode> {
-    fn find_root(
-        sema: &Semantics<'_, RootDatabase>,
-        token: &SyntaxToken,
-        pred: impl Fn(SyntaxNode) -> Option<SyntaxNode>,
-    ) -> Vec<SyntaxNode> {
-        let mut result = Vec::new();
-        for token in sema.descend_into_macros(token.clone()) {
-            for node in sema.token_ancestors_with_macros(token) {
-                if ast::MacroCall::can_cast(node.kind()) {
-                    break;
-                }
-
-                if let Some(node) = pred(node) {
-                    result.push(node);
-                    break;
-                }
-            }
-        }
-        result
-    }
+    let find_nodes = |node_filter: fn(SyntaxNode) -> Option<SyntaxNode>| {
+        sema.descend_into_macros(token.clone())
+            .into_iter()
+            .filter_map(|token| node_filter(token.parent()?))
+            .collect_vec()
+    };
 
     match token.kind() {
-        T![match] => {
-            find_root(sema, token, |node| Some(ast::MatchExpr::cast(node)?.syntax().clone()))
-        }
-        T![=>] => find_root(sema, token, |node| Some(ast::MatchArm::cast(node)?.syntax().clone())),
-        T![if] => find_root(sema, token, |node| {
+        T![match] => find_nodes(|node| Some(ast::MatchExpr::cast(node)?.syntax().clone())),
+        T![=>] => find_nodes(|node| Some(ast::MatchArm::cast(node)?.syntax().clone())),
+        T![if] => find_nodes(|node| {
             let if_expr = ast::IfExpr::cast(node)?;
 
-            iter::successors(Some(if_expr.clone()), |if_expr| {
+            let root_if = iter::successors(Some(if_expr.clone()), |if_expr| {
                 let parent_if = if_expr.syntax().parent().and_then(ast::IfExpr::cast)?;
-                if let ast::ElseBranch::IfExpr(nested_if) = parent_if.else_branch()? {
-                    (nested_if.syntax() == if_expr.syntax()).then_some(parent_if)
-                } else {
-                    None
-                }
+                let ast::ElseBranch::IfExpr(else_branch) = parent_if.else_branch()? else {
+                    return None;
+                };
+
+                (else_branch.syntax() == if_expr.syntax()).then_some(parent_if)
             })
-            .last()
-            .map(|if_expr| if_expr.syntax().clone())
+            .last()?;
+
+            Some(root_if.syntax().clone())
         }),
         _ => vec![],
     }
diff --git a/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs b/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs
index 6aa17f90ef0..356bd69aa44 100644
--- a/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs
+++ b/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs
@@ -2357,6 +2357,33 @@ fn main() {
     }
 
     #[test]
+    fn match_in_macro_highlight_2() {
+        check(
+            r#"
+macro_rules! match_ast {
+    (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
+
+    (match ($node:expr) {
+        $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
+        _ => $catch_all:expr $(,)?
+    }) => {{
+        $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
+        { $catch_all }
+    }};
+}
+
+fn main() {
+    match_ast! {
+        match$0 Some(1) {
+            Some(x) => x,
+        }
+    }
+}
+            "#,
+        );
+    }
+
+    #[test]
     fn nested_if_else() {
         check(
             r#"