about summary refs log tree commit diff
diff options
context:
space:
mode:
authorGuillaume Gomez <guillaume1.gomez@gmail.com>2023-07-18 21:45:28 +0200
committerGuillaume Gomez <guillaume1.gomez@gmail.com>2023-07-18 22:21:02 +0200
commit8b0540bb46467cb4af35eed437994753957a43c6 (patch)
treeb9d15eff2ec47988818e3a0ed7d5b7a930c51860
parent9f0cbfd7df87bb95585774ffbf44329940323535 (diff)
downloadrust-8b0540bb46467cb4af35eed437994753957a43c6.tar.gz
rust-8b0540bb46467cb4af35eed437994753957a43c6.zip
Fix async functions handling for `needless_pass_by_ref_mut` lint
-rw-r--r--clippy_lints/src/needless_pass_by_ref_mut.rs89
1 files changed, 67 insertions, 22 deletions
diff --git a/clippy_lints/src/needless_pass_by_ref_mut.rs b/clippy_lints/src/needless_pass_by_ref_mut.rs
index cdb7ee48305..c399b43e251 100644
--- a/clippy_lints/src/needless_pass_by_ref_mut.rs
+++ b/clippy_lints/src/needless_pass_by_ref_mut.rs
@@ -3,14 +3,16 @@ use clippy_utils::diagnostics::span_lint_and_then;
 use clippy_utils::source::snippet;
 use clippy_utils::{is_from_proc_macro, is_self};
 use if_chain::if_chain;
+use rustc_data_structures::fx::FxHashSet;
 use rustc_errors::Applicability;
 use rustc_hir::intravisit::FnKind;
 use rustc_hir::{Body, FnDecl, HirId, HirIdMap, HirIdSet, Impl, ItemKind, Mutability, Node, PatKind};
 use rustc_hir_typeck::expr_use_visitor as euv;
 use rustc_infer::infer::TyCtxtInferExt;
 use rustc_lint::{LateContext, LateLintPass};
+use rustc_middle::hir::map::associated_body;
 use rustc_middle::mir::FakeReadCause;
-use rustc_middle::ty::{self, Ty};
+use rustc_middle::ty::{self, Ty, UpvarId, UpvarPath};
 use rustc_session::{declare_tool_lint, impl_lint_pass};
 use rustc_span::def_id::LocalDefId;
 use rustc_span::symbol::kw;
@@ -102,17 +104,17 @@ impl<'tcx> LateLintPass<'tcx> for NeedlessPassByRefMut {
         }
 
         let hir_id = cx.tcx.hir().local_def_id_to_hir_id(fn_def_id);
