about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs2
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs119
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs110
3 files changed, 168 insertions, 63 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs
index 62700ab1809..04d63f5bc8f 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs
@@ -189,7 +189,7 @@ pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
 /// This will create a turbofish generic arg list corresponding to the number of arguments
 fn get_fish_head(make: &SyntaxFactory, number_of_arguments: usize) -> ast::GenericArgList {
     let args = (0..number_of_arguments).map(|_| make::type_arg(make::ty_placeholder()).into());
-    make.turbofish_generic_arg_list(args)
+    make.generic_arg_list(args, true)
 }
 
 #[cfg(test)]
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs
index 658600cd2d0..0b145dcb06b 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs
@@ -6,10 +6,9 @@ use ide_db::{
     famous_defs::FamousDefs,
     syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
 };
-use itertools::Itertools;
 use syntax::{
-    ast::{self, make, Expr, HasGenericParams},
-    match_ast, ted, AstNode, ToSmolStr,
+    ast::{self, syntax_factory::SyntaxFactory, Expr, HasGenericArgs, HasGenericParams},
+    match_ast, AstNode,
 };
 
 use crate::{AssistContext, AssistId, AssistKind, Assists};
@@ -43,11 +42,11 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
 pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
     let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
     let parent = ret_type.syntax().parent()?;
