about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorShoyu Vanilla <modulo641@gmail.com>2024-08-05 03:07:24 +0900
committerShoyu Vanilla <modulo641@gmail.com>2024-08-05 21:47:57 +0900
commit3e809f8f1f80d6d7e02fb43ab2da720c7dd34873 (patch)
treec0a2d6b394d4b488ff9c81f1505364cbe7854f7a /src
parent6da26c15580d5c41ee2bb29ae4372b8c52f96174 (diff)
downloadrust-3e809f8f1f80d6d7e02fb43ab2da720c7dd34873.tar.gz
rust-3e809f8f1f80d6d7e02fb43ab2da720c7dd34873.zip
feat: Implement diagnostic for `await` outside of `async`
Diffstat (limited to 'src')
-rw-r--r--src/tools/rust-analyzer/crates/hir-def/src/body.rs1
-rw-r--r--src/tools/rust-analyzer/crates/hir-def/src/body/lower.rs90
-rw-r--r--src/tools/rust-analyzer/crates/hir/src/diagnostics.rs7
-rw-r--r--src/tools/rust-analyzer/crates/hir/src/lib.rs3
-rw-r--r--src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/await_outside_of_async.rs101
-rw-r--r--src/tools/rust-analyzer/crates/ide-diagnostics/src/lib.rs2
6 files changed, 187 insertions, 17 deletions
diff --git a/src/tools/rust-analyzer/crates/hir-def/src/body.rs b/src/tools/rust-analyzer/crates/hir-def/src/body.rs
index d3c134f3266..58812479ddf 100644
--- a/src/tools/rust-analyzer/crates/hir-def/src/body.rs
+++ b/src/tools/rust-analyzer/crates/hir-def/src/body.rs
@@ -118,6 +118,7 @@ pub enum BodyDiagnostic {
     MacroError { node: InFile<AstPtr<ast::MacroCall>>, err: ExpandError },
     UnresolvedMacroCall { node: InFile<AstPtr<ast::MacroCall>>, path: ModPath },
     UnreachableLabel { node: InFile<AstPtr<ast::Lifetime>>, name: Name },
+    AwaitOutsideOfAsync { node: InFile<AstPtr<ast::AwaitExpr>>, location: String },
     UndeclaredLabel { node: InFile<AstPtr<ast::Lifetime>>, name: Name },
 }
 
diff --git a/src/tools/rust-analyzer/crates/hir-def/src/body/lower.rs b/src/tools/rust-analyzer/crates/hir-def/src/body/lower.rs
index 9e30aff8fe9..abf78958292 100644
--- a/src/tools/rust-analyzer/crates/hir-def/src/body/lower.rs
+++ b/src/tools/rust-analyzer/crates/hir-def/src/body/lower.rs
@@ -72,6 +72,7 @@ pub(super) fn lower(
         is_lowering_coroutine: false,
         label_ribs: Vec::new(),
         current_binding_owner: None,
+        awaitable_context: None,
     }
     .collect(params, body, is_async_fn)
 }
@@ -100,6 +101,8 @@ struct ExprCollector<'a> {
     // resolution
     label_ribs: Vec<LabelRib>,
     current_binding_owner: Option<ExprId>,
+
+    awaitable_context: Option<Awaitable>,
 }
 
 #[derive(Clone, Debug)]
@@ -135,6 +138,11 @@ impl RibKind {
     }
 }
 
