about summary refs log tree commit diff
diff options
context:
space:
mode:
authorroife <roifewu@gmail.com>2024-07-05 00:19:43 +0800
committerroife <roifewu@gmail.com>2024-07-20 01:42:51 +0800
commit63bacc111c900b4d2cacffcba553bbd5c26cdb00 (patch)
tree319484d133a257ede000477fa78a7e3f581c4488
parent9ede85344f5885afd3e8f40f350fb3d5b20d9f76 (diff)
downloadrust-63bacc111c900b4d2cacffcba553bbd5c26cdb00.tar.gz
rust-63bacc111c900b4d2cacffcba553bbd5c26cdb00.zip
fix: incorrect highlighting of try blocks with control flow kws
-rw-r--r--src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/node_ext.rs42
-rw-r--r--src/tools/rust-analyzer/crates/ide/src/highlight_related.rs338
2 files changed, 242 insertions, 138 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/node_ext.rs b/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/node_ext.rs
index c301e100341..37238cc61d3 100644
--- a/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/node_ext.rs
+++ b/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/node_ext.rs
@@ -36,10 +36,35 @@ pub fn walk_expr(expr: &ast::Expr, cb: &mut dyn FnMut(ast::Expr)) {
     })
 }
 
+pub fn is_closure_or_blk_with_modif(expr: &ast::Expr) -> bool {
+    match expr {
+        ast::Expr::BlockExpr(block_expr) => {
+            matches!(
+                block_expr.modifier(),
+                Some(
+                    ast::BlockModifier::Async(_)
+                        | ast::BlockModifier::Try(_)
+                        | ast::BlockModifier::Const(_)
+                )
+            )
+        }
+        ast::Expr::ClosureExpr(_) => true,
+        _ => false,
+    }
+}
+
 /// Preorder walk all the expression's child expressions preserving events.
 /// If the callback returns true on an [`WalkEvent::Enter`], the subtree of the expression will be skipped.
 /// Note that the subtree may already be skipped due to the context analysis this function does.
 pub fn preorder_expr(start: &ast::Expr, cb: &mut dyn FnMut(WalkEvent<ast::Expr>) -> bool) {
+    preorder_expr_with_ctx_checker(start, &is_closure_or_blk_with_modif, cb);
+}
+
+pub fn preorder_expr_with_ctx_checker(
+    start: &ast::Expr,
+    check_ctx: &dyn Fn(&ast::Expr) -> bool,
+    cb: &mut dyn FnMut(WalkEvent<ast::Expr>) -> bool,
+) {
     let mut preorder = start.syntax().preorder();
     while let Some(event) = preorder.next() {
         let node = match event {
@@ -71,20 +96,7 @@ pub fn preorder_expr(start: &ast::Expr, cb: &mut dyn FnMut(WalkEvent<ast::Expr>)
                 if ast::GenericArg::can_cast(node.kind()) {
                     preorder.skip_subtree();
                 } else if let Some(expr) = ast::Expr::cast(node) {
-                    let is_different_context = match &expr {
-                        ast::Expr::BlockExpr(block_expr) => {
-                            matches!(
-                                block_expr.modifier(),
-                                Some(
-                                    ast::BlockModifier::Async(_)
-                                        | ast::BlockModifier::Try(_)
-                                        | ast::BlockModifier::Const(_)
-                                )
-                            )
-                        }
-                        ast::Expr::ClosureExpr(_) => true,
-                        _ => false,
-                    } && expr.syntax() != start.syntax();
+                    let is_different_context = check_ctx(&expr) && expr.syntax() != start.syntax();
                     let skip = cb(WalkEvent::Enter(expr));
                     if skip || is_different_context {
                         preorder.skip_subtree();
@@ -394,7 +406,7 @@ fn for_each_break_expr(
     }
 }
 
-fn eq_label_lt(lt1: &Option<ast::Lifetime>, lt2: &Option<ast::Lifetime>) -> bool {
+pub fn eq_label_lt(lt1: &Option<ast::Lifetime>, lt2: &Option<ast::Lifetime>) -> bool {
     lt1.as_ref().zip(lt2.as_ref()).map_or(false, |(lt, lbl)| lt.text() == lbl.text())
 }
 
diff --git a/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs b/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs
index 98651900050..a5c547c7d9a 100644
--- a/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs
+++ b/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs
@@ -1,12 +1,13 @@
 use std::iter;
 
-use hir::{DescendPreference, FilePosition, FileRange, Semantics};
+use hir::{db, DescendPreference, FilePosition, FileRange, HirFileId, InFile, Semantics};
 use ide_db::{
     defs::{Definition, IdentClass},
     helpers::pick_best_token,
     search::{FileReference, ReferenceCategory, SearchScope},
     syntax_helpers::node_ext::{
-        for_each_break_and_continue_expr, for_each_tail_expr, full_path_of_name_ref, walk_expr,
+        eq_label_lt, for_each_tail_expr, full_path_of_name_ref, is_closure_or_blk_with_modif,
+        preorder_expr_with_ctx_checker,
     },
     FxHashSet, RootDatabase,
 };
@@ -15,7 +16,7 @@ use syntax::{
     ast::{self, HasLoopBody},
     match_ast, AstNode,
     SyntaxKind::{self, IDENT, INT_NUMBER},
-    SyntaxToken, TextRange, T,
+    SyntaxToken, TextRange, WalkEvent, T,
 };
 
 use crate::{navigation_target::ToNav, NavigationTarget, TryToNav};
@@ -75,12 +76,12 @@ pub(crate) fn highlight_related(
             highlight_exit_points(sema, token)
         }
         T![fn] | T![return] | T![->] if config.exit_points => highlight_exit_points(sema, token),
-        T![await] | T![async] if config.yield_points => highlight_yield_points(token),
+        T![await] | T![async] if config.yield_points => highlight_yield_points(sema, token),
         T![for] if config.break_points && token.parent().and_then(ast::ForExpr::cast).is_some() => {
-            highlight_break_points(token)
+            highlight_break_points(sema, token)
         }
         T![break] | T![loop] | T![while] | T![continue] if config.break_points => {
-            highlight_break_points(token)
+            highlight_break_points(sema, token)
         }
         T![|] if config.closure_captures => highlight_closure_captures(sema, token, file_id),
         T![move] if config.closure_captures => highlight_closure_captures(sema, token, file_id),
@@ -276,50 +277,53 @@ fn highlight_references(
     }
 }
 
-fn highlight_exit_points(
+pub(crate) fn highlight_exit_points(
     sema: &Semantics<'_, RootDatabase>,
     token: SyntaxToken,
 ) -> Option<Vec<HighlightedRange>> {
     fn hl(
         sema: &Semantics<'_, RootDatabase>,
-        def_ranges: [Option<TextRange>; 2],
-        body: Option<ast::Expr>,
+        def_range: Option<TextRange>,
+        body: ast::Expr,
     ) -> Option<Vec<HighlightedRange>> {
         let mut highlights = Vec::new();
-        highlights.extend(
-            def_ranges
-                .into_iter()
-                .flatten()
-                .map(|range| HighlightedRange { category: ReferenceCategory::empty(), range }),
-        );
-        let body = body?;
-        walk_expr(&body, &mut |expr| match expr {
-            ast::Expr::ReturnExpr(expr) => {
-                if let Some(token) = expr.return_token() {
-                    highlights.push(HighlightedRange {
-                        category: ReferenceCategory::empty(),
-                        range: token.text_range(),
-                    });
-                }
-            }
-            ast::Expr::TryExpr(try_) => {
-                if let Some(token) = try_.question_mark_token() {
-                    highlights.push(HighlightedRange {
-                        category: ReferenceCategory::empty(),
-                        range: token.text_range(),
-                    });
+        if let Some(range) = def_range {
+            highlights.push(HighlightedRange { category: ReferenceCategory::empty(), range });
+        }
+
+        WalkExpandedExprCtx::new(sema).walk(&body, &mut |_, expr| {
+            let file_id = sema.hir_file_for(expr.syntax());
+
+            let text_range = match &expr {
+                ast::Expr::TryExpr(try_) => {
+                    try_.question_mark_token().map(|token| token.text_range())
                 }
-            }
-            ast::Expr::MethodCallExpr(_) | ast::Expr::CallExpr(_) | ast::Expr::MacroExpr(_) => {
-                if sema.type_of_expr(&expr).map_or(false, |ty| ty.original.is_never()) {
-                    highlights.push(HighlightedRange {
-                        category: ReferenceCategory::empty(),
-                        range: expr.syntax().text_range(),
-                    });
+                ast::Expr::MethodCallExpr(_) | ast::Expr::CallExpr(_) | ast::Expr::MacroExpr(_)
+                    if sema.type_of_expr(&expr).map_or(false, |ty| ty.original.is_never()) =>
+                {
+                    Some(expr.syntax().text_range())
                 }
-            }
-            _ => (),
+                _ => None,
+            };
+
         });
+
+        // We should handle `return` separately because when it is used in `try` block
+        // it will exit the outside function instead of the block it self.
+        WalkExpandedExprCtx::new(sema)
+            .with_check_ctx(&WalkExpandedExprCtx::is_async_const_block_or_closure)
+            .walk(&body, &mut |_, expr| {
+                let file_id = sema.hir_file_for(expr.syntax());
+
+                let text_range = match &expr {
+                    ast::Expr::ReturnExpr(expr) => {
+                        expr.return_token().map(|token| token.text_range())
+                    }
+                    _ => None,
+                };
+
+            });
+
         let tail = match body {
             ast::Expr::BlockExpr(b) => b.tail_expr(),
             e => Some(e),
@@ -338,26 +342,22 @@ fn highlight_exit_points(
         }
         Some(highlights)
     }
+
     for anc in token.parent_ancestors() {
         return match_ast! {
             match anc {
-                ast::Fn(fn_) => hl(sema, [fn_.fn_token().map(|it| it.text_range()), None], fn_.body().map(ast::Expr::BlockExpr)),
+                ast::Fn(fn_) => hl(sema, fn_.fn_token().map(|it| it.text_range()), ast::Expr::BlockExpr(fn_.body()?)),
                 ast::ClosureExpr(closure) => hl(
                     sema,
-                    closure.param_list().map_or([None; 2], |p| [p.l_paren_token().map(|it| it.text_range()), p.r_paren_token().map(|it| it.text_range())]),
-                    closure.body()
+                    closure.param_list().and_then(|p| p.pipe_token()).map(|tok| tok.text_range()),
+                    closure.body()?
                 ),
-                ast::BlockExpr(block_expr) => if matches!(block_expr.modifier(), Some(ast::BlockModifier::Async(_) | ast::BlockModifier::Try(_)| ast::BlockModifier::Const(_))) {
-                    hl(
-                        sema,
-                        [block_expr.modifier().and_then(|modifier| match modifier {
-                            ast::BlockModifier::Async(t) | ast::BlockModifier::Try(t) | ast::BlockModifier::Const(t) => Some(t.text_range()),
-                            _ => None,
-                        }), None],
-                        Some(block_expr.into())
-                    )
-                } else {
-                    continue;
+                ast::BlockExpr(blk) => match blk.modifier() {
+                    Some(ast::BlockModifier::Async(t)) => hl(sema, Some(t.text_range()), blk.into()),
+                    Some(ast::BlockModifier::Try(t)) if token.kind() != T![return] => {
+                        hl(sema, Some(t.text_range()), blk.into())
+                    },
+                    _ => continue,
                 },
                 _ => continue,
             }
@@ -366,44 +366,56 @@ fn highlight_exit_points(
     None
 }
 
-fn highlight_break_points(token: SyntaxToken) -> Option<Vec<HighlightedRange>> {
+pub(crate) fn highlight_break_points(
+    sema: &Semantics<'_, RootDatabase>,
+    token: SyntaxToken,
+) -> Option<Vec<HighlightedRange>> {
     fn hl(
+        sema: &Semantics<'_, RootDatabase>,
         cursor_token_kind: SyntaxKind,
-        token: Option<SyntaxToken>,
+        loop_token: Option<SyntaxToken>,
         label: Option<ast::Label>,
-        body: Option<ast::StmtList>,
+        expr: ast::Expr,
     ) -> Option<Vec<HighlightedRange>> {
         let mut highlights = Vec::new();
-        let range = cover_range(
-            token.map(|tok| tok.text_range()),
-            label.as_ref().map(|it| it.syntax().text_range()),
-        );
-        highlights.extend(
-            range.map(|range| HighlightedRange { category: ReferenceCategory::empty(), range }),
-        );
-        for_each_break_and_continue_expr(label, body, &mut |expr| {
-            let range: Option<TextRange> = match (cursor_token_kind, expr) {
-                (T![for] | T![while] | T![loop] | T![break], ast::Expr::BreakExpr(break_)) => {
-                    cover_range(
-                        break_.break_token().map(|it| it.text_range()),
-                        break_.lifetime().map(|it| it.syntax().text_range()),
-                    )
+
+        let (label_range, label_lt) = label
+            .map_or((None, None), |label| (Some(label.syntax().text_range()), label.lifetime()));
+
+        if let Some(range) = cover_range(loop_token.map(|tok| tok.text_range()), label_range) {
+            highlights.push(HighlightedRange { category: ReferenceCategory::empty(), range })
+        }
+
+        WalkExpandedExprCtx::new(sema)
+            .with_check_ctx(&WalkExpandedExprCtx::is_async_const_block_or_closure)
+            .walk(&expr, &mut |depth, expr| {
+                let file_id = sema.hir_file_for(expr.syntax());
+
+                // Only highlight the `break`s for `break` and `continue`s for `continue`
+                let (token, token_lt) = match expr {
+                    ast::Expr::BreakExpr(b) if cursor_token_kind != T![continue] => {
+                        (b.break_token(), b.lifetime())
+                    }
+                    ast::Expr::ContinueExpr(c) if cursor_token_kind != T![break] => {
+                        (c.continue_token(), c.lifetime())
+                    }
+                    _ => return,
+                };
+
+                if !(depth == 1 && token_lt.is_none() || eq_label_lt(&label_lt, &token_lt)) {
+                    return;
                 }
-                (
-                    T![for] | T![while] | T![loop] | T![continue],
-                    ast::Expr::ContinueExpr(continue_),
-                ) => cover_range(
-                    continue_.continue_token().map(|it| it.text_range()),
-                    continue_.lifetime().map(|it| it.syntax().text_range()),
-                ),
-                _ => None,
-            };
-            highlights.extend(
-                range.map(|range| HighlightedRange { category: ReferenceCategory::empty(), range }),
-            );
-        });
+
+                let text_range = cover_range(
+                    token.map(|it| it.text_range()),
+                    token_lt.map(|it| it.syntax().text_range()),
+                );
+
+            });
+
         Some(highlights)
     }
+
     let parent = token.parent()?;
     let lbl = match_ast! {
         match parent {
@@ -416,36 +428,27 @@ fn highlight_break_points(token: SyntaxToken) -> Option<Vec<HighlightedRange>> {
             _ => return None,
         }
     };
-    let lbl = lbl.as_ref();
-    let label_matches = |def_lbl: Option<ast::Label>| match lbl {
+
+    let label_matches = |def_lbl: Option<ast::Label>| match lbl.as_ref() {
         Some(lbl) => {
             Some(lbl.text()) == def_lbl.and_then(|it| it.lifetime()).as_ref().map(|it| it.text())
         }
         None => true,
     };
-    let token_kind = token.kind();
+
     for anc in token.parent_ancestors().flat_map(ast::Expr::cast) {
-        return match anc {
-            ast::Expr::LoopExpr(l) if label_matches(l.label()) => hl(
-                token_kind,
-                l.loop_token(),
-                l.label(),
-                l.loop_body().and_then(|it| it.stmt_list()),
-            ),
-            ast::Expr::ForExpr(f) if label_matches(f.label()) => hl(
-                token_kind,
-                f.for_token(),
-                f.label(),
-                f.loop_body().and_then(|it| it.stmt_list()),
-            ),
-            ast::Expr::WhileExpr(w) if label_matches(w.label()) => hl(
-                token_kind,
-                w.while_token(),
-                w.label(),
-                w.loop_body().and_then(|it| it.stmt_list()),
-            ),
+        return match &anc {
+            ast::Expr::LoopExpr(l) if label_matches(l.label()) => {
+                hl(sema, token.kind(), l.loop_token(), l.label(), anc)
+            }
+            ast::Expr::ForExpr(f) if label_matches(f.label()) => {
+                hl(sema, token.kind(), f.for_token(), f.label(), anc)
+            }
+            ast::Expr::WhileExpr(w) if label_matches(w.label()) => {
+                hl(sema, token.kind(), w.while_token(), w.label(), anc)
+            }
             ast::Expr::BlockExpr(e) if e.label().is_some() && label_matches(e.label()) => {
-                hl(token_kind, None, e.label(), e.stmt_list())
+                hl(sema, token.kind(), None, e.label(), anc)
             }
             _ => continue,
         };
@@ -453,8 +456,12 @@ fn highlight_break_points(token: SyntaxToken) -> Option<Vec<HighlightedRange>> {
     None
 }
 
-fn highlight_yield_points(token: SyntaxToken) -> Option<Vec<HighlightedRange>> {
+pub(crate) fn highlight_yield_points(
+    sema: &Semantics<'_, RootDatabase>,
+    token: SyntaxToken,
+) -> Option<Vec<HighlightedRange>> {
     fn hl(
+        sema: &Semantics<'_, RootDatabase>,
         async_token: Option<SyntaxToken>,
         body: Option<ast::Expr>,
     ) -> Option<Vec<HighlightedRange>> {
@@ -462,31 +469,35 @@ fn highlight_yield_points(token: SyntaxToken) -> Option<Vec<HighlightedRange>> {
             category: ReferenceCategory::empty(),
             range: async_token?.text_range(),
         }];
-        if let Some(body) = body {
-            walk_expr(&body, &mut |expr| {
-                if let ast::Expr::AwaitExpr(expr) = expr {
-                    if let Some(token) = expr.await_token() {
-                        highlights.push(HighlightedRange {
-                            category: ReferenceCategory::empty(),
-                            range: token.text_range(),
-                        });
-                    }
-                }
-            });
-        }
+        let Some(body) = body else {
+            return Some(highlights);
+        };
+
+        WalkExpandedExprCtx::new(sema).walk(&body, &mut |_, expr| {
+            let file_id = sema.hir_file_for(expr.syntax());
+
+            let token_range = match expr {
+                ast::Expr::AwaitExpr(expr) => expr.await_token(),
+                ast::Expr::ReturnExpr(expr) => expr.return_token(),
+                _ => None,
+            }
+            .map(|it| it.text_range());
+
+        });
+
         Some(highlights)
     }
     for anc in token.parent_ancestors() {
         return match_ast! {
             match anc {
-                ast::Fn(fn_) => hl(fn_.async_token(), fn_.body().map(ast::Expr::BlockExpr)),
+                ast::Fn(fn_) => hl(sema, fn_.async_token(), fn_.body().map(ast::Expr::BlockExpr)),
                 ast::BlockExpr(block_expr) => {
                     if block_expr.async_token().is_none() {
                         continue;
                     }
-                    hl(block_expr.async_token(), Some(block_expr.into()))
+                    hl(sema, block_expr.async_token(), Some(block_expr.into()))
                 },
-                ast::ClosureExpr(closure) => hl(closure.async_token(), closure.body()),
+                ast::ClosureExpr(closure) => hl(sema, closure.async_token(), closure.body()),
                 _ => continue,
             }
         };
@@ -511,6 +522,87 @@ fn find_defs(sema: &Semantics<'_, RootDatabase>, token: SyntaxToken) -> FxHashSe
         .collect()
 }
 
+/// Preorder walk all the expression's child expressions.
+/// For macro calls, the callback will be called on the expanded expressions after
+/// visiting the macro call itself.
+struct WalkExpandedExprCtx<'a> {
+    sema: &'a Semantics<'a, RootDatabase>,
+    depth: usize,
+    check_ctx: &'static dyn Fn(&ast::Expr) -> bool,
+}
+
+impl<'a> WalkExpandedExprCtx<'a> {
+    fn new(sema: &'a Semantics<'a, RootDatabase>) -> Self {
+        Self { sema, depth: 0, check_ctx: &is_closure_or_blk_with_modif }
+    }
+
+    fn with_check_ctx(&self, check_ctx: &'static dyn Fn(&ast::Expr) -> bool) -> Self {
+        Self { check_ctx, ..*self }
+    }
+
+    fn walk(&mut self, expr: &ast::Expr, cb: &mut dyn FnMut(usize, ast::Expr)) {
+        preorder_expr_with_ctx_checker(expr, self.check_ctx, &mut |ev: WalkEvent<ast::Expr>| {
+            match ev {
+                syntax::WalkEvent::Enter(expr) => {
+                    cb(self.depth, expr.clone());
+
+                    if Self::should_change_depth(&expr) {
+                        self.depth += 1;
+                    }
+
+                    if let ast::Expr::MacroExpr(expr) = expr {
+                        if let Some(expanded) = expr
+                            .macro_call()
+                            .and_then(|call| self.sema.expand(&call))
+                            .and_then(ast::MacroStmts::cast)
+                        {
+                            self.handle_expanded(expanded, cb);
+                        }
+                    }
+                }
+                syntax::WalkEvent::Leave(expr) if Self::should_change_depth(&expr) => {
+                    self.depth -= 1;
+                }
+                _ => {}
+            }
+            false
+        })
+    }
+
+    fn handle_expanded(&mut self, expanded: ast::MacroStmts, cb: &mut dyn FnMut(usize, ast::Expr)) {
+        if let Some(expr) = expanded.expr() {
+            self.walk(&expr, cb);
+        }
+
+        for stmt in expanded.statements() {
+            if let ast::Stmt::ExprStmt(stmt) = stmt {
+                if let Some(expr) = stmt.expr() {
+                    self.walk(&expr, cb);
+                }
+            }
+        }
+    }
+
+    fn should_change_depth(expr: &ast::Expr) -> bool {
+        match expr {
+            ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_) => true,
+            ast::Expr::BlockExpr(blk) if blk.label().is_some() => true,
+            _ => false,
+        }
+    }
+
+    fn is_async_const_block_or_closure(expr: &ast::Expr) -> bool {
+        match expr {
+            ast::Expr::BlockExpr(b) => matches!(
+                b.modifier(),
+                Some(ast::BlockModifier::Async(_) | ast::BlockModifier::Const(_))
+            ),
+            ast::Expr::ClosureExpr(_) => true,
+            _ => false,
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use itertools::Itertools;