about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs54
1 files changed, 49 insertions, 5 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
index aee9ce7878b..82213ae3217 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
@@ -1,11 +1,11 @@
 use std::iter::once;
 
-use hir::Semantics;
 use either::Either;
+use hir::{Semantics, TypeInfo};
 use ide_db::{RootDatabase, ty_filter::TryEnum};
 use syntax::{
     AstNode,
-    SyntaxKind::{FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
+    SyntaxKind::{CLOSURE_EXPR, FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
     SyntaxNode, T,
     ast::{
         self,
@@ -228,16 +228,26 @@ fn early_expression(
     parent_container: SyntaxNode,
     sema: &Semantics<'_, RootDatabase>,
 ) -> Option<ast::Expr> {
+    let return_none_expr = || {
+        let none_expr = make::expr_path(make::ext::ident_path("None"));
+        make::expr_return(Some(none_expr))
+    };
     if let Some(fn_) = ast::Fn::cast(parent_container.clone())
         && let Some(fn_def) = sema.to_def(&fn_)
         && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db))
     {
-        let none_expr = make::expr_path(make::ext::ident_path("None"));
-        return Some(make::expr_return(Some(none_expr)));
+        return Some(return_none_expr());
+    }
+    if let Some(body) = ast::ClosureExpr::cast(parent_container.clone()).and_then(|it| it.body())
+        && let Some(ret_ty) = sema.type_of_expr(&body).map(TypeInfo::original)
+        && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &ret_ty)
+    {
+        return Some(return_none_expr());
     }
+
     Some(match parent_container.kind() {
         WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
-        FN => make::expr_return(None),
+        FN | CLOSURE_EXPR => make::expr_return(None),
         _ => return None,
     })
 }
@@ -330,6 +340,40 @@ fn ret_option() -> Option<()> {
     }
 
     #[test]
+    fn convert_inside_closure() {
+        check_assist(
+            convert_to_guarded_return,
+            r#"
+fn main() {
+    let _f = || {
+        bar();
+        if$0 true {
+            foo();
+
+            // comment
+            bar();
+        }
+    }
+}
+"#,
+            r#"
+fn main() {
+    let _f = || {
+        bar();
+        if false {
+            return;
+        }
+        foo();
+
+        // comment
+        bar();
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
     fn convert_let_inside_fn() {
         check_assist(
             convert_to_guarded_return,