about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/type_mismatch.rs437
1 files changed, 375 insertions, 62 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/type_mismatch.rs
index 93fe9374a3e..8994ab50e67 100644
--- a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/type_mismatch.rs
+++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/type_mismatch.rs
@@ -1,14 +1,17 @@
 use either::Either;
-use hir::{db::ExpandDatabase, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type};
-use ide_db::text_edit::TextEdit;
-use ide_db::{famous_defs::FamousDefs, source_change::SourceChange};
+use hir::{db::ExpandDatabase, CallableKind, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type};
+use ide_db::{
+    famous_defs::FamousDefs,
+    source_change::{SourceChange, SourceChangeBuilder},
+    text_edit::TextEdit,
+};
 use syntax::{
     ast::{
         self,
         edit::{AstNodeEdit, IndentLevel},
-        BlockExpr, Expr, ExprStmt,
+        make, BlockExpr, Expr, ExprStmt, HasArgList,
     },
-    AstNode, AstPtr, TextSize,
+    ted, AstNode, AstPtr, TextSize,
 };
 
 use crate::{adjusted_display_range, fix, Assist, Diagnostic, DiagnosticCode, DiagnosticsContext};
@@ -63,6 +66,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option<Vec<Assi
         let expr_ptr = &InFile { file_id: d.expr_or_pat.file_id, value: expr_ptr };
         add_reference(ctx, d, expr_ptr, &mut fixes);
         add_missing_ok_or_some(ctx, d, expr_ptr, &mut fixes);
+        remove_unnecessary_wrapper(ctx, d, expr_ptr, &mut fixes);
         remove_semicolon(ctx, d, expr_ptr, &mut fixes);
         str_ref_to_owned(ctx, d, expr_ptr, &mut fixes);
     }
@@ -184,6 +188,90 @@ fn add_missing_ok_or_some(
     Some(())
 }
 
