about summary refs log tree commit diff
path: root/clippy_utils/src/ast_utils.rs
diff options
context:
space:
mode:
Diffstat (limited to 'clippy_utils/src/ast_utils.rs')
-rw-r--r--clippy_utils/src/ast_utils.rs23
1 files changed, 18 insertions, 5 deletions
diff --git a/clippy_utils/src/ast_utils.rs b/clippy_utils/src/ast_utils.rs
index a2c61e07b70..e36f2fa87a7 100644
--- a/clippy_utils/src/ast_utils.rs
+++ b/clippy_utils/src/ast_utils.rs
@@ -188,7 +188,7 @@ pub fn eq_expr(l: &Expr, r: &Expr) -> bool {
             Closure(box ast::Closure {
                 binder: lb,
                 capture_clause: lc,
-                asyncness: la,
+                coroutine_kind: la,
                 movability: lm,
                 fn_decl: lf,
                 body: le,
@@ -197,7 +197,7 @@ pub fn eq_expr(l: &Expr, r: &Expr) -> bool {
             Closure(box ast::Closure {
                 binder: rb,
                 capture_clause: rc,
-                asyncness: ra,
+                coroutine_kind: ra,
                 movability: rm,
                 fn_decl: rf,
                 body: re,
@@ -206,7 +206,7 @@ pub fn eq_expr(l: &Expr, r: &Expr) -> bool {
         ) => {
             eq_closure_binder(lb, rb)
                 && lc == rc
-                && la.is_async() == ra.is_async()
+                && la.map_or(false, CoroutineKind::is_async) == ra.map_or(false, CoroutineKind::is_async)
                 && lm == rm
                 && eq_fn_decl(lf, rf)
                 && eq_expr(le, re)
@@ -236,7 +236,7 @@ pub fn eq_field(l: &ExprField, r: &ExprField) -> bool {
 pub fn eq_arm(l: &Arm, r: &Arm) -> bool {
     l.is_placeholder == r.is_placeholder
         && eq_pat(&l.pat, &r.pat)
-        && eq_expr(&l.body, &r.body)
+        && eq_expr_opt(&l.body, &r.body)
         && eq_expr_opt(&l.guard, &r.guard)
         && over(&l.attrs, &r.attrs, eq_attr)
 }
@@ -563,9 +563,22 @@ pub fn eq_fn_sig(l: &FnSig, r: &FnSig) -> bool {
     eq_fn_decl(&l.decl, &r.decl) && eq_fn_header(&l.header, &r.header)
 }
 
+fn eq_opt_coroutine_kind(l: Option<CoroutineKind>, r: Option<CoroutineKind>) -> bool {
+    matches!(
+        (l, r),
+        (Some(CoroutineKind::Async { .. }), Some(CoroutineKind::Async { .. }))
+            | (Some(CoroutineKind::Gen { .. }), Some(CoroutineKind::Gen { .. }))
+            | (
+                Some(CoroutineKind::AsyncGen { .. }),
+                Some(CoroutineKind::AsyncGen { .. })
+            )
+            | (None, None)
+    )
+}
+
 pub fn eq_fn_header(l: &FnHeader, r: &FnHeader) -> bool {
     matches!(l.unsafety, Unsafe::No) == matches!(r.unsafety, Unsafe::No)
-        && l.asyncness.is_async() == r.asyncness.is_async()
+        && eq_opt_coroutine_kind(l.coroutine_kind, r.coroutine_kind)
         && matches!(l.constness, Const::No) == matches!(r.constness, Const::No)
         && eq_ext(&l.ext, &r.ext)
 }