about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/ide-assists/src/handlers/convert_nested_function_to_closure.rs209
-rw-r--r--crates/ide-assists/src/lib.rs4
-rw-r--r--crates/ide-assists/src/tests/generated.rs25
3 files changed, 237 insertions, 1 deletions
diff --git a/crates/ide-assists/src/handlers/convert_nested_function_to_closure.rs b/crates/ide-assists/src/handlers/convert_nested_function_to_closure.rs
new file mode 100644
index 00000000000..399f87c8f50
--- /dev/null
+++ b/crates/ide-assists/src/handlers/convert_nested_function_to_closure.rs
@@ -0,0 +1,209 @@
+use ide_db::assists::{AssistId, AssistKind};
+use syntax::ast::{self, HasGenericParams, HasName};
+use syntax::{AstNode, SyntaxKind};
+
+use crate::assist_context::{AssistContext, Assists};
+
+// Assist: convert_nested_function_to_closure
+//
+// Converts a function that is defined within the body of another function into a closure.
+//
+// ```
+// fn main() {
+//     fn fo$0o(label: &str, number: u64) {
+//         println!("{}: {}", label, number);
+//     }
+//
+//     foo("Bar", 100);
+// }
+// ```
+// ->
+// ```
+// fn main() {
+//     let foo = |label: &str, number: u64| {
+//         println!("{}: {}", label, number);
+//     };
+//
+//     foo("Bar", 100);
+// }
+// ```
+pub(crate) fn convert_nested_function_to_closure(
+    acc: &mut Assists,
+    ctx: &AssistContext<'_>,
+) -> Option<()> {
+    let name = ctx.find_node_at_offset::<ast::Name>()?;
+    let function = name.syntax().parent().and_then(ast::Fn::cast)?;
+
+    if !is_nested_function(&function) || is_generic(&function) || has_modifiers(&function) {
+        return None;
+    }
+
+    let target = function.syntax().text_range();
+    let body = function.body()?;
+    let name = function.name()?;
+    let param_list = function.param_list()?;
+
+    acc.add(
+        AssistId("convert_nested_function_to_closure", AssistKind::RefactorRewrite),
+        "Convert nested function to closure",
+        target,
+        |edit| {
+            let params = &param_list.syntax().text().to_string();
+            let params = params.strip_prefix("(").unwrap_or(params);
+            let params = params.strip_suffix(")").unwrap_or(params);
+
+            let mut body = body.to_string();
+            if !has_semicolon(&function) {
+                body.push(';');
+            }
+            edit.replace(target, format!("let {name} = |{params}| {body}"));
+        },
+    )
+}
+
+/// Returns whether the given function is nested within the body of another function.
+fn is_nested_function(function: &ast::Fn) -> bool {
+    function.syntax().ancestors().skip(1).find_map(ast::Item::cast).map_or(false, |it| {
+        matches!(it, ast::Item::Fn(_) | ast::Item::Static(_) | ast::Item::Const(_))
+    })
+}
+
+/// Returns whether the given nested function has generic parameters.
+fn is_generic(function: &ast::Fn) -> bool {
+    function.generic_param_list().is_some()
+}
+
+/// Returns whether the given nested function has any modifiers:
+///
+/// - `async`,
+/// - `const` or
+/// - `unsafe`
+fn has_modifiers(function: &ast::Fn) -> bool {
+    function.async_token().is_some()
+        || function.const_token().is_some()
+        || function.unsafe_token().is_some()
+}
+
+/// Returns whether the given nested function has a trailing semicolon.
+fn has_semicolon(function: &ast::Fn) -> bool {
+    function
+        .syntax()
+        .next_sibling_or_token()
+        .map(|t| t.kind() == SyntaxKind::SEMICOLON)
+        .unwrap_or(false)
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::tests::{check_assist, check_assist_not_applicable};
+
+    use super::convert_nested_function_to_closure;
+
+    #[test]
+    fn convert_nested_function_to_closure_works() {
+        check_assist(
+            convert_nested_function_to_closure,
+            r#"
+fn main() {
+    fn $0foo(a: u64, b: u64) -> u64 {
+        2 * (a + b)
+    }
+
+    _ = foo(3, 4);
+}
+            "#,
+            r#"
+fn main() {
+    let foo = |a: u64, b: u64| {
+        2 * (a + b)
+    };
+
+    _ = foo(3, 4);
+}
+            "#,
+        );
+    }
+
+    #[test]
+    fn convert_nested_function_to_closure_works_with_existing_semicolon() {
+        check_assist(
+            convert_nested_function_to_closure,
+            r#"
+fn main() {
+    fn foo$0(a: u64, b: u64) -> u64 {
+        2 * (a + b)
+    };
+
+    _ = foo(3, 4);
+}
+            "#,
+            r#"
+fn main() {
+    let foo = |a: u64, b: u64| {
+        2 * (a + b)
+    };
+
+    _ = foo(3, 4);
+}
+            "#,
+        );
+    }
+
+    #[test]
+    fn convert_nested_function_to_closure_is_not_suggested_on_top_level_function() {
+        check_assist_not_applicable(
+            convert_nested_function_to_closure,
+            r#"
+fn ma$0in() {}
+            "#,
+        );
+    }
+
+    #[test]
+    fn convert_nested_function_to_closure_is_not_suggested_when_cursor_off_name() {
+        check_assist_not_applicable(
+            convert_nested_function_to_closure,
+            r#"
+fn main() {
+    fn foo(a: u64, $0b: u64) -> u64 {
+        2 * (a + b)
+    }
+
+    _ = foo(3, 4);
+}
+            "#,
+        );
+    }
+
+    #[test]
+    fn convert_nested_function_to_closure_is_not_suggested_if_function_has_generic_params() {
+        check_assist_not_applicable(
+            convert_nested_function_to_closure,
+            r#"
+fn main() {
+    fn fo$0o<S: Into<String>>(s: S) -> String {
+        s.into()
+    }
+
+    _ = foo("hello");
+}
+            "#,
+        );
+    }
+
+    #[test]
+    fn convert_nested_function_to_closure_is_not_suggested_if_function_has_modifier() {
+        check_assist_not_applicable(
+            convert_nested_function_to_closure,
+            r#"
+fn main() {
+    const fn fo$0o(s: String) -> String {
+        s
+    }
+
+    _ = foo("hello");
+}
+            "#,
+        );
+    }
+}
diff --git a/crates/ide-assists/src/lib.rs b/crates/ide-assists/src/lib.rs
index 8b07e29a587..fc03903e593 100644
--- a/crates/ide-assists/src/lib.rs
+++ b/crates/ide-assists/src/lib.rs
@@ -122,6 +122,7 @@ mod handlers {
     mod convert_iter_for_each_to_for;
     mod convert_let_else_to_match;
     mod convert_match_to_let_else;
+    mod convert_nested_function_to_closure;
     mod convert_tuple_struct_to_named_struct;
     mod convert_named_struct_to_tuple_struct;
     mod convert_to_guarded_return;
@@ -228,8 +229,9 @@ mod handlers {
             convert_iter_for_each_to_for::convert_iter_for_each_to_for,
             convert_iter_for_each_to_for::convert_for_loop_with_for_each,
             convert_let_else_to_match::convert_let_else_to_match,
-            convert_named_struct_to_tuple_struct::convert_named_struct_to_tuple_struct,
             convert_match_to_let_else::convert_match_to_let_else,
+            convert_named_struct_to_tuple_struct::convert_named_struct_to_tuple_struct,
+            convert_nested_function_to_closure::convert_nested_function_to_closure,
             convert_to_guarded_return::convert_to_guarded_return,
             convert_tuple_struct_to_named_struct::convert_tuple_struct_to_named_struct,
             convert_two_arm_bool_match_to_matches_macro::convert_two_arm_bool_match_to_matches_macro,
diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs
index e2c71d2565c..f093dfddfae 100644
--- a/crates/ide-assists/src/tests/generated.rs
+++ b/crates/ide-assists/src/tests/generated.rs
@@ -495,6 +495,31 @@ impl Point {
 }
 
 #[test]
+fn doctest_convert_nested_function_to_closure() {
+    check_doc_test(
+        "convert_nested_function_to_closure",
+        r#####"
+fn main() {
+    fn fo$0o(label: &str, number: u64) {
+        println!("{}: {}", label, number);
+    }
+
+    foo("Bar", 100);
+}
+"#####,
+        r#####"
+fn main() {
+    let foo = |label: &str, number: u64| {
+        println!("{}: {}", label, number);
+    };
+
+    foo("Bar", 100);
+}
+"#####,
+    )
+}
+
+#[test]
 fn doctest_convert_to_guarded_return() {
     check_doc_test(
         "convert_to_guarded_return",