about summary refs log tree commit diff
diff options
context:
space:
mode:
authordfireBird <me@dfirebird.dev>2024-01-25 00:06:22 +0530
committerdfireBird <me@dfirebird.dev>2024-01-25 00:13:16 +0530
commite0446a0eb5471d28ec74630ecc43c2ae5da94b79 (patch)
treee9b81dea344bbe108d109a0e21ebe19c8f3b9dc0
parentd410d4a2baf9e99b37b03dd42f06238b14374bf7 (diff)
downloadrust-e0446a0eb5471d28ec74630ecc43c2ae5da94b79.tar.gz
rust-e0446a0eb5471d28ec74630ecc43c2ae5da94b79.zip
implement assist for let stmt with TryEnum type to guarded return
-rw-r--r--crates/ide-assists/src/handlers/convert_to_guarded_return.rs126
1 files changed, 124 insertions, 2 deletions
diff --git a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
index 6f30ffa622d..5fc1c1dda62 100644
--- a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
+++ b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
@@ -1,6 +1,9 @@
 use std::iter::once;
 
-use ide_db::syntax_helpers::node_ext::{is_pattern_cond, single_let};
+use ide_db::{
+    syntax_helpers::node_ext::{is_pattern_cond, single_let},
+    ty_filter::TryEnum,
+};
 use syntax::{
     ast::{
         self,
@@ -41,7 +44,20 @@ use crate::{
 // }
 // ```
 pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
-    let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
+    if let Some(let_stmt) = ctx.find_node_at_offset() {
+        let_stmt_to_guarded_return(let_stmt, acc, ctx)
+    } else if let Some(if_expr) = ctx.find_node_at_offset() {
+        if_expr_to_guarded_return(if_expr, acc, ctx)
+    } else {
+        None
+    }
+}
+
+fn if_expr_to_guarded_return(
+    if_expr: ast::IfExpr,
+    acc: &mut Assists,
+    _ctx: &AssistContext<'_>,
+) -> Option<()> {
     if if_expr.else_branch().is_some() {
         return None;
     }
@@ -148,6 +164,56 @@ pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'
     )
 }
 
+fn let_stmt_to_guarded_return(
+    let_stmt: ast::LetStmt,
+    acc: &mut Assists,
+    ctx: &AssistContext<'_>,
+) -> Option<()> {
+    let pat = let_stmt.pat()?;
+    let expr = let_stmt.initializer()?;
+
+    let try_enum =
+        ctx.sema.type_of_expr(&expr).and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted()))?;
+
+    let happy_pattern = try_enum.happy_pattern(pat);
+    let target = let_stmt.syntax().text_range();
+
+    let early_expression: ast::Expr = {
+        let parent_block =
+            let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
+        let parent_container = parent_block.syntax().parent()?;
+
+        match parent_container.kind() {
+            WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
+            FN => make::expr_return(None),
+            _ => return None,
+        }
+    };
+
+    acc.add(
+        AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite),
+        "Convert to guarded return",
+        target,
+        |edit| {
+            let let_stmt = edit.make_mut(let_stmt);
+            let let_indent_level = IndentLevel::from_node(let_stmt.syntax());
+
+            let replacement = {
+                let let_else_stmt = make::let_else_stmt(
+                    happy_pattern,
+                    let_stmt.ty(),
+                    expr,
+                    ast::make::tail_only_block_expr(early_expression),
+                );
+                let let_else_stmt = let_else_stmt.indent(let_indent_level);
+                let_else_stmt.syntax().clone_for_update()
+            };
+
+            ted::replace(let_stmt.syntax(), replacement)
+        },
+    )
+}
+
 #[cfg(test)]
 mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
@@ -451,6 +517,62 @@ fn main() {
     }
 
     #[test]
+    fn convert_let_stmt_inside_fn() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+//- minicore: option
+fn foo() -> Option<i32> {
+    None
+}
+
+fn main() {
+    let x$0 = foo();
+}
+"#,
+            r#"
+fn foo() -> Option<i32> {
+    None
+}
+
+fn main() {
+    let Some(x) = foo() else { return };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn convert_let_stmt_inside_loop() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+//- minicore: option
+fn foo() -> Option<i32> {
+    None
+}
+
+fn main() {
+    loop {
+        let x$0 = foo();
+    }
+}
+"#,
+            r#"
+fn foo() -> Option<i32> {
+    None
+}
+
+fn main() {
+    loop {
+        let Some(x) = foo() else { continue };
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
     fn convert_arbitrary_if_let_patterns() {
         check_assist(
             convert_to_guarded_return,