about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-assists/src/handlers/generate_function.rs102
1 files changed, 84 insertions, 18 deletions
diff --git a/crates/ide-assists/src/handlers/generate_function.rs b/crates/ide-assists/src/handlers/generate_function.rs
index 850be21c300..c579f6780db 100644
--- a/crates/ide-assists/src/handlers/generate_function.rs
+++ b/crates/ide-assists/src/handlers/generate_function.rs
@@ -291,12 +291,9 @@ impl FunctionBuilder {
         let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
         let is_async = await_expr.is_some();
 
-        let (ret_type, should_focus_return_type) = make_return_type(
-            ctx,
-            &ast::Expr::CallExpr(call.clone()),
-            target_module,
-            &mut necessary_generic_params,
-        );
+        let expr_for_ret_ty = await_expr.map_or_else(|| call.clone().into(), |it| it.into());
+        let (ret_type, should_focus_return_type) =
+            make_return_type(ctx, &expr_for_ret_ty, target_module, &mut necessary_generic_params);
 
         let (generic_param_list, where_clause) =
             fn_generic_params(ctx, necessary_generic_params, &target)?;
@@ -338,12 +335,9 @@ impl FunctionBuilder {
         let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
         let is_async = await_expr.is_some();
 
-        let (ret_type, should_focus_return_type) = make_return_type(
-            ctx,
-            &ast::Expr::MethodCallExpr(call.clone()),
-            target_module,
-            &mut necessary_generic_params,
-        );
+        let expr_for_ret_ty = await_expr.map_or_else(|| call.clone().into(), |it| it.into());
+        let (ret_type, should_focus_return_type) =
+            make_return_type(ctx, &expr_for_ret_ty, target_module, &mut necessary_generic_params);
 
         let (generic_param_list, where_clause) =
             fn_generic_params(ctx, necessary_generic_params, &target)?;
@@ -429,12 +423,12 @@ impl FunctionBuilder {
 /// user can change the `todo!` function body.
 fn make_return_type(
     ctx: &AssistContext<'_>,
-    call: &ast::Expr,
+    expr: &ast::Expr,
     target_module: Module,
     necessary_generic_params: &mut FxHashSet<hir::GenericParam>,
 ) -> (Option<ast::RetType>, bool) {
     let (ret_ty, should_focus_return_type) = {
-        match ctx.sema.type_of_expr(call).map(TypeInfo::original) {
+        match ctx.sema.type_of_expr(expr).map(TypeInfo::original) {
             Some(ty) if ty.is_unknown() => (Some(make::ty_placeholder()), true),
             None => (Some(make::ty_placeholder()), true),
             Some(ty) if ty.is_unit() => (None, false),
@@ -2268,13 +2262,13 @@ impl Foo {
         check_assist(
             generate_function,
             r"
-fn foo() {
-    $0bar(42).await();
+async fn foo() {
+    $0bar(42).await;
 }
 ",
             r"
-fn foo() {
-    bar(42).await();
+async fn foo() {
+    bar(42).await;
 }
 
 async fn bar(arg: i32) ${0:-> _} {
@@ -2285,6 +2279,28 @@ async fn bar(arg: i32) ${0:-> _} {
     }
 
     #[test]
+    fn return_type_for_async_fn() {
+        check_assist(
+            generate_function,
+            r"
+//- minicore: result
+async fn foo() {
+    if Err(()) = $0bar(42).await {}
+}
+",
+            r"
+async fn foo() {
+    if Err(()) = bar(42).await {}
+}
+
+async fn bar(arg: i32) -> Result<_, ()> {
+    ${0:todo!()}
+}
+",
+        );
+    }
+
+    #[test]
     fn create_method() {
         check_assist(
             generate_function,
@@ -2402,6 +2418,31 @@ fn foo() {S.bar();}
     }
 
     #[test]
+    fn create_async_method() {
+        check_assist(
+            generate_function,
+            r"
+//- minicore: result
+struct S;
+async fn foo() {
+    if let Err(()) = S.$0bar(42).await {}
+}
+",
+            r"
+struct S;
+impl S {
+    async fn bar(&self, arg: i32) -> Result<_, ()> {
+        ${0:todo!()}
+    }
+}
+async fn foo() {
+    if let Err(()) = S.bar(42).await {}
+}
+",
+        )
+    }
+
+    #[test]
     fn create_static_method() {
         check_assist(
             generate_function,
@@ -2422,6 +2463,31 @@ fn foo() {S::bar();}
     }
 
     #[test]
+    fn create_async_static_method() {
+        check_assist(
+            generate_function,
+            r"
+//- minicore: result
+struct S;
+async fn foo() {
+    if let Err(()) = S::$0bar(42).await {}
+}
+",
+            r"
+struct S;
+impl S {
+    async fn bar(arg: i32) -> Result<_, ()> {
+        ${0:todo!()}
+    }
+}
+async fn foo() {
+    if let Err(()) = S::bar(42).await {}
+}
+",
+        )
+    }
+
+    #[test]
     fn create_generic_static_method() {
         check_assist(
             generate_function,