about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2024-10-30 09:36:55 +0000
committerGitHub <noreply@github.com>2024-10-30 09:36:55 +0000
commit5fe280e681e0a222fbe40ac736ad05a30b4080c7 (patch)
tree51f3158ad463e2b3c40ba72f99174d96e8cb7b24
parent3ab7a69e4cb2a9fa313f5a6b2356179be75c1b7d (diff)
parentd0de3fa7ca728b41db8aaf3735561895aaeeae5c (diff)
downloadrust-5fe280e681e0a222fbe40ac736ad05a30b4080c7.tar.gz
rust-5fe280e681e0a222fbe40ac736ad05a30b4080c7.zip
Merge pull request #18385 from Giga-Bowser/master
feat: Add assist to generate a type alias for a function
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs430
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/lib.rs2
-rw-r--r--src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs30
-rw-r--r--src/tools/rust-analyzer/crates/syntax/src/ast/make.rs42
4 files changed, 503 insertions, 1 deletions
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs
new file mode 100644
index 00000000000..f4b4c22d98d
--- /dev/null
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs
@@ -0,0 +1,430 @@
+use either::Either;
+use ide_db::assists::{AssistId, AssistKind, GroupLabel};
+use syntax::{
+    ast::{self, edit::IndentLevel, make, HasGenericParams, HasName},
+    syntax_editor, AstNode,
+};
+
+use crate::{AssistContext, Assists};
+
+// Assist: generate_fn_type_alias_named
+//
+// Generate a type alias for the function with named parameters.
+//
+// ```
+// unsafe fn fo$0o(n: i32) -> i32 { 42i32 }
+// ```
+// ->
+// ```
+// type ${0:FooFn} = unsafe fn(n: i32) -> i32;
+//
+// unsafe fn foo(n: i32) -> i32 { 42i32 }
+// ```
+
+// Assist: generate_fn_type_alias_unnamed
+//
+// Generate a type alias for the function with unnamed parameters.
+//
+// ```
+// unsafe fn fo$0o(n: i32) -> i32 { 42i32 }
+// ```
+// ->
+// ```
+// type ${0:FooFn} = unsafe fn(i32) -> i32;
+//
+// unsafe fn foo(n: i32) -> i32 { 42i32 }
+// ```
+
+pub(crate) fn generate_fn_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
+    let name = ctx.find_node_at_offset::<ast::Name>()?;
+    let func = &name.syntax().parent()?;
+    let func_node = ast::Fn::cast(func.clone())?;
+    let param_list = func_node.param_list()?;
+
+    let assoc_owner = func.ancestors().nth(2).and_then(Either::<ast::Trait, ast::Impl>::cast);
+    // This is where we'll insert the type alias, since type aliases in `impl`s or `trait`s are not supported
+    let insertion_node = assoc_owner
+        .as_ref()
+        .map_or_else(|| func, |impl_| impl_.as_ref().either(AstNode::syntax, AstNode::syntax));
+
+    for style in ParamStyle::ALL {
+        acc.add_group(
+            &GroupLabel("Generate a type alias for function...".into()),
+            style.assist_id(),
+            style.label(),
+            func_node.syntax().text_range(),
+            |builder| {
+                let mut edit = builder.make_editor(func);
+
+                let alias_name = format!("{}Fn", stdx::to_camel_case(&name.to_string()));
+
+                let mut fn_params_vec = Vec::new();
+
+                if let Some(self_ty) =
+                    param_list.self_param().and_then(|p| ctx.sema.type_of_self(&p))
+                {
+                    let is_ref = self_ty.is_reference();
+                    let is_mut = self_ty.is_mutable_reference();
+
+                    if let Some(adt) = self_ty.strip_references().as_adt() {
+                        let inner_type = make::ty(adt.name(ctx.db()).as_str());
+
+                        let ast_self_ty =
+                            if is_ref { make::ty_ref(inner_type, is_mut) } else { inner_type };
+
+                        fn_params_vec.push(make::unnamed_param(ast_self_ty));
+                    }
+                }
+
+                fn_params_vec.extend(param_list.params().filter_map(|p| match style {
+                    ParamStyle::Named => Some(p),
+                    ParamStyle::Unnamed => p.ty().map(make::unnamed_param),
+                }));
+
+                let generic_params = func_node.generic_param_list();
+
+                let is_unsafe = func_node.unsafe_token().is_some();
+                let ty = make::ty_fn_ptr(
+                    None,
+                    is_unsafe,
+                    func_node.abi(),
+                    fn_params_vec.into_iter(),
+                    func_node.ret_type(),
+                );
+
+                // Insert new alias
+                let ty_alias = make::ty_alias(
+                    &alias_name,
+                    generic_params,
+                    None,
+                    None,
+                    Some((ast::Type::FnPtrType(ty), None)),
+                )
+                .clone_for_update();
+
+                let indent = IndentLevel::from_node(insertion_node);
+                edit.insert_all(
+                    syntax_editor::Position::before(insertion_node),
+                    vec![
+                        ty_alias.syntax().clone().into(),
+                        make::tokens::whitespace(&format!("\n\n{indent}")).into(),
+                    ],
+                );
+
+                if let Some(cap) = ctx.config.snippet_cap {
+                    if let Some(name) = ty_alias.name() {
+                        edit.add_annotation(name.syntax(), builder.make_placeholder_snippet(cap));
+                    }
+                }
+
+                builder.add_file_edits(ctx.file_id(), edit);
+            },
+        );
+    }
+
+    Some(())
+}
+
+enum ParamStyle {
+    Named,
+    Unnamed,
+}
+
+impl ParamStyle {
+    const ALL: &'static [ParamStyle] = &[ParamStyle::Named, ParamStyle::Unnamed];
+
+    fn assist_id(&self) -> AssistId {
+        let s = match self {
+            ParamStyle::Named => "generate_fn_type_alias_named",
+            ParamStyle::Unnamed => "generate_fn_type_alias_unnamed",
+        };
+
+        AssistId(s, AssistKind::Generate)
+    }
+
+    fn label(&self) -> &'static str {
+        match self {
+            ParamStyle::Named => "Generate a type alias for function with named params",
+            ParamStyle::Unnamed => "Generate a type alias for function with unnamed params",
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::tests::check_assist_by_label;
+
+    use super::*;
+
+    #[test]
+    fn generate_fn_alias_unnamed_simple() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+fn fo$0o(param: u32) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn} = fn(u32) -> i32;
+
+fn foo(param: u32) -> i32 { return 42; }
+"#,
+            ParamStyle::Unnamed.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_unnamed_unsafe() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+unsafe fn fo$0o(param: u32) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn} = unsafe fn(u32) -> i32;
+
+unsafe fn foo(param: u32) -> i32 { return 42; }
+"#,
+            ParamStyle::Unnamed.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_unnamed_extern() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+extern fn fo$0o(param: u32) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn} = extern fn(u32) -> i32;
+
+extern fn foo(param: u32) -> i32 { return 42; }
+"#,
+            ParamStyle::Unnamed.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_type_unnamed_extern_abi() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+extern "FooABI" fn fo$0o(param: u32) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn} = extern "FooABI" fn(u32) -> i32;
+
+extern "FooABI" fn foo(param: u32) -> i32 { return 42; }
+"#,
+            ParamStyle::Unnamed.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_unnamed_unsafe_extern_abi() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+unsafe extern "FooABI" fn fo$0o(param: u32) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn} = unsafe extern "FooABI" fn(u32) -> i32;
+
+unsafe extern "FooABI" fn foo(param: u32) -> i32 { return 42; }
+"#,
+            ParamStyle::Unnamed.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_unnamed_generics() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+fn fo$0o<A, B>(a: A, b: B) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn}<A, B> = fn(A, B) -> i32;
+
+fn foo<A, B>(a: A, b: B) -> i32 { return 42; }
+"#,
+            ParamStyle::Unnamed.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_unnamed_generics_bounds() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+fn fo$0o<A: Trait, B: Trait>(a: A, b: B) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn}<A: Trait, B: Trait> = fn(A, B) -> i32;
+
+fn foo<A: Trait, B: Trait>(a: A, b: B) -> i32 { return 42; }
+"#,
+            ParamStyle::Unnamed.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_unnamed_self() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+struct S;
+
+impl S {
+    fn fo$0o(&mut self, param: u32) -> i32 { return 42; }
+}
+"#,
+            r#"
+struct S;
+
+type ${0:FooFn} = fn(&mut S, u32) -> i32;
+
+impl S {
+    fn foo(&mut self, param: u32) -> i32 { return 42; }
+}
+"#,
+            ParamStyle::Unnamed.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_named_simple() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+fn fo$0o(param: u32) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn} = fn(param: u32) -> i32;
+
+fn foo(param: u32) -> i32 { return 42; }
+"#,
+            ParamStyle::Named.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_named_unsafe() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+unsafe fn fo$0o(param: u32) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn} = unsafe fn(param: u32) -> i32;
+
+unsafe fn foo(param: u32) -> i32 { return 42; }
+"#,
+            ParamStyle::Named.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_named_extern() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+extern fn fo$0o(param: u32) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn} = extern fn(param: u32) -> i32;
+
+extern fn foo(param: u32) -> i32 { return 42; }
+"#,
+            ParamStyle::Named.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_type_named_extern_abi() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+extern "FooABI" fn fo$0o(param: u32) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn} = extern "FooABI" fn(param: u32) -> i32;
+
+extern "FooABI" fn foo(param: u32) -> i32 { return 42; }
+"#,
+            ParamStyle::Named.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_named_unsafe_extern_abi() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+unsafe extern "FooABI" fn fo$0o(param: u32) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn} = unsafe extern "FooABI" fn(param: u32) -> i32;
+
+unsafe extern "FooABI" fn foo(param: u32) -> i32 { return 42; }
+"#,
+            ParamStyle::Named.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_named_generics() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+fn fo$0o<A, B>(a: A, b: B) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn}<A, B> = fn(a: A, b: B) -> i32;
+
+fn foo<A, B>(a: A, b: B) -> i32 { return 42; }
+"#,
+            ParamStyle::Named.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_named_generics_bounds() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+fn fo$0o<A: Trait, B: Trait>(a: A, b: B) -> i32 { return 42; }
+"#,
+            r#"
+type ${0:FooFn}<A: Trait, B: Trait> = fn(a: A, b: B) -> i32;
+
+fn foo<A: Trait, B: Trait>(a: A, b: B) -> i32 { return 42; }
+"#,
+            ParamStyle::Named.label(),
+        );
+    }
+
+    #[test]
+    fn generate_fn_alias_named_self() {
+        check_assist_by_label(
+            generate_fn_type_alias,
+            r#"
+struct S;
+
+impl S {
+    fn fo$0o(&mut self, param: u32) -> i32 { return 42; }
+}
+"#,
+            r#"
+struct S;
+
+type ${0:FooFn} = fn(&mut S, param: u32) -> i32;
+
+impl S {
+    fn foo(&mut self, param: u32) -> i32 { return 42; }
+}
+"#,
+            ParamStyle::Named.label(),
+        );
+    }
+}
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs b/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs
index 8aaf5d6fff2..5c95b25f28d 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/lib.rs
@@ -161,6 +161,7 @@ mod handlers {
     mod generate_enum_is_method;
     mod generate_enum_projection_method;
     mod generate_enum_variant;
+    mod generate_fn_type_alias;
     mod generate_from_impl_for_enum;
     mod generate_function;
     mod generate_getter_or_setter;
@@ -289,6 +290,7 @@ mod handlers {
             generate_enum_projection_method::generate_enum_as_method,
             generate_enum_projection_method::generate_enum_try_into_method,
             generate_enum_variant::generate_enum_variant,
+            generate_fn_type_alias::generate_fn_type_alias,
             generate_from_impl_for_enum::generate_from_impl_for_enum,
             generate_function::generate_function,
             generate_impl::generate_impl,
diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs b/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs
index 933d45d7508..64b7ab1a123 100644
--- a/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs
+++ b/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs
@@ -1549,6 +1549,36 @@ fn main() {
 }
 
 #[test]
+fn doctest_generate_fn_type_alias_named() {
+    check_doc_test(
+        "generate_fn_type_alias_named",
+        r#####"
+unsafe fn fo$0o(n: i32) -> i32 { 42i32 }
+"#####,
+        r#####"
+type ${0:FooFn} = unsafe fn(n: i32) -> i32;
+
+unsafe fn foo(n: i32) -> i32 { 42i32 }
+"#####,
+    )
+}
+
+#[test]
+fn doctest_generate_fn_type_alias_unnamed() {
+    check_doc_test(
+        "generate_fn_type_alias_unnamed",
+        r#####"
+unsafe fn fo$0o(n: i32) -> i32 { 42i32 }
+"#####,
+        r#####"
+type ${0:FooFn} = unsafe fn(i32) -> i32;
+
+unsafe fn foo(n: i32) -> i32 { 42i32 }
+"#####,
+    )
+}
+
+#[test]
 fn doctest_generate_from_impl_for_enum() {
     check_doc_test(
         "generate_from_impl_for_enum",
diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs
index fcdc97ce327..2ec83d23b27 100644
--- a/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs
+++ b/src/tools/rust-analyzer/crates/syntax/src/ast/make.rs
@@ -15,7 +15,11 @@ use parser::{Edition, T};
 use rowan::NodeOrToken;
 use stdx::{format_to, format_to_acc, never};
 
-use crate::{ast, utils::is_raw_identifier, AstNode, SourceFile, SyntaxKind, SyntaxToken};
+use crate::{
+    ast::{self, Param},
+    utils::is_raw_identifier,
+    AstNode, SourceFile, SyntaxKind, SyntaxToken,
+};
 
 /// While the parent module defines basic atomic "constructors", the `ext`
 /// module defines shortcuts for common things.
@@ -198,6 +202,38 @@ pub fn ty_alias(
     ast_from_text(&s)
 }
 
+pub fn ty_fn_ptr<I: Iterator<Item = Param>>(
+    for_lifetime_list: Option<ast::GenericParamList>,
+    is_unsafe: bool,
+    abi: Option<ast::Abi>,
+    params: I,
+    ret_type: Option<ast::RetType>,
+) -> ast::FnPtrType {
+    let mut s = String::from("type __ = ");
+
+    if let Some(list) = for_lifetime_list {
+        format_to!(s, "for{} ", list);
+    }
+
+    if is_unsafe {
+        s.push_str("unsafe ");
+    }
+
+    if let Some(abi) = abi {
+        format_to!(s, "{} ", abi)
+    }
+
+    s.push_str("fn");
+
+    format_to!(s, "({})", params.map(|p| p.to_string()).join(", "));
+
+    if let Some(ret_type) = ret_type {
+        format_to!(s, " {}", ret_type);
+    }
+
+    ast_from_text(&s)
+}
+
 pub fn assoc_item_list() -> ast::AssocItemList {
     ast_from_text("impl C for D {}")
 }
@@ -862,6 +898,10 @@ pub fn item_const(
     ast_from_text(&format!("{visibility} const {name}: {ty} = {expr};"))
 }
 
+pub fn unnamed_param(ty: ast::Type) -> ast::Param {
+    ast_from_text(&format!("fn f({ty}) {{ }}"))
+}
+
 pub fn param(pat: ast::Pat, ty: ast::Type) -> ast::Param {
     ast_from_text(&format!("fn f({pat}: {ty}) {{ }}"))
 }