about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorChayim Refael Friedman <chayimfr@gmail.com>2025-06-23 00:45:40 +0300
committerChayim Refael Friedman <chayimfr@gmail.com>2025-06-23 00:45:40 +0300
commit47b29ea0c0df9bf30b16e5674cd5745554b6069a (patch)
treea3df2dffd086320c3262ddf7e9740c26db2b2836 /src
parent9a0434ec195f6cbe3b84fd6d6275f142414f41f7 (diff)
downloadrust-47b29ea0c0df9bf30b16e5674cd5745554b6069a.tar.gz
rust-47b29ea0c0df9bf30b16e5674cd5745554b6069a.zip
In "Wrap return type" assist, don't wrap exit points if they already have the right type
Diffstat (limited to 'src')
-rw-r--r--src/tools/rust-analyzer/crates/hir/src/lib.rs4
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs163
2 files changed, 133 insertions, 34 deletions
diff --git a/src/tools/rust-analyzer/crates/hir/src/lib.rs b/src/tools/rust-analyzer/crates/hir/src/lib.rs
index 3b39707cf60..46d2e881600 100644
--- a/src/tools/rust-analyzer/crates/hir/src/lib.rs
+++ b/src/tools/rust-analyzer/crates/hir/src/lib.rs
@@ -1727,10 +1727,10 @@ impl Adt {
     pub fn ty_with_args<'db>(
         self,
         db: &'db dyn HirDatabase,
-        args: impl Iterator<Item = Type<'db>>,
+        args: impl IntoIterator<Item = Type<'db>>,
     ) -> Type<'db> {
         let id = AdtId::from(self);
-        let mut it = args.map(|t| t.ty);
+        let mut it = args.into_iter().map(|t| t.ty);
         let ty = TyBuilder::def_ty(db, id.into(), None)
             .fill(|x| {
                 let r = it.next().unwrap_or_else(|| TyKind::Error.intern(Interner));
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs
index 9ea78719b20..d7189aa5dbb 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs
@@ -56,7 +56,8 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
     };
 
     let type_ref = &ret_type.ty()?;
-    let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
+    let ty = ctx.sema.resolve_type(type_ref)?;
+    let ty_adt = ty.as_adt();
     let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate());
 
     for kind in WrapperKind::ALL {
@@ -64,7 +65,7 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
             continue;
         };
 