+fn remove_unnecessary_wrapper(
+    ctx: &DiagnosticsContext<'_>,
+    d: &hir::TypeMismatch,
+    expr_ptr: &InFile<AstPtr<ast::Expr>>,
+    acc: &mut Vec<Assist>,
+) -> Option<()> {
+    let db = ctx.sema.db;
+    let root = db.parse_or_expand(expr_ptr.file_id);
+    let expr = expr_ptr.value.to_node(&root);
+    let expr = ctx.sema.original_ast_node(expr.clone())?;
+
+    let Expr::CallExpr(call_expr) = expr else {
+        return None;
+    };
+
+    let callable = ctx.sema.resolve_expr_as_callable(&call_expr.expr()?)?;
+    let CallableKind::TupleEnumVariant(variant) = callable.kind() else {
+        return None;
+    };
+
+    let actual_enum = d.actual.as_adt()?.as_enum()?;
+    let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(call_expr.syntax())?.krate());
+    let core_option = famous_defs.core_option_Option();
+    let core_result = famous_defs.core_result_Result();
+    if Some(actual_enum) != core_option && Some(actual_enum) != core_result {
+        return None;
+    }
+
+    let inner_type = variant.fields(db).first()?.ty_with_args(db, d.actual.type_arguments());
+    if !d.expected.could_unify_with(db, &inner_type) {
+        return None;
+    }
+
+    let inner_arg = call_expr.arg_list()?.args().next()?;
+
+    let mut builder = SourceChangeBuilder::new(expr_ptr.file_id.original_file(ctx.sema.db));
+
+    match inner_arg {
+        // We're returning `()`
+        Expr::TupleExpr(tup) if tup.fields().next().is_none() => {
+            let parent = call_expr
+                .syntax()
+                .parent()
+                .and_then(Either::<ast::ReturnExpr, ast::StmtList>::cast)?;
+
+            match parent {
+                Either::Left(ret_expr) => {
+                    let old = builder.make_mut(ret_expr);
+                    let new = make::expr_return(None).clone_for_update();
+
+                    ted::replace(old.syntax(), new.syntax());
+                }
+                Either::Right(stmt_list) => {
+                    if stmt_list.statements().count() == 0 {
+                        let block = stmt_list.syntax().parent().and_then(ast::BlockExpr::cast)?;
+                        let old = builder.make_mut(block);
+                        let new = make::expr_empty_block().clone_for_update();
+
+                        ted::replace(old.syntax(), new.syntax());
+                    } else {
+                        let old = builder.make_syntax_mut(stmt_list.syntax().parent()?);
+                        let new = make::block_expr(stmt_list.statements(), None).clone_for_update();
+
+                        ted::replace(old, new.syntax());
+                    }
+                }
+            }
+        }
+        _ => {
+            let call_mut = builder.make_mut(call_expr.clone());
+            ted::replace(call_mut.syntax(), inner_arg.clone_for_update().syntax());
+        }
+    }
+
+    let name = format!("Remove unnecessary {}() wrapper", variant.name(db).as_str());
+    acc.push(fix(
+        "remove_unnecessary_wrapper",
+        &name,
+        builder.finish(),
+        call_expr.syntax().text_range(),
+    ));
+    Some(())
+}
+
 fn remove_semicolon(
     ctx: &DiagnosticsContext<'_>,
     d: &hir::TypeMismatch,
@@ -243,7 +331,7 @@ fn str_ref_to_owned(
 #[cfg(test)]
 mod tests {
     use crate::tests::{
-        check_diagnostics, check_diagnostics_with_disabled, check_fix, check_no_fix,
+        check_diagnostics, check_diagnostics_with_disabled, check_fix, check_has_fix, check_no_fix,
     };
 
     #[test]
@@ -260,7 +348,7 @@ fn test(_arg: &i32) {}
     }
 
     #[test]
-    fn test_add_reference_to_int() {
+    fn add_reference_to_int() {
         check_fix(
             r#"
 fn main() {
@@ -278,7 +366,7 @@ fn test(_arg: &i32) {}
     }
 
     #[test]
-    fn test_add_mutable_reference_to_int() {
+    fn add_mutable_reference_to_int() {
         check_fix(
             r#"
 fn main() {
@@ -296,7 +384,7 @@ fn test(_arg: &mut i32) {}
     }
 
     #[test]
-    fn test_add_reference_to_array() {
+    fn add_reference_to_array() {
         check_fix(
             r#"
 //- minicore: coerce_unsized
@@ -315,7 +403,7 @@ fn test(_arg: &[i32]) {}
     }
 
     #[test]
-    fn test_add_reference_with_autoderef() {
+    fn add_reference_with_autoderef() {
         check_fix(
             r#"
 //- minicore: coerce_unsized, deref
@@ -348,7 +436,7 @@ fn test(_arg: &Bar) {}
     }
 
     #[test]
-    fn test_add_reference_to_method_call() {
+    fn add_reference_to_method_call() {
         check_fix(
             r#"
 fn main() {
@@ -372,7 +460,7 @@ impl Test {
     }
 
     #[test]
-    fn test_add_reference_to_let_stmt() {
+    fn add_reference_to_let_stmt() {
         check_fix(
             r#"
 fn main() {
@@ -388,7 +476,7 @@ fn main() {
     }
 
     #[test]
-    fn test_add_reference_to_macro_call() {
+    fn add_reference_to_macro_call() {
         check_fix(
             r#"
 macro_rules! thousand {
@@ -416,7 +504,7 @@ fn main() {
     }
 
     #[test]
-    fn test_add_mutable_reference_to_let_stmt() {
+    fn add_mutable_reference_to_let_stmt() {
         check_fix(
             r#"
 fn main() {
@@ -432,29 +520,6 @@ fn main() {
     }
 
     #[test]
-    fn test_wrap_return_type_option() {
-        check_fix(
-            r#"
-//- minicore: option, result
-fn div(x: i32, y: i32) -> Option<i32> {
-    if y == 0 {
-        return None;
-    }
-    x / y$0
-}
-"#,
-            r#"
-fn div(x: i32, y: i32) -> Option<i32> {
-    if y == 0 {
-        return None;
-    }
-    Some(x / y)
-}
-"#,
-        );
-    }
-
-    #[test]
     fn const_generic_type_mismatch() {
         check_diagnostics(
             r#"
@@ -487,59 +552,82 @@ fn div(x: i32, y: i32) -> Option<i32> {
     }
 
     #[test]
-    fn test_wrap_return_type_option_tails() {
+    fn wrap_return_type() {
+        check_fix(
+            r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> Result<i32, ()> {
+    if y == 0 {
+        return Err(());
+    }
+    x / y$0
+}
+"#,
+            r#"
+fn div(x: i32, y: i32) -> Result<i32, ()> {
+    if y == 0 {
+        return Err(());
+    }
+    Ok(x / y)
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn wrap_return_type_option() {
         check_fix(
             r#"
 //- minicore: option, result
 fn div(x: i32, y: i32) -> Option<i32> {
     if y == 0 {
-        Some(0)
-    } else if true {
-        100$0
-    } else {
-        None
+        return None;
     }
+    x / y$0
 }
 "#,
             r#"
 fn div(x: i32, y: i32) -> Option<i32> {
     if y == 0 {
-        Some(0)
-    } else if true {
-        Some(100)
-    } else {
-        None
+        return None;
     }
+    Some(x / y)
 }
 "#,
         );
     }
 
     #[test]
-    fn test_wrap_return_type() {
+    fn wrap_return_type_option_tails() {
         check_fix(
             r#"
 //- minicore: option, result
-fn div(x: i32, y: i32) -> Result<i32, ()> {
+fn div(x: i32, y: i32) -> Option<i32> {
     if y == 0 {
-        return Err(());
+        Some(0)
+    } else if true {
+        100$0
+    } else {
+        None
     }
-    x / y$0
 }
 "#,
             r#"
-fn div(x: i32, y: i32) -> Result<i32, ()> {
+fn div(x: i32, y: i32) -> Option<i32> {
     if y == 0 {
-        return Err(());
+        Some(0)
+    } else if true {
+        Some(100)
+    } else {
+        None
     }
-    Ok(x / y)
 }
 "#,
         );
     }
 
     #[test]
-    fn test_wrap_return_type_handles_generic_functions() {
+    fn wrap_return_type_handles_generic_functions() {
         check_fix(
             r#"
 //- minicore: option, result
@@ -562,7 +650,7 @@ fn div<T>(x: T) -> Result<T, i32> {
     }
 
     #[test]
-    fn test_wrap_return_type_handles_type_aliases() {
+    fn wrap_return_type_handles_type_aliases() {
         check_fix(
             r#"
 //- minicore: option, result
@@ -589,7 +677,7 @@ fn div(x: i32, y: i32) -> MyResult<i32> {
     }
 
     #[test]
-    fn test_wrapped_unit_as_block_tail_expr() {
+    fn wrapped_unit_as_block_tail_expr() {
         check_fix(
             r#"
 //- minicore: result
@@ -619,7 +707,7 @@ fn foo() -> Result<(), ()> {
     }
 
     #[test]
-    fn test_wrapped_unit_as_return_expr() {
+    fn wrapped_unit_as_return_expr() {
         check_fix(
             r#"
 //- minicore: result
@@ -642,7 +730,7 @@ fn foo(b: bool) -> Result<(), String> {
     }
 
     #[test]
-    fn test_in_const_and_static() {
+    fn wrap_in_const_and_static() {
         check_fix(
             r#"
 //- minicore: option, result
@@ -664,7 +752,7 @@ const _: Option<()> = {Some(())};
     }
 
     #[test]
-    fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() {
+    fn wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() {
         check_no_fix(
             r#"
 //- minicore: option, result
@@ -674,7 +762,7 @@ fn foo() -> Result<(), i32> { 0$0 }
     }
 
     #[test]
-    fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() {
+    fn wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() {
         check_no_fix(
             r#"
 //- minicore: option, result
@@ -686,6 +774,231 @@ fn foo() -> SomeOtherEnum { 0$0 }
     }
 
     #[test]
+    fn unwrap_return_type() {
+        check_fix(
+            r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> i32 {
+    if y == 0 {
+        panic!();
+    }
+    Ok(x / y)$0
+}
+"#,
+            r#"
+fn div(x: i32, y: i32) -> i32 {
+    if y == 0 {
+        panic!();
+    }
+    x / y
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_return_type_option() {
+        check_fix(
+            r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> i32 {
+    if y == 0 {
+        panic!();
+    }
+    Some(x / y)$0
+}
+"#,
+            r#"
+fn div(x: i32, y: i32) -> i32 {
+    if y == 0 {
+        panic!();
+    }
+    x / y
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_return_type_option_tails() {
+        check_fix(
+            r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> i32 {
+    if y == 0 {
+        42
+    } else if true {
+        Some(100)$0
+    } else {
+        0
+    }
+}
+"#,
+            r#"
+fn div(x: i32, y: i32) -> i32 {
+    if y == 0 {
+        42
+    } else if true {
+        100
+    } else {
+        0
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_return_type_handles_generic_functions() {
+        check_fix(
+            r#"
+//- minicore: option, result
+fn div<T>(x: T) -> T {
+    if x == 0 {
+        panic!();
+    }
+    $0Ok(x)
+}
+"#,
+            r#"
+fn div<T>(x: T) -> T {
+    if x == 0 {
+        panic!();
+    }
+    x
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_return_type_handles_type_aliases() {
+        check_fix(
+            r#"
+//- minicore: option, result
+type MyResult<T> = T;
+
+fn div(x: i32, y: i32) -> MyResult<i32> {
+    if y == 0 {
+        panic!();
+    }
+    Ok(x $0/ y)
+}
+"#,
+            r#"
+type MyResult<T> = T;
+
+fn div(x: i32, y: i32) -> MyResult<i32> {
+    if y == 0 {
+        panic!();
+    }
+    x / y
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_tail_expr() {
+        check_fix(
+            r#"
+//- minicore: result
+fn foo() -> () {
+    println!("Hello, world!");
+    Ok(())$0
+}
+            "#,
+            r#"
+fn foo() -> () {
+    println!("Hello, world!");
+}
+            "#,
+        );
+    }
+
+    #[test]
+    fn unwrap_to_empty_block() {
+        check_fix(
+            r#"
+//- minicore: result
+fn foo() -> () {
+    Ok(())$0
+}
+            "#,
+            r#"
+fn foo() -> () {}
+            "#,
+        );
+    }
+
+    #[test]
+    fn unwrap_to_return_expr() {
+        check_has_fix(
+            r#"
+//- minicore: result
+fn foo(b: bool) -> () {
+    if b {
+        return $0Ok(());
+    }
+
+    panic!("oh dear");
+}"#,
+            r#"
+fn foo(b: bool) -> () {
+    if b {
+        return;
+    }
+
+    panic!("oh dear");
+}"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_in_const_and_static() {
+        check_fix(
+            r#"
+//- minicore: option, result
+static A: () = {Some(($0))};
+            "#,
+            r#"
+static A: () = {};
+            "#,
+        );
+        check_fix(
+            r#"
+//- minicore: option, result
+const _: () = {Some(($0))};
+            "#,
+            r#"
+const _: () = {};
+            "#,
+        );
+    }
+
+    #[test]
+    fn unwrap_return_type_not_applicable_when_inner_type_does_not_match_return_type() {
+        check_no_fix(
+            r#"
+//- minicore: result
+fn foo() -> i32 { $0Ok(()) }
+"#,
+        );
+    }
+
+    #[test]
+    fn unwrap_return_type_not_applicable_when_wrapper_type_is_not_result_or_option() {
+        check_no_fix(
+            r#"
+//- minicore: option, result
+enum SomeOtherEnum { Ok(i32), Err(String) }
+
+fn foo() -> i32 { SomeOtherEnum::Ok($042) }
+"#,
+        );
+    }
+
+    #[test]
     fn remove_semicolon() {
         check_fix(r#"fn f() -> i32 { 92$0; }"#, r#"fn f() -> i32 { 92 }"#);
     }