about summary refs log tree commit diff
diff options
context:
space:
mode:
authorA4-Tacks <wdsjxhno1001@163.com>2025-09-03 20:07:19 +0800
committerA4-Tacks <wdsjxhno1001@163.com>2025-09-25 21:04:48 +0800
commitc1cd1da668650d0e9e0f143f89a17387d96979c2 (patch)
tree859f6821381a7f8916e6073b47bd3d1a69ea824c
parentda9831cc04b85faf584318e79dea00b4aa16ec59 (diff)
downloadrust-c1cd1da668650d0e9e0f143f89a17387d96979c2.tar.gz
rust-c1cd1da668650d0e9e0f143f89a17387d96979c2.zip
Add let-chain support for convert_to_guarded_return
- And add early expression `None` in function `Option` return

Example
---
```rust
fn main() {
    if$0 let Ok(x) = Err(92)
        && x < 30
        && let Some(y) = Some(8)
    {
        foo(x, y);
    }
}
```
->
```rust
fn main() {
    let Ok(x) = Err(92) else { return };
    if x >= 30 {
        return;
    }
    let Some(y) = Some(8) else { return };
    foo(x, y);
}
```
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs217
1 files changed, 179 insertions, 38 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
index 2ea032fb62b..3dc4737ffc5 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
@@ -1,13 +1,11 @@
 use std::iter::once;
 
