about summary refs log tree commit diff
diff options
context:
space:
mode:
authorShoyu Vanilla (Flint) <modulo641@gmail.com>2025-09-26 06:42:10 +0000
committerGitHub <noreply@github.com>2025-09-26 06:42:10 +0000
commitacd320f7b3c94f358ceafb229e773e1809a62034 (patch)
treef5d73126b49a0e97bc7a09a29f16cc0266a102ac
parent35f76dfcd99f66684544ebee4c9c25bdba01dd82 (diff)
parent11c35cd0bcb6e0c285be031f10d14d64bbf2bd9c (diff)
downloadrust-acd320f7b3c94f358ceafb229e773e1809a62034.tar.gz
rust-acd320f7b3c94f358ceafb229e773e1809a62034.zip
Merge pull request #20598 from A4-Tacks/let-chain-sup-conv-to-guarded-ret
Add let-chain support for convert_to_guarded_return
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs297
1 files changed, 252 insertions, 45 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
index 2ea032fb62b..82213ae3217 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
@@ -1,13 +1,12 @@
 use std::iter::once;
 
-use ide_db::{
-    syntax_helpers::node_ext::{is_pattern_cond, single_let},
-    ty_filter::TryEnum,
-};
+use either::Either;
+use hir::{Semantics, TypeInfo};
+use ide_db::{RootDatabase, ty_filter::TryEnum};
 use syntax::{
     AstNode,
-    SyntaxKind::{FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
-    T,
+    SyntaxKind::{CLOSURE_EXPR, FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
+    SyntaxNode, T,
     ast::{
         self,
         edit::{AstNodeEdit, IndentLevel},
@@ -44,12 +43,9 @@ use crate::{
 // }
 // ```
 pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
-    if let Some(let_stmt) = ctx.find_node_at_offset() {
-        let_stmt_to_guarded_return(let_stmt, acc, ctx)
-    } else if let Some(if_expr) = ctx.find_node_at_offset() {
-        if_expr_to_guarded_return(if_expr, acc, ctx)
-    } else {
-        None
+    match ctx.find_node_at_offset::<Either<ast::LetStmt, ast::IfExpr>>()? {
+        Either::Left(let_stmt) => let_stmt_to_guarded_return(let_stmt, acc, ctx),
+        Either::Right(if_expr) => if_expr_to_guarded_return(if_expr, acc, ctx),
     }
 }
 
@@ -73,13 +69,7 @@ fn if_expr_to_guarded_return(
         return None;
     }
 
-    // Check if there is an IfLet that we can handle.
-    let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
-        let let_ = single_let(cond)?;
-        (Some(let_.pat()?), let_.expr()?)
-    } else {
-        (None, cond)
-    };
+    let let_chains = flat_let_chain(cond);
 
     let then_block = if_expr.then_branch()?;
     let then_block = then_block.stmt_list()?;
@@ -106,11 +96,7 @@ fn if_expr_to_guarded_return(
 
     let parent_container = parent_block.syntax().parent()?;
 
-    let early_expression: ast::Expr = match parent_container.kind() {
-        WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
-        FN => make::expr_return(None),
-        _ => return None,
-    };
+    let early_expression: ast::Expr = early_expression(parent_container, &ctx.sema)?;
 
     then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{'])?;
 
@@ -132,32 +118,42 @@ fn if_expr_to_guarded_return(
         target,
         |edit| {
             let if_indent_level = IndentLevel::from_node(if_expr.syntax());
-            let replacement = match if_let_pat {
-                None => {
-                    // If.
-                    let new_expr = {
-                        let then_branch =
-                            make::block_expr(once(make::expr_stmt(early_expression).into()), None);
-                        let cond = invert_boolean_expression_legacy(cond_expr);
-                        make::expr_if(cond, then_branch, None).indent(if_indent_level)
-                    };
-                    new_expr.syntax().clone()
-                }
-                Some(pat) => {
+            let replacement = let_chains.into_iter().map(|expr| {
+                if let ast::Expr::LetExpr(let_expr) = &expr
+                    && let (Some(pat), Some(expr)) = (let_expr.pat(), let_expr.expr())
+                {
                     // If-let.
                     let let_else_stmt = make::let_else_stmt(
                         pat,
                         None,
-                        cond_expr,
-                        ast::make::tail_only_block_expr(early_expression),
+                        expr,
+                        ast::make::tail_only_block_expr(early_expression.clone()),
                     );
                     let let_else_stmt = let_else_stmt.indent(if_indent_level);
                     let_else_stmt.syntax().clone()
+                } else {
+                    // If.
+                    let new_expr = {
+                        let then_branch = make::block_expr(
+                            once(make::expr_stmt(early_expression.clone()).into()),
+                            None,
+                        );
+                        let cond = invert_boolean_expression_legacy(expr);
+                        make::expr_if(cond, then_branch, None).indent(if_indent_level)
+                    };
+                    new_expr.syntax().clone()
                 }
-            };
+            });
 
+            let newline = &format!("\n{if_indent_level}");
             let then_statements = replacement
-                .children_with_tokens()
+                .enumerate()
+                .flat_map(|(i, node)| {
+                    (i != 0)
+                        .then(|| make::tokens::whitespace(newline).into())
+                        .into_iter()
+                        .chain(node.children_with_tokens())
+                })
                 .chain(
                     then_block_items
                         .syntax()
@@ -201,11 +197,7 @@ fn let_stmt_to_guarded_return(
             let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
         let parent_container = parent_block.syntax().parent()?;
 
-        match parent_container.kind() {
-            WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
-            FN => make::expr_return(None),
-            _ => return None,
-        }
+        early_expression(parent_container, &ctx.sema)?
     };
 
     acc.add(
@@ -232,6 +224,54 @@ fn let_stmt_to_guarded_return(
     )
 }
 
+fn early_expression(
+    parent_container: SyntaxNode,
+    sema: &Semantics<'_, RootDatabase>,
+) -> Option<ast::Expr> {
+    let return_none_expr = || {
+        let none_expr = make::expr_path(make::ext::ident_path("None"));
+        make::expr_return(Some(none_expr))
+    };
+    if let Some(fn_) = ast::Fn::cast(parent_container.clone())
+        && let Some(fn_def) = sema.to_def(&fn_)
+        && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db))
+    {
+        return Some(return_none_expr());
+    }
+    if let Some(body) = ast::ClosureExpr::cast(parent_container.clone()).and_then(|it| it.body())
+        && let Some(ret_ty) = sema.type_of_expr(&body).map(TypeInfo::original)
+        && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &ret_ty)
+    {
+        return Some(return_none_expr());
+    }
+
+    Some(match parent_container.kind() {
+        WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
+        FN | CLOSURE_EXPR => make::expr_return(None),
+        _ => return None,
+    })
+}
+
+fn flat_let_chain(mut expr: ast::Expr) -> Vec<ast::Expr> {
+    let mut chains = vec![];
+
+    while let ast::Expr::BinExpr(bin_expr) = &expr
+        && bin_expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And))
+        && let (Some(lhs), Some(rhs)) = (bin_expr.lhs(), bin_expr.rhs())
+    {
+        if let Some(last) = chains.pop_if(|last| !matches!(last, ast::Expr::LetExpr(_))) {
+            chains.push(make::expr_bin_op(rhs, ast::BinaryOp::LogicOp(ast::LogicOp::And), last));
+        } else {
+            chains.push(rhs);
+        }
+        expr = lhs;
+    }
+
+    chains.push(expr);
+    chains.reverse();
+    chains
+}
+
 #[cfg(test)]
 mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
@@ -269,6 +309,71 @@ fn main() {
     }
 
     #[test]
+    fn convert_inside_fn_return_option() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+//- minicore: option
+fn ret_option() -> Option<()> {
+    bar();
+    if$0 true {
+        foo();
+
+        // comment
+        bar();
+    }
+}
+"#,
+            r#"
+fn ret_option() -> Option<()> {
+    bar();
+    if false {
+        return None;
+    }
+    foo();
+
+    // comment
+    bar();
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn convert_inside_closure() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+fn main() {
+    let _f = || {
+        bar();
+        if$0 true {
+            foo();
+
+            // comment
+            bar();
+        }
+    }
+}
+"#,
+            r#"
+fn main() {
+    let _f = || {
+        bar();
+        if false {
+            return;
+        }
+        foo();
+
+        // comment
+        bar();
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
     fn convert_let_inside_fn() {
         check_assist(
             convert_to_guarded_return,
@@ -317,6 +422,82 @@ fn main() {
     }
 
     #[test]
+    fn convert_if_let_result_inside_let() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+fn main() {
+    let _x = loop {
+        if$0 let Ok(x) = Err(92) {
+            foo(x);
+        }
+    };
+}
+"#,
+            r#"
+fn main() {
+    let _x = loop {
+        let Ok(x) = Err(92) else { continue };
+        foo(x);
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn convert_if_let_chain_result() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+fn main() {
+    if$0 let Ok(x) = Err(92)
+        && x < 30
+        && let Some(y) = Some(8)
+    {
+        foo(x, y);
+    }
+}
+"#,
+            r#"
+fn main() {
+    let Ok(x) = Err(92) else { return };
+    if x >= 30 {
+        return;
+    }
+    let Some(y) = Some(8) else { return };
+    foo(x, y);
+}
+"#,
+        );
+
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+fn main() {
+    if$0 let Ok(x) = Err(92)
+        && x < 30
+        && y < 20
+        && let Some(y) = Some(8)
+    {
+        foo(x, y);
+    }
+}
+"#,
+            r#"
+fn main() {
+    let Ok(x) = Err(92) else { return };
+    if !(x < 30 && y < 20) {
+        return;
+    }
+    let Some(y) = Some(8) else { return };
+    foo(x, y);
+}
+"#,
+        );
+    }
+
+    #[test]
     fn convert_let_ok_inside_fn() {
         check_assist(
             convert_to_guarded_return,
@@ -561,6 +742,32 @@ fn main() {
     }
 
     #[test]
+    fn convert_let_stmt_inside_fn_return_option() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+//- minicore: option
+fn foo() -> Option<i32> {
+    None
+}
+
+fn ret_option() -> Option<i32> {
+    let x$0 = foo();
+}
+"#,
+            r#"
+fn foo() -> Option<i32> {
+    None
+}
+
+fn ret_option() -> Option<i32> {
+    let Some(x) = foo() else { return None };
+}
+"#,
+        );
+    }
+
+    #[test]
     fn convert_let_stmt_inside_loop() {
         check_assist(
             convert_to_guarded_return,