about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorGiga Bowser <45986823+Giga-Bowser@users.noreply.github.com>2024-10-24 17:54:58 -0400
committerGiga Bowser <45986823+Giga-Bowser@users.noreply.github.com>2024-10-24 17:54:58 -0400
commitd0de3fa7ca728b41db8aaf3735561895aaeeae5c (patch)
treec3b58520f0395bcd6cda9efadd2f9d38622ca92d /src
parenta00b4c2a529089b9eeeba140c9a82b872025d801 (diff)
downloadrust-d0de3fa7ca728b41db8aaf3735561895aaeeae5c.tar.gz
rust-d0de3fa7ca728b41db8aaf3735561895aaeeae5c.zip
Rework `generate_fn_type_alias`
Diffstat (limited to 'src')
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs98
1 files changed, 43 insertions, 55 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs
index 597d94d3fc6..f4b4c22d98d 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs
@@ -1,5 +1,4 @@
 use either::Either;
-use hir::HirDisplay;
 use ide_db::assists::{AssistId, AssistKind, GroupLabel};
 use syntax::{
     ast::{self, edit::IndentLevel, make, HasGenericParams, HasName},
@@ -39,23 +38,16 @@ use crate::{AssistContext, Assists};
 pub(crate) fn generate_fn_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
     let name = ctx.find_node_at_offset::<ast::Name>()?;
     let func = &name.syntax().parent()?;
-    let item = func.ancestors().find_map(ast::Item::cast)?;
-    let assoc_owner =
-        item.syntax().ancestors().nth(2).and_then(Either::<ast::Trait, ast::Impl>::cast);
-    let node = assoc_owner.as_ref().map_or_else(
-        || item.syntax(),
-        |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax),
-    );
     let func_node = ast::Fn::cast(func.clone())?;
     let param_list = func_node.param_list()?;
 
-    for style in ParamStyle::ALL {
-        let generic_params = func_node.generic_param_list();
-        let module = match ctx.sema.scope(node) {
-            Some(scope) => scope.module(),
-            None => continue,
-        };
+    let assoc_owner = func.ancestors().nth(2).and_then(Either::<ast::Trait, ast::Impl>::cast);
+    // This is where we'll insert the type alias, since type aliases in `impl`s or `trait`s are not supported
+    let insertion_node = assoc_owner
+        .as_ref()
+        .map_or_else(|| func, |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax));
 
+    for style in ParamStyle::ALL {
         acc.add_group(
             &GroupLabel("Generate a type alias for function...".into()),
             style.assist_id(),
@@ -66,51 +58,53 @@ pub(crate) fn generate_fn_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>)
 
                 let alias_name = format!("{}Fn", stdx::to_camel_case(&name.to_string()));
 
-                let fn_abi = match func_node.abi() {
-                    Some(abi) => format!("{} ", abi),
-                    None => "".into(),
-                };
-
-                let fn_unsafe = if func_node.unsafe_token().is_some() { "unsafe " } else { "" };
+                let mut fn_params_vec = Vec::new();
 
-                let fn_qualifiers = format!("{fn_unsafe}{fn_abi}");
+                if let Some(self_ty) =
+                    param_list.self_param().and_then(|p| ctx.sema.type_of_self(&p))
+                {
+                    let is_ref = self_ty.is_reference();
+                    let is_mut = self_ty.is_mutable_reference();
 
-                let fn_type = return_type(&func_node);
+                    if let Some(adt) = self_ty.strip_references().as_adt() {
+                        let inner_type = make::ty(adt.name(ctx.db()).as_str());
 
-                let mut fn_params_vec = Vec::new();
+                        let ast_self_ty =
+                            if is_ref { make::ty_ref(inner_type, is_mut) } else { inner_type };
 
-                if let Some(self_param) = param_list.self_param() {
-                    if let Some(local) = ctx.sema.to_def(&self_param) {
-                        let ty = local.ty(ctx.db());
-                        if let Ok(s) = ty.display_source_code(ctx.db(), module.into(), false) {
-                            fn_params_vec.push(s)
-                        }
+                        fn_params_vec.push(make::unnamed_param(ast_self_ty));
                     }
                 }
 
-                match style {
-                    ParamStyle::Named => {
-                        fn_params_vec.extend(param_list.params().map(|p| p.to_string()))
-                    }
-                    ParamStyle::Unnamed => fn_params_vec.extend(
-                        param_list.params().filter_map(|p| p.ty()).map(|ty| ty.to_string()),
-                    ),
-                };
+                fn_params_vec.extend(param_list.params().filter_map(|p| match style {
+                    ParamStyle::Named => Some(p),
+                    ParamStyle::Unnamed => p.ty().map(make::unnamed_param),
+                }));
 
-                let fn_params = fn_params_vec.join(", ");
+                let generic_params = func_node.generic_param_list();
 
-                // FIXME: sometime in the far future when we have `make::ty_func`, we should use that
-                let ty = make::ty(&format!("{fn_qualifiers}fn({fn_params}){fn_type}"))
-                    .clone_for_update();
+                let is_unsafe = func_node.unsafe_token().is_some();
+                let ty = make::ty_fn_ptr(
+                    None,
+                    is_unsafe,
+                    func_node.abi(),
+                    fn_params_vec.into_iter(),
+                    func_node.ret_type(),
+                );
 
                 // Insert new alias
-                let ty_alias =
-                    make::ty_alias(&alias_name, generic_params, None, None, Some((ty, None)))
-                        .clone_for_update();
-
-                let indent = IndentLevel::from_node(node);
+                let ty_alias = make::ty_alias(
+                    &alias_name,
+                    generic_params,
+                    None,
+                    None,
+                    Some((ast::Type::FnPtrType(ty), None)),
+                )
+                .clone_for_update();
+
+                let indent = IndentLevel::from_node(insertion_node);
                 edit.insert_all(
-                    syntax_editor::Position::before(node),
+                    syntax_editor::Position::before(insertion_node),
                     vec![
                         ty_alias.syntax().clone().into(),
                         make::tokens::whitespace(&format!("\n\n{indent}")).into(),
@@ -156,12 +150,6 @@ impl ParamStyle {
     }
 }
 
-fn return_type(func: &ast::Fn) -> String {
-    func.ret_type()
-        .and_then(|ret_type| ret_type.ty())
-        .map_or("".into(), |ty| format!(" -> {} ", ty))
-}
-
 #[cfg(test)]
 mod tests {
     use crate::tests::check_assist_by_label;
@@ -233,7 +221,7 @@ extern "FooABI" fn foo(param: u32) -> i32 { return 42; }
     }
 
     #[test]
-    fn generate_fn_alias_unnamed_unnamed_unsafe_extern_abi() {
+    fn generate_fn_alias_unnamed_unsafe_extern_abi() {
         check_assist_by_label(
             generate_fn_type_alias,
             r#"
@@ -369,7 +357,7 @@ extern "FooABI" fn foo(param: u32) -> i32 { return 42; }
     }
 
     #[test]
-    fn generate_fn_alias_named_named_unsafe_extern_abi() {
+    fn generate_fn_alias_named_unsafe_extern_abi() {
         check_assist_by_label(
             generate_fn_type_alias,
             r#"