about summary refs log tree commit diff
diff options
context:
space:
mode:
authorEthiraric <ethiraric@gmail.com>2022-03-13 21:56:34 +0100
committerEthiraric <ethiraric@gmail.com>2022-04-03 14:34:08 +0200
commit520557d45c15375f32cb838ed2e2030b86d11fd2 (patch)
treef36401858fec1952088a37042bc4baa4a5e3a395
parentf8a21e4c70bba61d63472c730f7e0b16324b5806 (diff)
downloadrust-520557d45c15375f32cb838ed2e2030b86d11fd2.tar.gz
rust-520557d45c15375f32cb838ed2e2030b86d11fd2.zip
feat: assist to remove unneeded `async`s
-rw-r--r--crates/ide_assists/src/handlers/unnecessary_async.rs257
-rw-r--r--crates/ide_assists/src/lib.rs2
-rw-r--r--crates/ide_assists/src/tests/generated.rs15
3 files changed, 274 insertions, 0 deletions
diff --git a/crates/ide_assists/src/handlers/unnecessary_async.rs b/crates/ide_assists/src/handlers/unnecessary_async.rs
new file mode 100644
index 00000000000..d90fee7809e
--- /dev/null
+++ b/crates/ide_assists/src/handlers/unnecessary_async.rs
@@ -0,0 +1,257 @@
+use ide_db::{
+    assists::{AssistId, AssistKind},
+    base_db::FileId,
+    defs::Definition,
+    search::FileReference,
+    syntax_helpers::node_ext::full_path_of_name_ref,
+};
+use syntax::{
+    ast::{self, NameLike, NameRef},
+    AstNode, SyntaxKind, TextRange,
+};
+
+use crate::{AssistContext, Assists};
+
+// Assist: unnecessary_async
+//
+// Removes the `async` mark from functions which have no `.await` in their body.
+// Looks for calls to the functions and removes the `.await` on the call site.
+//
+// ```
+// pub async f$0n foo() {}
+// pub async fn bar() { foo().await }
+// ```
+// ->
+// ```
+// pub fn foo() {}
+// pub async fn bar() { foo() }
+// ```
+pub(crate) fn unnecessary_async(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
+    let function: ast::Fn = ctx.find_node_at_offset()?;
+
+    // Do nothing if the cursor is not on the prototype. This is so that the check does not pollute
+    // when the user asks us for assists when in the middle of the function body.
+    // We consider the prototype to be anything that is before the body of the function.
+    let cursor_position = ctx.offset();
+    if cursor_position >= function.body()?.syntax().text_range().start() {
+        return None;
+    }
+    // Do nothing if the function isn't async.
+    if let None = function.async_token() {
+        return None;
+    }
+    // Do nothing if the function has an `await` expression in its body.
+    if function.body()?.syntax().descendants().find_map(ast::AwaitExpr::cast).is_some() {
+        return None;
+    }
+
+    // Remove the `async` keyword plus whitespace after it, if any.
+    let async_range = {
+        let async_token = function.async_token()?;
+        let next_token = async_token.next_token()?;
+        if matches!(next_token.kind(), SyntaxKind::WHITESPACE) {
+            TextRange::new(async_token.text_range().start(), next_token.text_range().end())
+        } else {
+            async_token.text_range()
+        }
+    };
+
+    // Otherwise, we may remove the `async` keyword.
+    acc.add(
+        AssistId("unnecessary_async", AssistKind::QuickFix),
+        "Remove unnecessary async",
+        async_range,
+        |edit| {
+            // Remove async on the function definition.
+            edit.replace(async_range, "");
+
+            // Remove all `.await`s from calls to the function we remove `async` from.
+            if let Some(fn_def) = ctx.sema.to_def(&function) {
+                for await_expr in find_all_references(ctx, &Definition::Function(fn_def))
+                    // Keep only references that correspond NameRefs.
+                    .filter_map(|(_, reference)| match reference.name {
+                        NameLike::NameRef(nameref) => Some(nameref),
+                        _ => None,
+                    })
+                    // Keep only references that correspond to await expressions
+                    .filter_map(|nameref| find_await_expression(ctx, &nameref))
+                {
+                    if let Some(await_token) = &await_expr.await_token() {
+                        edit.replace(await_token.text_range(), "");
+                    }
+                    if let Some(dot_token) = &await_expr.dot_token() {
+                        edit.replace(dot_token.text_range(), "");
+                    }
+                }
+            }
+        },
+    )
+}
+
+fn find_all_references(
+    ctx: &AssistContext,
+    def: &Definition,
+) -> impl Iterator<Item = (FileId, FileReference)> {
+    def.usages(&ctx.sema).all().into_iter().flat_map(|(file_id, references)| {
+        references.into_iter().map(move |reference| (file_id, reference))
+    })
+}
+
+/// Finds the await expression for the given `NameRef`.
+/// If no await expression is found, returns None.
+fn find_await_expression(ctx: &AssistContext, nameref: &NameRef) -> Option<ast::AwaitExpr> {
+    // From the nameref, walk up the tree to the await expression.
+    let await_expr = if let Some(path) = full_path_of_name_ref(&nameref) {
+        // Function calls.
+        path.syntax()
+            .parent()
+            .and_then(ast::PathExpr::cast)?
+            .syntax()
+            .parent()
+            .and_then(ast::CallExpr::cast)?
+            .syntax()
+            .parent()
+            .and_then(ast::AwaitExpr::cast)
+    } else {
+        // Method calls.
+        nameref
+            .syntax()
+            .parent()
+            .and_then(ast::MethodCallExpr::cast)?
+            .syntax()
+            .parent()
+            .and_then(ast::AwaitExpr::cast)
+    };
+
+    ctx.sema.original_ast_node(await_expr?)
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    use crate::tests::{check_assist, check_assist_not_applicable};
+
+    #[test]
+    fn applies_on_empty_function() {
+        check_assist(unnecessary_async, "pub async f$0n f() {}", "pub fn f() {}")
+    }
+
+    #[test]
+    fn applies_and_removes_whitespace() {
+        check_assist(unnecessary_async, "pub async       f$0n f() {}", "pub fn f() {}")
+    }
+
+    #[test]
+    fn does_not_apply_on_non_async_function() {
+        check_assist_not_applicable(unnecessary_async, "pub f$0n f() {}")
+    }
+
+    #[test]
+    fn applies_on_function_with_a_non_await_expr() {
+        check_assist(unnecessary_async, "pub async f$0n f() { f2() }", "pub fn f() { f2() }")
+    }
+
+    #[test]
+    fn does_not_apply_on_function_with_an_await_expr() {
+        check_assist_not_applicable(unnecessary_async, "pub async f$0n f() { f2().await }")
+    }
+
+    #[test]
+    fn applies_and_removes_await_on_reference() {
+        check_assist(
+            unnecessary_async,
+            r#"
+pub async fn f4() { }
+pub async f$0n f2() { }
+pub async fn f() { f2().await }
+pub async fn f3() { f2().await }"#,
+            r#"
+pub async fn f4() { }
+pub fn f2() { }
+pub async fn f() { f2() }
+pub async fn f3() { f2() }"#,
+        )
+    }
+
+    #[test]
+    fn applies_and_removes_await_from_within_module() {
+        check_assist(
+            unnecessary_async,
+            r#"
+pub async fn f4() { }
+mod a { pub async f$0n f2() { } }
+pub async fn f() { a::f2().await }
+pub async fn f3() { a::f2().await }"#,
+            r#"
+pub async fn f4() { }
+mod a { pub fn f2() { } }
+pub async fn f() { a::f2() }
+pub async fn f3() { a::f2() }"#,
+        )
+    }
+
+    #[test]
+    fn applies_and_removes_await_on_inner_await() {
+        check_assist(
+            unnecessary_async,
+            // Ensure that it is the first await on the 3rd line that is removed
+            r#"
+pub async fn f() { f2().await }
+pub async f$0n f2() -> i32 { 1 }
+pub async fn f3() { f4(f2().await).await }
+pub async fn f4(i: i32) { }"#,
+            r#"
+pub async fn f() { f2() }
+pub fn f2() -> i32 { 1 }
+pub async fn f3() { f4(f2()).await }
+pub async fn f4(i: i32) { }"#,
+        )
+    }
+
+    #[test]
+    fn applies_and_removes_await_on_outer_await() {
+        check_assist(
+            unnecessary_async,
+            // Ensure that it is the second await on the 3rd line that is removed
+            r#"
+pub async fn f() { f2().await }
+pub async f$0n f2(i: i32) { }
+pub async fn f3() { f2(f4().await).await }
+pub async fn f4() -> i32 { 1 }"#,
+            r#"
+pub async fn f() { f2() }
+pub fn f2(i: i32) { }
+pub async fn f3() { f2(f4().await) }
+pub async fn f4() -> i32 { 1 }"#,
+        )
+    }
+
+    #[test]
+    fn applies_on_method_call() {
+        check_assist(
+            unnecessary_async,
+            r#"
+pub struct S { }
+impl S { pub async f$0n f2(&self) { } }
+pub async fn f(s: &S) { s.f2().await }"#,
+            r#"
+pub struct S { }
+impl S { pub fn f2(&self) { } }
+pub async fn f(s: &S) { s.f2() }"#,
+        )
+    }
+
+    #[test]
+    fn does_not_apply_on_function_with_a_nested_await_expr() {
+        check_assist_not_applicable(
+            unnecessary_async,
+            "async f$0n f() { if true { loop { f2().await } } }",
+        )
+    }
+
+    #[test]
+    fn does_not_apply_when_not_on_prototype() {
+        check_assist_not_applicable(unnecessary_async, "pub async fn f() { $0f2() }")
+    }
+}
diff --git a/crates/ide_assists/src/lib.rs b/crates/ide_assists/src/lib.rs
index 6eff8871e8a..ef4aa1c62bd 100644
--- a/crates/ide_assists/src/lib.rs
+++ b/crates/ide_assists/src/lib.rs
@@ -183,6 +183,7 @@ mod handlers {
     mod sort_items;
     mod toggle_ignore;
     mod unmerge_use;
+    mod unnecessary_async;
     mod unwrap_block;
     mod unwrap_result_return_type;
     mod wrap_return_type_in_result;
@@ -268,6 +269,7 @@ mod handlers {
             split_import::split_import,
             toggle_ignore::toggle_ignore,
             unmerge_use::unmerge_use,
+            unnecessary_async::unnecessary_async,
             unwrap_block::unwrap_block,
             unwrap_result_return_type::unwrap_result_return_type,
             wrap_return_type_in_result::wrap_return_type_in_result,
diff --git a/crates/ide_assists/src/tests/generated.rs b/crates/ide_assists/src/tests/generated.rs
index 282374b3cf3..8a1e95d8947 100644
--- a/crates/ide_assists/src/tests/generated.rs
+++ b/crates/ide_assists/src/tests/generated.rs
@@ -2107,6 +2107,21 @@ use std::fmt::Display;
 }
 
 #[test]
+fn doctest_unnecessary_async() {
+    check_doc_test(
+        "unnecessary_async",
+        r#####"
+pub async f$0n foo() {}
+pub async fn bar() { foo().await }
+"#####,
+        r#####"
+pub fn foo() {}
+pub async fn bar() { foo() }
+"#####,
+    )
+}
+
+#[test]
 fn doctest_unwrap_block() {
     check_doc_test(
         "unwrap_block",