+enum Awaitable {
+    Yes,
+    No(&'static str),
+}
+
 #[derive(Debug, Default)]
 struct BindingList {
     map: FxHashMap<Name, BindingId>,
@@ -180,6 +188,18 @@ impl ExprCollector<'_> {
         body: Option<ast::Expr>,
         is_async_fn: bool,
     ) -> (Body, BodySourceMap) {
+        self.awaitable_context.replace(if is_async_fn {
+            Awaitable::Yes
+        } else {
+            match self.owner {
+                DefWithBodyId::FunctionId(..) => Awaitable::No("non-async function"),
+                DefWithBodyId::StaticId(..) => Awaitable::No("static"),
+                DefWithBodyId::ConstId(..) | DefWithBodyId::InTypeConstId(..) => {
+                    Awaitable::No("constant")
+                }
+                DefWithBodyId::VariantId(..) => Awaitable::No("enum variant"),
+            }
+        });
         if let Some((param_list, mut attr_enabled)) = param_list {
             let mut params = vec![];
             if let Some(self_param) =
@@ -280,31 +300,40 @@ impl ExprCollector<'_> {
                 }
                 Some(ast::BlockModifier::Async(_)) => {
                     self.with_label_rib(RibKind::Closure, |this| {
-                        this.collect_block_(e, |id, statements, tail| Expr::Async {
-                            id,
-                            statements,
-                            tail,
+                        this.with_awaitable_block(Awaitable::Yes, |this| {
+                            this.collect_block_(e, |id, statements, tail| Expr::Async {
+                                id,
+                                statements,
+                                tail,
+                            })
                         })
                     })
                 }
                 Some(ast::BlockModifier::Const(_)) => {
                     self.with_label_rib(RibKind::Constant, |this| {
-                        let (result_expr_id, prev_binding_owner) =
-                            this.initialize_binding_owner(syntax_ptr);
-                        let inner_expr = this.collect_block(e);
-                        let it = this.db.intern_anonymous_const(ConstBlockLoc {
-                            parent: this.owner,
-                            root: inner_expr,
-                        });
-                        this.body.exprs[result_expr_id] = Expr::Const(it);
-                        this.current_binding_owner = prev_binding_owner;
-                        result_expr_id
+                        this.with_awaitable_block(Awaitable::No("constant block"), |this| {
+                            let (result_expr_id, prev_binding_owner) =
+                                this.initialize_binding_owner(syntax_ptr);
+                            let inner_expr = this.collect_block(e);
+                            let it = this.db.intern_anonymous_const(ConstBlockLoc {
+                                parent: this.owner,
+                                root: inner_expr,
+                            });
+                            this.body.exprs[result_expr_id] = Expr::Const(it);
+                            this.current_binding_owner = prev_binding_owner;
+                            result_expr_id
+                        })
                     })
                 }
                 // FIXME
-                Some(ast::BlockModifier::AsyncGen(_)) | Some(ast::BlockModifier::Gen(_)) | None => {
-                    self.collect_block(e)
+                Some(ast::BlockModifier::AsyncGen(_)) => {
+                    self.with_awaitable_block(Awaitable::Yes, |this| this.collect_block(e))
                 }
+                Some(ast::BlockModifier::Gen(_)) => self
+                    .with_awaitable_block(Awaitable::No("non-async gen block"), |this| {
+                        this.collect_block(e)
+                    }),
+                None => self.collect_block(e),
             },
             ast::Expr::LoopExpr(e) => {
                 let label = e.label().map(|label| self.collect_label(label));
@@ -469,6 +498,12 @@ impl ExprCollector<'_> {
             }
             ast::Expr::AwaitExpr(e) => {
                 let expr = self.collect_expr_opt(e.expr());
+                if let Awaitable::No(location) = self.is_lowering_awaitable_block() {
+                    self.source_map.diagnostics.push(BodyDiagnostic::AwaitOutsideOfAsync {
+                        node: InFile::new(self.expander.current_file_id(), AstPtr::new(&e)),
+                        location: location.to_string(),
+                    });
+                }
                 self.alloc_expr(Expr::Await { expr }, syntax_ptr)
             }
             ast::Expr::TryExpr(e) => self.collect_try_operator(syntax_ptr, e),
@@ -527,7 +562,13 @@ impl ExprCollector<'_> {
                 let prev_is_lowering_coroutine = mem::take(&mut this.is_lowering_coroutine);
                 let prev_try_block_label = this.current_try_block_label.take();
 
-                let body = this.collect_expr_opt(e.body());
+                let awaitable = if e.async_token().is_some() {
+                    Awaitable::Yes
+                } else {
+                    Awaitable::No("non-async closure")
+                };
+                let body =
+                    this.with_awaitable_block(awaitable, |this| this.collect_expr_opt(e.body()));
 
                 let closure_kind = if this.is_lowering_coroutine {
                     let movability = if e.static_token().is_some() {
@@ -2082,6 +2123,21 @@ impl ExprCollector<'_> {
     fn alloc_label_desugared(&mut self, label: Label) -> LabelId {
         self.body.labels.alloc(label)
     }
+
+    fn is_lowering_awaitable_block(&self) -> &Awaitable {
+        self.awaitable_context.as_ref().unwrap_or(&Awaitable::No("unknown"))
+    }
+
+    fn with_awaitable_block<T>(
+        &mut self,
+        awaitable: Awaitable,
+        f: impl FnOnce(&mut Self) -> T,
+    ) -> T {
+        let orig = self.awaitable_context.replace(awaitable);
+        let res = f(self);
+        self.awaitable_context = orig;
+        res
+    }
 }
 
 fn comma_follows_token(t: Option<syntax::SyntaxToken>) -> bool {
diff --git a/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs b/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs
index 4bb8c140a1f..ffb972475f8 100644
--- a/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs
+++ b/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs
@@ -48,6 +48,7 @@ macro_rules! diagnostics {
 // ]
 
 diagnostics![
+    AwaitOutsideOfAsync,
     BreakOutsideOfLoop,
     ExpectedFunction,
     InactiveCode,
@@ -135,6 +136,12 @@ pub struct UnreachableLabel {
     pub name: Name,
 }
 
+#[derive(Debug)]
+pub struct AwaitOutsideOfAsync {
+    pub node: InFile<AstPtr<ast::AwaitExpr>>,
+    pub location: String,
+}
+
 #[derive(Debug, Clone, Eq, PartialEq)]
 pub struct UndeclaredLabel {
     pub node: InFile<AstPtr<ast::Lifetime>>,
diff --git a/src/tools/rust-analyzer/crates/hir/src/lib.rs b/src/tools/rust-analyzer/crates/hir/src/lib.rs
index 266ef2a55c5..1c5e4ce4b53 100644
--- a/src/tools/rust-analyzer/crates/hir/src/lib.rs
+++ b/src/tools/rust-analyzer/crates/hir/src/lib.rs
@@ -1828,6 +1828,9 @@ impl DefWithBody {
                     is_bang: true,
                 }
                 .into(),
+                BodyDiagnostic::AwaitOutsideOfAsync { node, location } => {
+                    AwaitOutsideOfAsync { node: *node, location: location.clone() }.into()
+                }
                 BodyDiagnostic::UnreachableLabel { node, name } => {
                     UnreachableLabel { node: *node, name: name.clone() }.into()
                 }
diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/await_outside_of_async.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/await_outside_of_async.rs
new file mode 100644
index 00000000000..92b6e748ca5
--- /dev/null
+++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/await_outside_of_async.rs
@@ -0,0 +1,101 @@
+use crate::{adjusted_display_range, Diagnostic, DiagnosticsContext};
+
+// Diagnostic: await-outside-of-async
+//
+// This diagnostic is triggered if the `await` keyword is used outside of an async function or block
+pub(crate) fn await_outside_of_async(
+    ctx: &DiagnosticsContext<'_>,
+    d: &hir::AwaitOutsideOfAsync,
+) -> Diagnostic {
+    let display_range =
+        adjusted_display_range(ctx, d.node, &|node| Some(node.await_token()?.text_range()));
+    Diagnostic::new(
+        crate::DiagnosticCode::RustcHardError("E0728"),
+        format!("`await` is used inside {}, which is not an `async` context", d.location),
+        display_range,
+    )
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::tests::check_diagnostics;
+
+    #[test]
+    fn await_inside_non_async_fn() {
+        check_diagnostics(
+            r#"
+async fn foo() {}
+
+fn bar() {
+    foo().await;
+        //^^^^^ error: `await` is used inside non-async function, which is not an `async` context
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn await_inside_async_fn() {
+        check_diagnostics(
+            r#"
+async fn foo() {}
+
+async fn bar() {
+    foo().await;
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn await_inside_closure() {
+        check_diagnostics(
+            r#"
+async fn foo() {}
+
+async fn bar() {
+    let _a = || { foo().await };
+                      //^^^^^ error: `await` is used inside non-async closure, which is not an `async` context
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn await_inside_async_block() {
+        check_diagnostics(
+            r#"
+async fn foo() {}
+
+fn bar() {
+    let _a = async { foo().await };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn await_in_complex_context() {
+        check_diagnostics(
+            r#"
+async fn foo() {}
+
+fn bar() {
+    async fn baz() {
+        let a = foo().await;
+    }
+
+    let x = || {
+        let y = async {
+            baz().await;
+            let z = || {
+                baz().await;
+                    //^^^^^ error: `await` is used inside non-async closure, which is not an `async` context
+            };
+        };
+    };
+}
+"#,
+        );
+    }
+}
diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/lib.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/lib.rs
index 263ab747559..7d9589dab22 100644
--- a/src/tools/rust-analyzer/crates/ide-diagnostics/src/lib.rs
+++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/lib.rs
@@ -24,6 +24,7 @@
 //! don't yet have a great pattern for how to do them properly.
 
 mod handlers {
+    pub(crate) mod await_outside_of_async;
     pub(crate) mod break_outside_of_loop;
     pub(crate) mod expected_function;
     pub(crate) mod inactive_code;
@@ -348,6 +349,7 @@ pub fn diagnostics(
 
     for diag in diags {
         let d = match diag {
+            AnyDiagnostic::AwaitOutsideOfAsync(d) => handlers::await_outside_of_async::await_outside_of_async(&ctx, &d),
             AnyDiagnostic::ExpectedFunction(d) => handlers::expected_function::expected_function(&ctx, &d),
             AnyDiagnostic::InactiveCode(d) => match handlers::inactive_code::inactive_code(&ctx, &d) {
                 Some(it) => it,