about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors[bot] <26634292+bors[bot]@users.noreply.github.com>2021-10-17 13:32:35 +0000
committerGitHub <noreply@github.com>2021-10-17 13:32:35 +0000
commit401daa5f77fd9cfb79d16fe3a54bc576d60b4c82 (patch)
treeb755984633bca194a1ecd8b5f00e12471dbf17f7
parentd9080addf95679f5b7a9e315f9501163d83ce1be (diff)
parentccf05debfe0b3c3f098ce88daed6aa6651fcc0e1 (diff)
downloadrust-401daa5f77fd9cfb79d16fe3a54bc576d60b4c82.tar.gz
rust-401daa5f77fd9cfb79d16fe3a54bc576d60b4c82.zip
Merge #10417
10417: feat(assist): add new assist to unwrap the result return type r=bnjjj a=bnjjj

do the opposite of assist "wrap the return type in Result"

Co-authored-by: Benjamin Coenen <5719034+bnjjj@users.noreply.github.com>
Co-authored-by: Coenen Benjamin <benjamin.coenen@hotmail.com>
-rw-r--r--crates/ide_assists/src/handlers/unwrap_result_return_type.rs943
-rw-r--r--crates/ide_assists/src/lib.rs2
-rw-r--r--crates/ide_assists/src/tests/generated.rs14
3 files changed, 959 insertions, 0 deletions
diff --git a/crates/ide_assists/src/handlers/unwrap_result_return_type.rs b/crates/ide_assists/src/handlers/unwrap_result_return_type.rs
new file mode 100644
index 00000000000..6d813aa6722
--- /dev/null
+++ b/crates/ide_assists/src/handlers/unwrap_result_return_type.rs
@@ -0,0 +1,943 @@
+use ide_db::helpers::{for_each_tail_expr, node_ext::walk_expr, FamousDefs};
+use syntax::{
+    ast::{self, Expr},
+    match_ast, AstNode,
+};
+
+use crate::{AssistContext, AssistId, AssistKind, Assists};
+
+// Assist: unwrap_result_return_type
+//
+// Unwrap the function's return type.
+//
+// ```
+// # //- minicore: result
+// fn foo() -> Result<i32>$0 { Ok(42i32) }
+// ```
+// ->
+// ```
+// fn foo() -> i32 { 42i32 }
+// ```
+pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
+    let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
+    let parent = ret_type.syntax().parent()?;
+    let body = match_ast! {
+        match parent {
+            ast::Fn(func) => func.body()?,
+            ast::ClosureExpr(closure) => match closure.body()? {
+                Expr::BlockExpr(block) => block,
+                // closures require a block when a return type is specified
+                _ => return None,
+            },
+            _ => return None,
+        }
+    };
+
+    let type_ref = &ret_type.ty()?;
+    let ty = ctx.sema.resolve_type(type_ref).and_then(|ty| ty.as_adt());
+    let result_enum =
+        FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax()).krate()).core_result_Result()?;
+
+    if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
+        return None;
+    }
+
+    acc.add(
+        AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite),
+        "Unwrap Result return type",
+        type_ref.syntax().text_range(),
+        |builder| {
+            let body = ast::Expr::BlockExpr(body);
+
+            let mut exprs_to_unwrap = Vec::new();
+            let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_unwrap, e);
+            walk_expr(&body, &mut |expr| {
+                if let Expr::ReturnExpr(ret_expr) = expr {
+                    if let Some(ret_expr_arg) = &ret_expr.expr() {
+                        for_each_tail_expr(ret_expr_arg, tail_cb);
+                    }
+                }
+            });
+            for_each_tail_expr(&body, tail_cb);
+
+            for ret_expr_arg in exprs_to_unwrap {
+                let new_ret_expr = ret_expr_arg.to_string();
+                let new_ret_expr =
+                    new_ret_expr.trim_start_matches("Ok(").trim_start_matches("Err(");
+                builder.replace(
+                    ret_expr_arg.syntax().text_range(),
+                    new_ret_expr.strip_suffix(')').unwrap_or(new_ret_expr),
+                )
+            }
+
+            if let Some((_, inner_type)) = type_ref.to_string().split_once('<') {
+                let inner_type = match inner_type.split_once(',') {
+                    Some((success_inner_type, _)) => success_inner_type,
+                    None => inner_type,
+                };
+                builder.replace(
+                    type_ref.syntax().text_range(),
+                    inner_type.strip_suffix('>').unwrap_or(inner_type),
+                )
+            }
+        },
+    )
+}
+
+fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
+    match e {
+        Expr::BreakExpr(break_expr) => {
+            if let Some(break_expr_arg) = break_expr.expr() {
+                for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e))
+            }
+        }
+        Expr::ReturnExpr(ret_expr) => {
+            if let Some(ret_expr_arg) = &ret_expr.expr() {
+                for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e));
+            }
+        }
+        e => acc.push(e.clone()),
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::tests::{check_assist, check_assist_not_applicable};
+
+    use super::*;
+
+    #[test]
+    fn unwrap_result_return_type_simple() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i3$02> {
+    let test = "test";
+    return Ok(42i32);
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    let test = "test";
+    return 42i32;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_return_type_break_split_tail() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i3$02, String> {
+    loop {
+        break if true {
+            Ok(1)
+        } else {
+            Ok(0)
+        };
+    }
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    loop {
+        break if true {
+            1
+        } else {
+            0
+        };
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_closure() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() {
+    || -> Result<i32$0> {
+        let test = "test";
+        return Ok(42i32);
+    };
+}
+"#,
+            r#"
+fn foo() {
+    || -> i32 {
+        let test = "test";
+        return 42i32;
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_return_type_bad_cursor() {
+        check_assist_not_applicable(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> i32 {
+    let test = "test";$0
+    return 42i32;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_return_type_bad_cursor_closure() {
+        check_assist_not_applicable(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() {
+    || -> i32 {
+        let test = "test";$0
+        return 42i32;
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_closure_non_block() {
+        check_assist_not_applicable(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() { || -> i$032 3; }
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_return_type_already_not_result_std() {
+        check_assist_not_applicable(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> i32$0 {
+    let test = "test";
+    return 42i32;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_return_type_already_not_result_closure() {
+        check_assist_not_applicable(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() {
+    || -> i32$0 {
+        let test = "test";
+        return 42i32;
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_tail() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() ->$0 Result<i32> {
+    let test = "test";
+    Ok(42i32)
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    let test = "test";
+    42i32
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_tail_closure() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() {
+    || ->$0 Result<i32, String> {
+        let test = "test";
+        Ok(42i32)
+    };
+}
+"#,
+            r#"
+fn foo() {
+    || -> i32 {
+        let test = "test";
+        42i32
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_tail_only() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32$0> { Ok(42i32) }
+"#,
+            r#"
+fn foo() -> i32 { 42i32 }
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_tail_block_like() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32>$0 {
+    if true {
+        Ok(42i32)
+    } else {
+        Ok(24i32)
+    }
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    if true {
+        42i32
+    } else {
+        24i32
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_without_block_closure() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() {
+    || -> Result<i32, String>$0 {
+        if true {
+            Ok(42i32)
+        } else {
+            Ok(24i32)
+        }
+    };
+}
+"#,
+            r#"
+fn foo() {
+    || -> i32 {
+        if true {
+            42i32
+        } else {
+            24i32
+        }
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_nested_if() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32>$0 {
+    if true {
+        if false {
+            Ok(1)
+        } else {
+            Ok(2)
+        }
+    } else {
+        Ok(24i32)
+    }
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    if true {
+        if false {
+            1
+        } else {
+            2
+        }
+    } else {
+        24i32
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_await() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+async fn foo() -> Result<i$032> {
+    if true {
+        if false {
+            Ok(1.await)
+        } else {
+            Ok(2.await)
+        }
+    } else {
+        Ok(24i32.await)
+    }
+}
+"#,
+            r#"
+async fn foo() -> i32 {
+    if true {
+        if false {
+            1.await
+        } else {
+            2.await
+        }
+    } else {
+        24i32.await
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_array() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<[i32; 3]$0> { Ok([1, 2, 3]) }
+"#,
+            r#"
+fn foo() -> [i32; 3] { [1, 2, 3] }
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_cast() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -$0> Result<i32> {
+    if true {
+        if false {
+            Ok(1 as i32)
+        } else {
+            Ok(2 as i32)
+        }
+    } else {
+        Ok(24 as i32)
+    }
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    if true {
+        if false {
+            1 as i32
+        } else {
+            2 as i32
+        }
+    } else {
+        24 as i32
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_tail_block_like_match() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32$0> {
+    let my_var = 5;
+    match my_var {
+        5 => Ok(42i32),
+        _ => Ok(24i32),
+    }
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    let my_var = 5;
+    match my_var {
+        5 => 42i32,
+        _ => 24i32,
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_loop_with_tail() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32$0> {
+    let my_var = 5;
+    loop {
+        println!("test");
+        5
+    }
+    Ok(my_var)
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    let my_var = 5;
+    loop {
+        println!("test");
+        5
+    }
+    my_var
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_loop_in_let_stmt() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32$0> {
+    let my_var = let x = loop {
+        break 1;
+    };
+    Ok(my_var)
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    let my_var = let x = loop {
+        break 1;
+    };
+    my_var
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_tail_block_like_match_return_expr() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32>$0 {
+    let my_var = 5;
+    let res = match my_var {
+        5 => 42i32,
+        _ => return Ok(24i32),
+    };
+    Ok(res)
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    let my_var = 5;
+    let res = match my_var {
+        5 => 42i32,
+        _ => return 24i32,
+    };
+    res
+}
+"#,
+        );
+
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32$0> {
+    let my_var = 5;
+    let res = if my_var == 5 {
+        42i32
+    } else {
+        return Ok(24i32);
+    };
+    Ok(res)
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    let my_var = 5;
+    let res = if my_var == 5 {
+        42i32
+    } else {
+        return 24i32;
+    };
+    res
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_tail_block_like_match_deeper() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32$0> {
+    let my_var = 5;
+    match my_var {
+        5 => {
+            if true {
+                Ok(42i32)
+            } else {
+                Ok(25i32)
+            }
+        },
+        _ => {
+            let test = "test";
+            if test == "test" {
+                return Ok(bar());
+            }
+            Ok(53i32)
+        },
+    }
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    let my_var = 5;
+    match my_var {
+        5 => {
+            if true {
+                42i32
+            } else {
+                25i32
+            }
+        },
+        _ => {
+            let test = "test";
+            if test == "test" {
+                return bar();
+            }
+            53i32
+        },
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_tail_block_like_early_return() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32$0> {
+    let test = "test";
+    if test == "test" {
+        return Ok(24i32);
+    }
+    Ok(53i32)
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    let test = "test";
+    if test == "test" {
+        return 24i32;
+    }
+    53i32
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_closure() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo(the_field: u32) -> Result<u32$0> {
+    let true_closure = || { return true; };
+    if the_field < 5 {
+        let mut i = 0;
+        if true_closure() {
+            return Ok(99);
+        } else {
+            return Ok(0);
+        }
+    }
+    Ok(the_field)
+}
+"#,
+            r#"
+fn foo(the_field: u32) -> u32 {
+    let true_closure = || { return true; };
+    if the_field < 5 {
+        let mut i = 0;
+        if true_closure() {
+            return 99;
+        } else {
+            return 0;
+        }
+    }
+    the_field
+}
+"#,
+        );
+
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo(the_field: u32) -> Result<u32$0> {
+    let true_closure = || {
+        return true;
+    };
+    if the_field < 5 {
+        let mut i = 0;
+
+
+        if true_closure() {
+            return Ok(99);
+        } else {
+            return Ok(0);
+        }
+    }
+    let t = None;
+
+    Ok(t.unwrap_or_else(|| the_field))
+}
+"#,
+            r#"
+fn foo(the_field: u32) -> u32 {
+    let true_closure = || {
+        return true;
+    };
+    if the_field < 5 {
+        let mut i = 0;
+
+
+        if true_closure() {
+            return 99;
+        } else {
+            return 0;
+        }
+    }
+    let t = None;
+
+    t.unwrap_or_else(|| the_field)
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_result_return_type_simple_with_weird_forms() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<i32$0> {
+    let test = "test";
+    if test == "test" {
+        return Ok(24i32);
+    }
+    let mut i = 0;
+    loop {
+        if i == 1 {
+            break Ok(55);
+        }
+        i += 1;
+    }
+}
+"#,
+            r#"
+fn foo() -> i32 {
+    let test = "test";
+    if test == "test" {
+        return 24i32;
+    }
+    let mut i = 0;
+    loop {
+        if i == 1 {
+            break 55;
+        }
+        i += 1;
+    }
+}
+"#,
+        );
+
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo(the_field: u32) -> Result<u32$0> {
+    if the_field < 5 {
+        let mut i = 0;
+        loop {
+            if i > 5 {
+                return Ok(55u32);
+            }
+            i += 3;
+        }
+        match i {
+            5 => return Ok(99),
+            _ => return Ok(0),
+        };
+    }
+    Ok(the_field)
+}
+"#,
+            r#"
+fn foo(the_field: u32) -> u32 {
+    if the_field < 5 {
+        let mut i = 0;
+        loop {
+            if i > 5 {
+                return 55u32;
+            }
+            i += 3;
+        }
+        match i {
+            5 => return 99,
+            _ => return 0,
+        };
+    }
+    the_field
+}
+"#,
+        );
+
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo(the_field: u32) -> Result<u32$0> {
+    if the_field < 5 {
+        let mut i = 0;
+        match i {
+            5 => return Ok(99),
+            _ => return Ok(0),
+        }
+    }
+    Ok(the_field)
+}
+"#,
+            r#"
+fn foo(the_field: u32) -> u32 {
+    if the_field < 5 {
+        let mut i = 0;
+        match i {
+            5 => return 99,
+            _ => return 0,
+        }
+    }
+    the_field
+}
+"#,
+        );
+
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo(the_field: u32) -> Result<u32$0> {
+    if the_field < 5 {
+        let mut i = 0;
+        if i == 5 {
+            return Ok(99)
+        } else {
+            return Ok(0)
+        }
+    }
+    Ok(the_field)
+}
+"#,
+            r#"
+fn foo(the_field: u32) -> u32 {
+    if the_field < 5 {
+        let mut i = 0;
+        if i == 5 {
+            return 99
+        } else {
+            return 0
+        }
+    }
+    the_field
+}
+"#,
+        );
+
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo(the_field: u32) -> Result<u3$02> {
+    if the_field < 5 {
+        let mut i = 0;
+        if i == 5 {
+            return Ok(99);
+        } else {
+            return Ok(0);
+        }
+    }
+    Ok(the_field)
+}
+"#,
+            r#"
+fn foo(the_field: u32) -> u32 {
+    if the_field < 5 {
+        let mut i = 0;
+        if i == 5 {
+            return 99;
+        } else {
+            return 0;
+        }
+    }
+    the_field
+}
+"#,
+        );
+    }
+}
diff --git a/crates/ide_assists/src/lib.rs b/crates/ide_assists/src/lib.rs
index ea2c19b5087..dccd071dccf 100644
--- a/crates/ide_assists/src/lib.rs
+++ b/crates/ide_assists/src/lib.rs
@@ -178,6 +178,7 @@ mod handlers {
     mod toggle_ignore;
     mod unmerge_use;
     mod unwrap_block;
+    mod unwrap_result_return_type;
     mod wrap_return_type_in_result;
 
     pub(crate) fn all() -> &'static [Handler] {
@@ -259,6 +260,7 @@ mod handlers {
             toggle_ignore::toggle_ignore,
             unmerge_use::unmerge_use,
             unwrap_block::unwrap_block,
+            unwrap_result_return_type::unwrap_result_return_type,
             wrap_return_type_in_result::wrap_return_type_in_result,
             // These are manually sorted for better priorities. By default,
             // priority is determined by the size of the target range (smaller
diff --git a/crates/ide_assists/src/tests/generated.rs b/crates/ide_assists/src/tests/generated.rs
index 25acd534824..539f81622b1 100644
--- a/crates/ide_assists/src/tests/generated.rs
+++ b/crates/ide_assists/src/tests/generated.rs
@@ -1965,6 +1965,20 @@ fn foo() {
 }
 
 #[test]
+fn doctest_unwrap_result_return_type() {
+    check_doc_test(
+        "unwrap_result_return_type",
+        r#####"
+//- minicore: result
+fn foo() -> Result<i32>$0 { Ok(42i32) }
+"#####,
+        r#####"
+fn foo() -> i32 { 42i32 }
+"#####,
+    )
+}
+
+#[test]
 fn doctest_wrap_return_type_in_result() {
     check_doc_test(
         "wrap_return_type_in_result",