-        if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == core_wrapper) {
+        if matches!(ty_adt, Some(hir::Adt::Enum(ret_type)) if ret_type == core_wrapper) {
             // The return type is already wrapped
             cov_mark::hit!(wrap_return_type_simple_return_type_already_wrapped);
             continue;
@@ -78,10 +79,23 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
             |builder| {
                 let mut editor = builder.make_editor(&parent);
                 let make = SyntaxFactory::with_mappings();
-                let alias = wrapper_alias(ctx, &make, &core_wrapper, type_ref, kind.symbol());
-                let new_return_ty = alias.unwrap_or_else(|| match kind {
-                    WrapperKind::Option => make.ty_option(type_ref.clone()),
-                    WrapperKind::Result => make.ty_result(type_ref.clone(), make.ty_infer().into()),
+                let alias = wrapper_alias(ctx, &make, core_wrapper, type_ref, &ty, kind.symbol());
+                let (ast_new_return_ty, semantic_new_return_ty) = alias.unwrap_or_else(|| {
+                    let (ast_ty, ty_constructor) = match kind {
+                        WrapperKind::Option => {
+                            (make.ty_option(type_ref.clone()), famous_defs.core_option_Option())
+                        }
+                        WrapperKind::Result => (
+                            make.ty_result(type_ref.clone(), make.ty_infer().into()),
+                            famous_defs.core_result_Result(),
+                        ),
+                    };
+                    let semantic_ty = ty_constructor
+                        .map(|ty_constructor| {
+                            hir::Adt::from(ty_constructor).ty_with_args(ctx.db(), [ty.clone()])
+                        })
+                        .unwrap_or_else(|| ty.clone());
+                    (ast_ty, semantic_ty)
                 });
 
                 let mut exprs_to_wrap = Vec::new();
@@ -96,6 +110,17 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
                 for_each_tail_expr(&body_expr, tail_cb);
 
                 for ret_expr_arg in exprs_to_wrap {
+                    if let Some(ty) = ctx.sema.type_of_expr(&ret_expr_arg) {
+                        if ty.adjusted().could_unify_with(ctx.db(), &semantic_new_return_ty) {
+                            // The type is already correct, don't wrap it.
+                            // We deliberately don't use `could_unify_with_deeply()`, because as long as the outer
+                            // enum matches it's okay for us, as we don't trigger the assist if the return type
+                            // is already `Option`/`Result`, so mismatched exact type is more likely a mistake
+                            // than something intended.
+                            continue;
+                        }
+                    }
+
                     let happy_wrapped = make.expr_call(
                         make.expr_path(make.ident_path(kind.happy_ident())),
                         make.arg_list(iter::once(ret_expr_arg.clone())),
@@ -103,12 +128,12 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
                     editor.replace(ret_expr_arg.syntax(), happy_wrapped.syntax());
                 }
 
-                editor.replace(type_ref.syntax(), new_return_ty.syntax());
+                editor.replace(type_ref.syntax(), ast_new_return_ty.syntax());
 
                 if let WrapperKind::Result = kind {
                     // Add a placeholder snippet at the first generic argument that doesn't equal the return type.
                     // This is normally the error type, but that may not be the case when we inserted a type alias.
-                    let args = new_return_ty
+                    let args = ast_new_return_ty
                         .path()
                         .unwrap()
                         .segment()
@@ -188,27 +213,28 @@ impl WrapperKind {
 }
 
 // Try to find an wrapper type alias in the current scope (shadowing the default).
-fn wrapper_alias(
-    ctx: &AssistContext<'_>,
+fn wrapper_alias<'db>(
+    ctx: &AssistContext<'db>,
     make: &SyntaxFactory,
-    core_wrapper: &hir::Enum,
-    ret_type: &ast::Type,
+    core_wrapper: hir::Enum,
+    ast_ret_type: &ast::Type,
+    semantic_ret_type: &hir::Type<'db>,
     wrapper: hir::Symbol,
-) -> Option<ast::PathType> {
+) -> Option<(ast::PathType, hir::Type<'db>)> {
     let wrapper_path = hir::ModPath::from_segments(
         hir::PathKind::Plain,
         iter::once(hir::Name::new_symbol_root(wrapper)),
     );
 
-    ctx.sema.resolve_mod_path(ret_type.syntax(), &wrapper_path).and_then(|def| {
+    ctx.sema.resolve_mod_path(ast_ret_type.syntax(), &wrapper_path).and_then(|def| {
         def.filter_map(|def| match def.into_module_def() {
             hir::ModuleDef::TypeAlias(alias) => {
                 let enum_ty = alias.ty(ctx.db()).as_adt()?.as_enum()?;
-                (&enum_ty == core_wrapper).then_some(alias)
+                (enum_ty == core_wrapper).then_some((alias, enum_ty))
             }
             _ => None,
         })
-        .find_map(|alias| {
+        .find_map(|(alias, enum_ty)| {
             let mut inserted_ret_type = false;
             let generic_args =
                 alias.source(ctx.db())?.value.generic_param_list()?.generic_params().map(|param| {
@@ -216,7 +242,7 @@ fn wrapper_alias(
                         // Replace the very first type parameter with the function's return type.
                         ast::GenericParam::TypeParam(_) if !inserted_ret_type => {
                             inserted_ret_type = true;
-                            make.type_arg(ret_type.clone()).into()
+                            make.type_arg(ast_ret_type.clone()).into()
                         }
                         ast::GenericParam::LifetimeParam(_) => {
                             make.lifetime_arg(make.lifetime("'_")).into()
@@ -231,7 +257,10 @@ fn wrapper_alias(
                 make.path_segment_generics(make.name_ref(name.as_str()), generic_arg_list),
             );
 
-            Some(make.ty_path(path))
+            let new_ty =
+                hir::Adt::from(enum_ty).ty_with_args(ctx.db(), [semantic_ret_type.clone()]);
+
+            Some((make.ty_path(path), new_ty))
         })
     })
 }
@@ -605,29 +634,39 @@ fn foo() -> Option<i32> {
         check_assist_by_label(
             wrap_return_type,
             r#"
-//- minicore: option
+//- minicore: option, future
+struct F(i32);
+impl core::future::Future for F {
+    type Output = i32;
+    fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
+}
 async fn foo() -> i$032 {
     if true {
         if false {
-            1.await
+            F(1).await
         } else {
-            2.await
+            F(2).await
         }
     } else {
-        24i32.await
+        F(24i32).await
     }
 }
 "#,
             r#"
+struct F(i32);
+impl core::future::Future for F {
+    type Output = i32;
+    fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
+}
 async fn foo() -> Option<i32> {
     if true {
         if false {
-            Some(1.await)
+            Some(F(1).await)
         } else {
-            Some(2.await)
+            Some(F(2).await)
         }
     } else {
-        Some(24i32.await)
+        Some(F(24i32).await)
     }
 }
 "#,
@@ -1666,29 +1705,39 @@ fn foo() -> Result<i32, ${0:_}> {
         check_assist_by_label(
             wrap_return_type,
             r#"
-//- minicore: result
+//- minicore: result, future
+struct F(i32);
+impl core::future::Future for F {
+    type Output = i32;
+    fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
+}
 async fn foo() -> i$032 {
     if true {
         if false {
-            1.await
+            F(1).await
         } else {
-            2.await
+            F(2).await
         }
     } else {
-        24i32.await
+        F(24i32).await
     }
 }
 "#,
             r#"
+struct F(i32);
+impl core::future::Future for F {
+    type Output = i32;
+    fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
+}
 async fn foo() -> Result<i32, ${0:_}> {
     if true {
         if false {
-            Ok(1.await)
+            Ok(F(1).await)
         } else {
-            Ok(2.await)
+            Ok(F(2).await)
         }
     } else {
-        Ok(24i32.await)
+        Ok(F(24i32).await)
     }
 }
 "#,
@@ -2460,4 +2509,54 @@ fn foo() -> Result<i32, ${0:_}> {
             WrapperKind::Result.label(),
         );
     }
+
+    #[test]
+    fn already_wrapped() {
+        check_assist_by_label(
+            wrap_return_type,
+            r#"
+//- minicore: option
+fn foo() -> i32$0 {
+    if false {
+        0
+    } else {
+        Some(1)
+    }
+}
+            "#,
+            r#"
+fn foo() -> Option<i32> {
+    if false {
+        Some(0)
+    } else {
+        Some(1)
+    }
+}
+            "#,
+            WrapperKind::Option.label(),
+        );
+        check_assist_by_label(
+            wrap_return_type,
+            r#"
+//- minicore: result
+fn foo() -> i32$0 {
+    if false {
+        0
+    } else {
+        Ok(1)
+    }
+}
+            "#,
+            r#"
+fn foo() -> Result<i32, ${0:_}> {
+    if false {
+        Ok(0)
+    } else {
+        Ok(1)
+    }
+}
+            "#,
+            WrapperKind::Result.label(),
+        );
+    }
 }