about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--clippy_utils/src/lib.rs31
1 files changed, 20 insertions, 11 deletions
diff --git a/clippy_utils/src/lib.rs b/clippy_utils/src/lib.rs
index 8d22d434ebc..0f4a42cec0c 100644
--- a/clippy_utils/src/lib.rs
+++ b/clippy_utils/src/lib.rs
@@ -106,10 +106,10 @@ use rustc_hir::hir_id::{HirIdMap, HirIdSet};
 use rustc_hir::intravisit::{FnKind, Visitor, walk_expr};
 use rustc_hir::{
     self as hir, Arm, BindingMode, Block, BlockCheckMode, Body, ByRef, Closure, ConstArgKind, ConstContext,
-    Destination, Expr, ExprField, ExprKind, FnDecl, FnRetTy, GenericArg, GenericArgs, HirId, Impl, ImplItem,
-    ImplItemKind, ImplItemRef, Item, ItemKind, LangItem, LetStmt, MatchSource, Mutability, Node, OwnerId, OwnerNode,
-    Param, Pat, PatExpr, PatExprKind, PatKind, Path, PathSegment, PrimTy, QPath, Stmt, StmtKind, TraitFn, TraitItem,
-    TraitItemKind, TraitItemRef, TraitRef, TyKind, UnOp, def,
+    CoroutineDesugaring, CoroutineKind, Destination, Expr, ExprField, ExprKind, FnDecl, FnRetTy, GenericArg,
+    GenericArgs, HirId, Impl, ImplItem, ImplItemKind, ImplItemRef, Item, ItemKind, LangItem, LetStmt, MatchSource,
+    Mutability, Node, OwnerId, OwnerNode, Param, Pat, PatExpr, PatExprKind, PatKind, Path, PathSegment, PrimTy, QPath,
+    Stmt, StmtKind, TraitFn, TraitItem, TraitItemKind, TraitItemRef, TraitRef, TyKind, UnOp, def,
 };
 use rustc_lexer::{TokenKind, tokenize};
 use rustc_lint::{LateContext, Level, Lint, LintContext};
@@ -2134,15 +2134,18 @@ pub fn is_async_fn(kind: FnKind<'_>) -> bool {
     }
 }
 
-/// Peels away all the compiler generated code surrounding the body of an async function,
-pub fn get_async_fn_body<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'_>) -> Option<&'tcx Expr<'tcx>> {
-    if let ExprKind::Closure(&Closure { body, .. }) = body.value.kind
+/// Peels away all the compiler generated code surrounding the body of an async closure.
+pub fn get_async_closure_expr<'tcx>(tcx: TyCtxt<'tcx>, expr: &Expr<'_>) -> Option<&'tcx Expr<'tcx>> {
+    if let ExprKind::Closure(&Closure {
+        body,
+        kind: hir::ClosureKind::Coroutine(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)),
+        ..
+    }) = expr.kind
         && let ExprKind::Block(
             Block {
-                stmts: [],
                 expr:
                     Some(Expr {
-                        kind: ExprKind::DropTemps(expr),
+                        kind: ExprKind::DropTemps(inner_expr),
                         ..
                     }),
                 ..
@@ -2150,9 +2153,15 @@ pub fn get_async_fn_body<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'_>) -> Option<&'t
             _,
         ) = tcx.hir_body(body).value.kind
     {
-        return Some(expr);
+        Some(inner_expr)
+    } else {
+        None
     }
-    None
+}
+
+/// Peels away all the compiler generated code surrounding the body of an async function,
+pub fn get_async_fn_body<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'_>) -> Option<&'tcx Expr<'tcx>> {
+    get_async_closure_expr(tcx, body.value)
 }
 
 // check if expr is calling method or function with #[must_use] attribute