about summary refs log tree commit diff
diff options
context:
space:
mode:
authorunexge <unexge@gmail.com>2023-04-26 22:07:06 +0100
committerunexge <unexge@gmail.com>2023-04-26 22:07:06 +0100
commit734fe66f71a237f9bfdf3b8d235fceda8a114c64 (patch)
tree0080da29111323087227cfe695041491295cce4d
parent797c2f1dde0905afa24f567160ed23ba2bc79a81 (diff)
downloadrust-734fe66f71a237f9bfdf3b8d235fceda8a114c64.tar.gz
rust-734fe66f71a237f9bfdf3b8d235fceda8a114c64.zip
Handle nested types in `unwrap_result_return_type` assist
-rw-r--r--crates/ide-assists/src/handlers/unwrap_result_return_type.rs122
1 files changed, 99 insertions, 23 deletions
diff --git a/crates/ide-assists/src/handlers/unwrap_result_return_type.rs b/crates/ide-assists/src/handlers/unwrap_result_return_type.rs
index 9ef4ae047ef..8b6c614219d 100644
--- a/crates/ide-assists/src/handlers/unwrap_result_return_type.rs
+++ b/crates/ide-assists/src/handlers/unwrap_result_return_type.rs
@@ -5,7 +5,7 @@ use ide_db::{
 use itertools::Itertools;
 use syntax::{
     ast::{self, Expr},
-    match_ast, AstNode, TextRange, TextSize,
+    match_ast, AstNode, NodeOrToken, SyntaxKind, TextRange, TextSize,
 };
 
 use crate::{AssistContext, AssistId, AssistKind, Assists};
@@ -38,14 +38,15 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
     };
 
     let type_ref = &ret_type.ty()?;
-    let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
+    let Some(hir::Adt::Enum(ret_enum)) = ctx.sema.resolve_type(type_ref)?.as_adt() else { return None; };
     let result_enum =
         FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?;
-
-    if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
+    if ret_enum != result_enum {
         return None;
     }
 
+    let Some(ok_type) = unwrap_result_type(type_ref) else { return None; };
+
     acc.add(
         AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite),
         "Unwrap Result return type",
@@ -64,26 +65,22 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
             });
             for_each_tail_expr(&body, tail_cb);
 
-            let mut is_unit_type = false;
-            if let Some((_, inner_type)) = type_ref.to_string().split_once('<') {
-                let inner_type = match inner_type.split_once(',') {
-                    Some((success_inner_type, _)) => success_inner_type,
-                    None => inner_type,
-                };
-                let new_ret_type = inner_type.strip_suffix('>').unwrap_or(inner_type);
-                if new_ret_type == "()" {
-                    is_unit_type = true;
-                    let text_range = TextRange::new(
-                        ret_type.syntax().text_range().start(),
-                        ret_type.syntax().text_range().end() + TextSize::from(1u32),
-                    );
-                    builder.delete(text_range)
-                } else {
-                    builder.replace(
-                        type_ref.syntax().text_range(),
-                        inner_type.strip_suffix('>').unwrap_or(inner_type),
-                    )
+            let is_unit_type = is_unit_type(&ok_type);
+            if is_unit_type {
+                let mut text_range = ret_type.syntax().text_range();
+
+                if let Some(NodeOrToken::Token(token)) = ret_type.syntax().next_sibling_or_token() {
+                    if token.kind() == SyntaxKind::WHITESPACE {
+                        text_range = TextRange::new(
+                            text_range.start(),
+                            text_range.end() + TextSize::from(1u32),
+                        );
+                    }
                 }
+
+                builder.delete(text_range);
+            } else {
+                builder.replace(type_ref.syntax().text_range(), ok_type.syntax().text());
             }
 
             for ret_expr_arg in exprs_to_unwrap {
@@ -134,6 +131,22 @@ fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
     }
 }
 
+// Tries to extract `T` from `Result<T, E>`.
+fn unwrap_result_type(ty: &ast::Type) -> Option<ast::Type> {
+    let ast::Type::PathType(path_ty) = ty else { return None; };
+    let Some(path) = path_ty.path() else { return None; };
+    let Some(segment) = path.first_segment() else { return None; };
+    let Some(generic_arg_list) = segment.generic_arg_list() else { return None; };
+    let generic_args: Vec<_> = generic_arg_list.generic_args().collect();
+    let Some(ast::GenericArg::TypeArg(ok_type)) = generic_args.first() else { return None; };
+    ok_type.ty()
+}
+
+fn is_unit_type(ty: &ast::Type) -> bool {
+    let ast::Type::TupleType(tuple) = ty else { return false };
+    tuple.fields().next().is_none()
+}
+
 #[cfg(test)]
 mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
@@ -175,6 +188,21 @@ fn foo() {
 }
 "#,
         );
+
+        // Unformatted return type
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result
+fn foo() -> Result<(), Box<dyn Error$0>>{
+    Ok(())
+}
+"#,
+            r#"
+fn foo() {
+}
+"#,
+        );
     }
 
     #[test]
@@ -1017,4 +1045,52 @@ fn foo(the_field: u32) -> u32 {
 "#,
         );
     }
+
+    #[test]
+    fn unwrap_result_return_type_nested_type() {
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result, option
+fn foo() -> Result<Option<i32$0>, ()> {
+    Ok(Some(42))
+}
+"#,
+            r#"
+fn foo() -> Option<i32> {
+    Some(42)
+}
+"#,
+        );
+
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result, option
+fn foo() -> Result<Option<Result<i32$0, ()>>, ()> {
+    Ok(None)
+}
+"#,
+            r#"
+fn foo() -> Option<Result<i32, ()>> {
+    None
+}
+"#,
+        );
+
+        check_assist(
+            unwrap_result_return_type,
+            r#"
+//- minicore: result, option, iterators
+fn foo() -> Result<impl Iterator<Item = i32>$0, ()> {
+    Ok(Some(42).into_iter())
+}
+"#,
+            r#"
+fn foo() -> impl Iterator<Item = i32> {
+    Some(42).into_iter()
+}
+"#,
+        );
+    }
 }