-    let body = match_ast! {
+    let body_expr = match_ast! {
         match parent {
-            ast::Fn(func) => func.body()?,
+            ast::Fn(func) => func.body()?.into(),
             ast::ClosureExpr(closure) => match closure.body()? {
-                Expr::BlockExpr(block) => block,
+                Expr::BlockExpr(block) => block.into(),
                 // closures require a block when a return type is specified
                 _ => return None,
             },
@@ -75,56 +74,65 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
             kind.assist_id(),
             kind.label(),
             type_ref.syntax().text_range(),
-            |edit| {
-                let alias = wrapper_alias(ctx, &core_wrapper, type_ref, kind.symbol());
-                let new_return_ty =
-                    alias.unwrap_or_else(|| kind.wrap_type(type_ref)).clone_for_update();
-
-                let body = edit.make_mut(ast::Expr::BlockExpr(body.clone()));
+            |builder| {
+                let mut editor = builder.make_editor(&parent);
+                let make = SyntaxFactory::new();
+                let alias = wrapper_alias(ctx, &make, &core_wrapper, type_ref, kind.symbol());
+                let new_return_ty = alias.unwrap_or_else(|| match kind {
+                    WrapperKind::Option => make.ty_option(type_ref.clone()),
+                    WrapperKind::Result => make.ty_result(type_ref.clone(), make.ty_infer().into()),
+                });
 
                 let mut exprs_to_wrap = Vec::new();
                 let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e);
-                walk_expr(&body, &mut |expr| {
+                walk_expr(&body_expr, &mut |expr| {
                     if let Expr::ReturnExpr(ret_expr) = expr {
                         if let Some(ret_expr_arg) = &ret_expr.expr() {
                             for_each_tail_expr(ret_expr_arg, tail_cb);
                         }
                     }
                 });
-                for_each_tail_expr(&body, tail_cb);
+                for_each_tail_expr(&body_expr, tail_cb);
 
                 for ret_expr_arg in exprs_to_wrap {
-                    let happy_wrapped = make::expr_call(
-                        make::expr_path(make::ext::ident_path(kind.happy_ident())),
-                        make::arg_list(iter::once(ret_expr_arg.clone())),
-                    )
-                    .clone_for_update();
-                    ted::replace(ret_expr_arg.syntax(), happy_wrapped.syntax());
+                    let happy_wrapped = make.expr_call(
+                        make.expr_path(make.ident_path(kind.happy_ident())),
+                        make.arg_list(iter::once(ret_expr_arg.clone())),
+                    );
+                    editor.replace(ret_expr_arg.syntax(), happy_wrapped.syntax());
                 }
 
-                let old_return_ty = edit.make_mut(type_ref.clone());
-                ted::replace(old_return_ty.syntax(), new_return_ty.syntax());
+                editor.replace(type_ref.syntax(), new_return_ty.syntax());
 
                 if let WrapperKind::Result = kind {
                     // Add a placeholder snippet at the first generic argument that doesn't equal the return type.
                     // This is normally the error type, but that may not be the case when we inserted a type alias.
-                    let args =
-                        new_return_ty.syntax().descendants().find_map(ast::GenericArgList::cast);
-                    let error_type_arg = args.and_then(|list| {
-                        list.generic_args().find(|arg| match arg {
-                            ast::GenericArg::TypeArg(_) => {
-                                arg.syntax().text() != type_ref.syntax().text()
-                            }
-                            ast::GenericArg::LifetimeArg(_) => false,
-                            _ => true,
-                        })
+                    let args = new_return_ty
+                        .path()
+                        .unwrap()
+                        .segment()
+                        .unwrap()
+                        .generic_arg_list()
+                        .unwrap();
+                    let error_type_arg = args.generic_args().find(|arg| match arg {
+                        ast::GenericArg::TypeArg(_) => {
+                            arg.syntax().text() != type_ref.syntax().text()
+                        }
+                        ast::GenericArg::LifetimeArg(_) => false,
+                        _ => true,
                     });
                     if let Some(error_type_arg) = error_type_arg {
                         if let Some(cap) = ctx.config.snippet_cap {
-                            edit.add_placeholder_snippet(cap, error_type_arg);
+                            editor.add_annotation(
+                                error_type_arg.syntax(),
+                                builder.make_placeholder_snippet(cap),
+                            );
                         }
                     }
                 }
+
+                editor.add_mappings(make.finish_with_mappings());
+                builder.add_file_edits(ctx.file_id(), editor);
             },
         );
     }
@@ -176,22 +184,16 @@ impl WrapperKind {
             WrapperKind::Result => hir::sym::Result.clone(),
         }
     }
-
-    fn wrap_type(&self, type_ref: &ast::Type) -> ast::Type {
-        match self {
-            WrapperKind::Option => make::ext::ty_option(type_ref.clone()),
-            WrapperKind::Result => make::ext::ty_result(type_ref.clone(), make::ty_placeholder()),
-        }
-    }
 }
 
 // Try to find an wrapper type alias in the current scope (shadowing the default).
 fn wrapper_alias(
     ctx: &AssistContext<'_>,
+    make: &SyntaxFactory,
     core_wrapper: &hir::Enum,
     ret_type: &ast::Type,
     wrapper: hir::Symbol,
-) -> Option<ast::Type> {
+) -> Option<ast::PathType> {
     let wrapper_path = hir::ModPath::from_segments(
         hir::PathKind::Plain,
         iter::once(hir::Name::new_symbol_root(wrapper)),
@@ -207,25 +209,28 @@ fn wrapper_alias(
         })
         .find_map(|alias| {
             let mut inserted_ret_type = false;
-            let generic_params = alias
-                .source(ctx.db())?
-                .value
-                .generic_param_list()?
-                .generic_params()
-                .map(|param| match param {
-                    // Replace the very first type parameter with the functions return type.
-                    ast::GenericParam::TypeParam(_) if !inserted_ret_type => {
-                        inserted_ret_type = true;
-                        ret_type.to_smolstr()
+            let generic_args =
+                alias.source(ctx.db())?.value.generic_param_list()?.generic_params().map(|param| {
+                    match param {
+                        // Replace the very first type parameter with the function's return type.
+                        ast::GenericParam::TypeParam(_) if !inserted_ret_type => {
+                            inserted_ret_type = true;
+                            make.type_arg(ret_type.clone()).into()
+                        }
+                        ast::GenericParam::LifetimeParam(_) => {
+                            make.lifetime_arg(make.lifetime("'_")).into()
+                        }
+                        _ => make.type_arg(make.ty_infer().into()).into(),
                     }
-                    ast::GenericParam::LifetimeParam(_) => make::lifetime("'_").to_smolstr(),
-                    _ => make::ty_placeholder().to_smolstr(),
-                })
-                .join(", ");
+                });
 
             let name = alias.name(ctx.db());
-            let name = name.as_str();
-            Some(make::ty(&format!("{name}<{generic_params}>")))
+            let generic_arg_list = make.generic_arg_list(generic_args, false);
+            let path = make.path_unqualified(
+                make.path_segment_generics(make.name_ref(name.as_str()), generic_arg_list),
+            );
+
+            Some(make.ty_path(path))
         })
     })
 }
diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
index d62c01ba761..af7b3c81581 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs
@@ -1,6 +1,9 @@
 //! Wrappers over [`make`] constructors
 use crate::{
-    ast::{self, make, HasGenericArgs, HasGenericParams, HasName, HasTypeBounds, HasVisibility},
+    ast::{
+        self, make, HasArgList, HasGenericArgs, HasGenericParams, HasName, HasTypeBounds,
+        HasVisibility,
+    },
     syntax_editor::SyntaxMappingBuilder,
     AstNode, NodeOrToken, SyntaxKind, SyntaxNode, SyntaxToken,
 };
@@ -16,6 +19,10 @@ impl SyntaxFactory {
         make::name_ref(name).clone_for_update()
     }
 
+    pub fn lifetime(&self, text: &str) -> ast::Lifetime {
+        make::lifetime(text).clone_for_update()
+    }
+
     pub fn ty(&self, text: &str) -> ast::Type {
         make::ty(text).clone_for_update()
     }
@@ -28,6 +35,20 @@ impl SyntaxFactory {
         ast
     }
 
+    pub fn ty_path(&self, path: ast::Path) -> ast::PathType {
+        let ast::Type::PathType(ast) = make::ty_path(path.clone()).clone_for_update() else {
+            unreachable!()
+        };
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+            builder.map_node(path.syntax().clone(), ast.path().unwrap().syntax().clone());
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
     pub fn type_param(
         &self,
         name: ast::Name,
@@ -253,6 +274,37 @@ impl SyntaxFactory {
         ast
     }
 
+    pub fn expr_call(&self, expr: ast::Expr, arg_list: ast::ArgList) -> ast::CallExpr {
+        // FIXME: `make::expr_call`` should return a `CallExpr`, not just an `Expr`
+        let ast::Expr::CallExpr(ast) =
+            make::expr_call(expr.clone(), arg_list.clone()).clone_for_update()
+        else {
+            unreachable!()
+        };
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+            builder.map_node(expr.syntax().clone(), ast.expr().unwrap().syntax().clone());
+            builder.map_node(arg_list.syntax().clone(), ast.arg_list().unwrap().syntax().clone());
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
+    pub fn arg_list(&self, args: impl IntoIterator<Item = ast::Expr>) -> ast::ArgList {
+        let (args, input) = iterator_input(args);
+        let ast = make::arg_list(args).clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax.clone());
+            builder.map_children(input.into_iter(), ast.args().map(|it| it.syntax().clone()));
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
     pub fn expr_ref(&self, expr: ast::Expr, exclusive: bool) -> ast::Expr {
         let ast::Expr::RefExpr(ast) = make::expr_ref(expr.clone(), exclusive).clone_for_update()
         else {
@@ -428,6 +480,30 @@ impl SyntaxFactory {
         ast
     }
 
+    pub fn type_arg(&self, ty: ast::Type) -> ast::TypeArg {
+        let ast = make::type_arg(ty.clone()).clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+            builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone());
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
+    pub fn lifetime_arg(&self, lifetime: ast::Lifetime) -> ast::LifetimeArg {
+        let ast = make::lifetime_arg(lifetime.clone()).clone_for_update();
+
+        if let Some(mut mapping) = self.mappings() {
+            let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+            builder.map_node(lifetime.syntax().clone(), ast.lifetime().unwrap().syntax().clone());
+            builder.finish(&mut mapping);
+        }
+
+        ast
+    }
+
     pub fn item_const(
         &self,
         visibility: Option<ast::Visibility>,
@@ -495,12 +571,17 @@ impl SyntaxFactory {
         ast
     }
 
-    pub fn turbofish_generic_arg_list(
+    pub fn generic_arg_list(
         &self,
         generic_args: impl IntoIterator<Item = ast::GenericArg>,
+        is_turbo: bool,
     ) -> ast::GenericArgList {
         let (generic_args, input) = iterator_input(generic_args);
-        let ast = make::turbofish_generic_arg_list(generic_args.clone()).clone_for_update();
+        let ast = if is_turbo {
+            make::turbofish_generic_arg_list(generic_args).clone_for_update()
+        } else {
+            make::generic_arg_list(generic_args).clone_for_update()
+        };
 
         if let Some(mut mapping) = self.mappings() {
             let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
@@ -753,12 +834,31 @@ impl SyntaxFactory {
 
 // `ext` constructors
 impl SyntaxFactory {
+    pub fn ident_path(&self, ident: &str) -> ast::Path {
+        self.path_unqualified(self.path_segment(self.name_ref(ident)))
+    }
+
     pub fn expr_unit(&self) -> ast::Expr {
         self.expr_tuple([]).into()
     }
 
-    pub fn ident_path(&self, ident: &str) -> ast::Path {
-        self.path_unqualified(self.path_segment(self.name_ref(ident)))
+    pub fn ty_option(&self, t: ast::Type) -> ast::PathType {
+        let generic_arg_list = self.generic_arg_list([self.type_arg(t).into()], false);
+        let path = self.path_unqualified(
+            self.path_segment_generics(self.name_ref("Option"), generic_arg_list),
+        );
+
+        self.ty_path(path)
+    }
+
+    pub fn ty_result(&self, t: ast::Type, e: ast::Type) -> ast::PathType {
+        let generic_arg_list =
+            self.generic_arg_list([self.type_arg(t).into(), self.type_arg(e).into()], false);
+        let path = self.path_unqualified(
+            self.path_segment_generics(self.name_ref("Result"), generic_arg_list),
+        );
+
+        self.ty_path(path)
     }
 }