-
-        match kind {
+        let is_async = match kind {
             FnKind::ItemFn(.., header) => {
                 let attrs = cx.tcx.hir().attrs(hir_id);
                 if header.abi != Abi::Rust || requires_exact_signature(attrs) {
                     return;
                 }
+                header.is_async()
             },
-            FnKind::Method(..) => (),
+            FnKind::Method(.., sig) => sig.header.is_async(),
             FnKind::Closure => return,
-        }
+        };
 
         // Exclude non-inherent impls
         if let Some(Node::Item(item)) = cx.tcx.hir().find_parent(hir_id) {
@@ -128,13 +130,14 @@ impl<'tcx> LateLintPass<'tcx> for NeedlessPassByRefMut {
         let fn_sig = cx.tcx.liberate_late_bound_regions(fn_def_id.to_def_id(), fn_sig);
 
         // If there are no `&mut` argument, no need to go any further.
-        if !decl
+        let mut it = decl
             .inputs
             .iter()
             .zip(fn_sig.inputs())
             .zip(body.params)
-            .any(|((&input, &ty), arg)| !should_skip(cx, input, ty, arg))
-        {
+            .filter(|((&input, &ty), arg)| !should_skip(cx, input, ty, arg))
+            .peekable();
+        if it.peek().is_none() {
             return;
         }
 
@@ -144,19 +147,25 @@ impl<'tcx> LateLintPass<'tcx> for NeedlessPassByRefMut {
             let mut ctx = MutablyUsedVariablesCtxt::default();
             let infcx = cx.tcx.infer_ctxt().build();
             euv::ExprUseVisitor::new(&mut ctx, &infcx, fn_def_id, cx.param_env, cx.typeck_results()).consume_body(body);
+            if is_async {
+                let closures = ctx.async_closures.clone();
+                let hir = cx.tcx.hir();
+                for closure in closures {
+                    ctx.prev_bind = None;
+                    ctx.prev_move_to_closure.clear();
+                    if let Some(body) = hir
+                        .find_by_def_id(closure)
+                        .and_then(associated_body)
+                        .map(|(_, body_id)| hir.body(body_id))
+                    {
+                        euv::ExprUseVisitor::new(&mut ctx, &infcx, closure, cx.param_env, cx.typeck_results())
+                            .consume_body(body);
+                    }
+                }
+            }
             ctx
         };
 
-        let mut it = decl
-            .inputs
-            .iter()
-            .zip(fn_sig.inputs())
-            .zip(body.params)
-            .filter(|((&input, &ty), arg)| !should_skip(cx, input, ty, arg))
-            .peekable();
-        if it.peek().is_none() {
-            return;
-        }
         let show_semver_warning = self.avoid_breaking_exported_api && cx.effective_visibilities.is_exported(fn_def_id);
         for ((&input, &_), arg) in it {
             // Only take `&mut` arguments.
@@ -197,7 +206,9 @@ impl<'tcx> LateLintPass<'tcx> for NeedlessPassByRefMut {
 struct MutablyUsedVariablesCtxt {
     mutably_used_vars: HirIdSet,
     prev_bind: Option<HirId>,
+    prev_move_to_closure: HirIdSet,
     aliases: HirIdMap<HirId>,
+    async_closures: FxHashSet<LocalDefId>,
 }
 
 impl MutablyUsedVariablesCtxt {
@@ -213,16 +224,27 @@ impl MutablyUsedVariablesCtxt {
 impl<'tcx> euv::Delegate<'tcx> for MutablyUsedVariablesCtxt {
     fn consume(&mut self, cmt: &euv::PlaceWithHirId<'tcx>, _id: HirId) {
         if let euv::Place {
-            base: euv::PlaceBase::Local(vid),
+            base:
+                euv::PlaceBase::Local(vid)
+                | euv::PlaceBase::Upvar(UpvarId {
+                    var_path: UpvarPath { hir_id: vid },
+                    ..
+                }),
             base_ty,
             ..
         } = &cmt.place
         {
             if let Some(bind_id) = self.prev_bind.take() {
-                self.aliases.insert(bind_id, *vid);
-            } else if matches!(base_ty.ref_mutability(), Some(Mutability::Mut)) {
+                if bind_id != *vid {
+                    self.aliases.insert(bind_id, *vid);
+                }
+            } else if !self.prev_move_to_closure.contains(vid)
+                && matches!(base_ty.ref_mutability(), Some(Mutability::Mut))
+            {
                 self.add_mutably_used_var(*vid);
             }
+            self.prev_bind = None;
+            self.prev_move_to_closure.remove(vid);
         }
     }
 
@@ -265,7 +287,30 @@ impl<'tcx> euv::Delegate<'tcx> for MutablyUsedVariablesCtxt {
         self.prev_bind = None;
     }
 
-    fn fake_read(&mut self, _: &rustc_hir_typeck::expr_use_visitor::PlaceWithHirId<'tcx>, _: FakeReadCause, _: HirId) {}
+    fn fake_read(
+        &mut self,
+        cmt: &rustc_hir_typeck::expr_use_visitor::PlaceWithHirId<'tcx>,
+        cause: FakeReadCause,
+        _id: HirId,
+    ) {
+        if let euv::Place {
+            base:
+                euv::PlaceBase::Upvar(UpvarId {
+                    var_path: UpvarPath { hir_id: vid },
+                    ..
+                }),
+            ..
+        } = &cmt.place
+        {
+            if let FakeReadCause::ForLet(Some(inner)) = cause {
+                // Seems like we are inside an async function. We need to store the closure `DefId`
+                // to go through it afterwards.
+                self.async_closures.insert(inner);
+                self.aliases.insert(cmt.hir_id, *vid);
+                self.prev_move_to_closure.insert(*vid);
+            }
+        }
+    }
 
     fn bind(&mut self, _cmt: &euv::PlaceWithHirId<'tcx>, id: HirId) {
         self.prev_bind = Some(id);