-use ide_db::{
-    syntax_helpers::node_ext::{is_pattern_cond, single_let},
-    ty_filter::TryEnum,
-};
+use hir::Semantics;
+use ide_db::{RootDatabase, ty_filter::TryEnum};
 use syntax::{
     AstNode,
     SyntaxKind::{FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
-    T,
+    SyntaxNode, T,
     ast::{
         self,
         edit::{AstNodeEdit, IndentLevel},
@@ -73,13 +71,7 @@ fn if_expr_to_guarded_return(
         return None;
     }
 
-    // Check if there is an IfLet that we can handle.
-    let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
-        let let_ = single_let(cond)?;
-        (Some(let_.pat()?), let_.expr()?)
-    } else {
-        (None, cond)
-    };
+    let let_chains = flat_let_chain(cond);
 
     let then_block = if_expr.then_branch()?;
     let then_block = then_block.stmt_list()?;
@@ -106,11 +98,7 @@ fn if_expr_to_guarded_return(
 
     let parent_container = parent_block.syntax().parent()?;
 
-    let early_expression: ast::Expr = match parent_container.kind() {
-        WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
-        FN => make::expr_return(None),
-        _ => return None,
-    };
+    let early_expression: ast::Expr = early_expression(parent_container, &ctx.sema)?;
 
     then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{'])?;
 
@@ -132,32 +120,42 @@ fn if_expr_to_guarded_return(
         target,
         |edit| {
             let if_indent_level = IndentLevel::from_node(if_expr.syntax());
-            let replacement = match if_let_pat {
-                None => {
-                    // If.
-                    let new_expr = {
-                        let then_branch =
-                            make::block_expr(once(make::expr_stmt(early_expression).into()), None);
-                        let cond = invert_boolean_expression_legacy(cond_expr);
-                        make::expr_if(cond, then_branch, None).indent(if_indent_level)
-                    };
-                    new_expr.syntax().clone()
-                }
-                Some(pat) => {
+            let replacement = let_chains.into_iter().map(|expr| {
+                if let ast::Expr::LetExpr(let_expr) = &expr
+                    && let (Some(pat), Some(expr)) = (let_expr.pat(), let_expr.expr())
+                {
                     // If-let.
                     let let_else_stmt = make::let_else_stmt(
                         pat,
                         None,
-                        cond_expr,
-                        ast::make::tail_only_block_expr(early_expression),
+                        expr,
+                        ast::make::tail_only_block_expr(early_expression.clone()),
                     );
                     let let_else_stmt = let_else_stmt.indent(if_indent_level);
                     let_else_stmt.syntax().clone()
+                } else {
+                    // If.
+                    let new_expr = {
+                        let then_branch = make::block_expr(
+                            once(make::expr_stmt(early_expression.clone()).into()),
+                            None,
+                        );
+                        let cond = invert_boolean_expression_legacy(expr);
+                        make::expr_if(cond, then_branch, None).indent(if_indent_level)
+                    };
+                    new_expr.syntax().clone()
                 }
-            };
+            });
 
+            let newline = &format!("\n{if_indent_level}");
             let then_statements = replacement
-                .children_with_tokens()
+                .enumerate()
+                .flat_map(|(i, node)| {
+                    (i != 0)
+                        .then(|| make::tokens::whitespace(newline).into())
+                        .into_iter()
+                        .chain(node.children_with_tokens())
+                })
                 .chain(
                     then_block_items
                         .syntax()
@@ -201,11 +199,7 @@ fn let_stmt_to_guarded_return(
             let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
         let parent_container = parent_block.syntax().parent()?;
 
-        match parent_container.kind() {
-            WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
-            FN => make::expr_return(None),
-            _ => return None,
-        }
+        early_expression(parent_container, &ctx.sema)?
     };
 
     acc.add(
@@ -232,6 +226,44 @@ fn let_stmt_to_guarded_return(
     )
 }
 
+fn early_expression(
+    parent_container: SyntaxNode,
+    sema: &Semantics<'_, RootDatabase>,
+) -> Option<ast::Expr> {
+    if let Some(fn_) = ast::Fn::cast(parent_container.clone())
+        && let Some(fn_def) = sema.to_def(&fn_)
+        && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db))
+    {
+        let none_expr = make::expr_path(make::ext::ident_path("None"));
+        return Some(make::expr_return(Some(none_expr)));
+    }
+    Some(match parent_container.kind() {
+        WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
+        FN => make::expr_return(None),
+        _ => return None,
+    })
+}
+
+fn flat_let_chain(mut expr: ast::Expr) -> Vec<ast::Expr> {
+    let mut chains = vec![];
+
+    while let ast::Expr::BinExpr(bin_expr) = &expr
+        && bin_expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And))
+        && let (Some(lhs), Some(rhs)) = (bin_expr.lhs(), bin_expr.rhs())
+    {
+        if let Some(last) = chains.pop_if(|last| !matches!(last, ast::Expr::LetExpr(_))) {
+            chains.push(make::expr_bin_op(rhs, ast::BinaryOp::LogicOp(ast::LogicOp::And), last));
+        } else {
+            chains.push(rhs);
+        }
+        expr = lhs;
+    }
+
+    chains.push(expr);
+    chains.reverse();
+    chains
+}
+
 #[cfg(test)]
 mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
@@ -269,6 +301,37 @@ fn main() {
     }
 
     #[test]
+    fn convert_inside_fn_return_option() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+//- minicore: option
+fn ret_option() -> Option<()> {
+    bar();
+    if$0 true {
+        foo();
+
+        // comment
+        bar();
+    }
+}
+"#,
+            r#"
+fn ret_option() -> Option<()> {
+    bar();
+    if false {
+        return None;
+    }
+    foo();
+
+    // comment
+    bar();
+}
+"#,
+        );
+    }
+
+    #[test]
     fn convert_let_inside_fn() {
         check_assist(
             convert_to_guarded_return,
@@ -317,6 +380,58 @@ fn main() {
     }
 
     #[test]
+    fn convert_if_let_chain_result() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+fn main() {
+    if$0 let Ok(x) = Err(92)
+        && x < 30
+        && let Some(y) = Some(8)
+    {
+        foo(x, y);
+    }
+}
+"#,
+            r#"
+fn main() {
+    let Ok(x) = Err(92) else { return };
+    if x >= 30 {
+        return;
+    }
+    let Some(y) = Some(8) else { return };
+    foo(x, y);
+}
+"#,
+        );
+
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+fn main() {
+    if$0 let Ok(x) = Err(92)
+        && x < 30
+        && y < 20
+        && let Some(y) = Some(8)
+    {
+        foo(x, y);
+    }
+}
+"#,
+            r#"
+fn main() {
+    let Ok(x) = Err(92) else { return };
+    if !(x < 30 && y < 20) {
+        return;
+    }
+    let Some(y) = Some(8) else { return };
+    foo(x, y);
+}
+"#,
+        );
+    }
+
+    #[test]
     fn convert_let_ok_inside_fn() {
         check_assist(
             convert_to_guarded_return,
@@ -561,6 +676,32 @@ fn main() {
     }
 
     #[test]
+    fn convert_let_stmt_inside_fn_return_option() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+//- minicore: option
+fn foo() -> Option<i32> {
+    None
+}
+
+fn ret_option() -> Option<i32> {
+    let x$0 = foo();
+}
+"#,
+            r#"
+fn foo() -> Option<i32> {
+    None
+}
+
+fn ret_option() -> Option<i32> {
+    let Some(x) = foo() else { return None };
+}
+"#,
+        );
+    }
+
+    #[test]
     fn convert_let_stmt_inside_loop() {
         check_assist(
             convert_to_guarded_return,