about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--clippy_lints/src/copies.rs64
-rw-r--r--clippy_utils/src/lib.rs2
-rw-r--r--clippy_utils/src/visitors.rs70
-rw-r--r--tests/ui/branches_sharing_code/false_positives.rs54
4 files changed, 182 insertions, 8 deletions
diff --git a/clippy_lints/src/copies.rs b/clippy_lints/src/copies.rs
index 1deff9684a1..0e3d9317590 100644
--- a/clippy_lints/src/copies.rs
+++ b/clippy_lints/src/copies.rs
@@ -1,13 +1,16 @@
 use clippy_utils::diagnostics::{span_lint_and_note, span_lint_and_then};
 use clippy_utils::source::{first_line_of_span, indent_of, reindent_multiline, snippet, snippet_opt};
+use clippy_utils::ty::needs_ordered_drop;
+use clippy_utils::visitors::for_each_expr;
 use clippy_utils::{
-    eq_expr_value, get_enclosing_block, hash_expr, hash_stmt, if_sequence, is_else_clause, is_lint_allowed,
-    search_same, ContainsName, HirEqInterExpr, SpanlessEq,
+    capture_local_usage, eq_expr_value, get_enclosing_block, hash_expr, hash_stmt, if_sequence, is_else_clause,
+    is_lint_allowed, path_to_local, search_same, ContainsName, HirEqInterExpr, SpanlessEq,
 };
 use core::iter;
+use core::ops::ControlFlow;
 use rustc_errors::Applicability;
 use rustc_hir::intravisit;
-use rustc_hir::{BinOpKind, Block, Expr, ExprKind, HirId, Stmt, StmtKind};
+use rustc_hir::{BinOpKind, Block, Expr, ExprKind, HirId, HirIdSet, Stmt, StmtKind};
 use rustc_lint::{LateContext, LateLintPass};
 use rustc_session::{declare_lint_pass, declare_tool_lint};
 use rustc_span::hygiene::walk_chain;
@@ -214,7 +217,7 @@ fn lint_if_same_then_else(cx: &LateContext<'_>, conds: &[&Expr<'_>], blocks: &[&
 fn lint_branches_sharing_code<'tcx>(
     cx: &LateContext<'tcx>,
     conds: &[&'tcx Expr<'_>],
-    blocks: &[&Block<'tcx>],
+    blocks: &[&'tcx Block<'_>],
     expr: &'tcx Expr<'_>,
 ) {
     // We only lint ifs with multiple blocks
@@ -340,6 +343,21 @@ fn eq_binding_names(s: &Stmt<'_>, names: &[(HirId, Symbol)]) -> bool {
     }
 }
 
+/// Checks if the statement modifies or moves any of the given locals.
+fn modifies_any_local<'tcx>(cx: &LateContext<'tcx>, s: &'tcx Stmt<'_>, locals: &HirIdSet) -> bool {
+    for_each_expr(s, |e| {
+        if let Some(id) = path_to_local(e)
+            && locals.contains(&id)
+            && !capture_local_usage(cx, e).is_imm_ref()
+        {
+            ControlFlow::Break(())
+        } else {
+            ControlFlow::Continue(())
+        }
+    })
+    .is_some()
+}
+
 /// Checks if the given statement should be considered equal to the statement in the same position
 /// for each block.
 fn eq_stmts(
@@ -365,18 +383,52 @@ fn eq_stmts(
         .all(|b| get_stmt(b).map_or(false, |s| eq.eq_stmt(s, stmt)))
 }
 
-fn scan_block_for_eq(cx: &LateContext<'_>, _conds: &[&Expr<'_>], block: &Block<'_>, blocks: &[&Block<'_>]) -> BlockEq {
+#[expect(clippy::too_many_lines)]
+fn scan_block_for_eq<'tcx>(
+    cx: &LateContext<'tcx>,
+    conds: &[&'tcx Expr<'_>],
+    block: &'tcx Block<'_>,
+    blocks: &[&'tcx Block<'_>],
+) -> BlockEq {
     let mut eq = SpanlessEq::new(cx);
     let mut eq = eq.inter_expr();
     let mut moved_locals = Vec::new();
 
+    let mut cond_locals = HirIdSet::default();
+    for &cond in conds {
+        let _: Option<!> = for_each_expr(cond, |e| {
+            if let Some(id) = path_to_local(e) {
+                cond_locals.insert(id);
+            }
+            ControlFlow::Continue(())
+        });
+    }
+
+    let mut local_needs_ordered_drop = false;
     let start_end_eq = block
         .stmts
         .iter()
         .enumerate()
-        .find(|&(i, stmt)| !eq_stmts(stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals))
+        .find(|&(i, stmt)| {
+            if let StmtKind::Local(l) = stmt.kind
+                && needs_ordered_drop(cx, cx.typeck_results().node_type(l.hir_id))
+            {
+                local_needs_ordered_drop = true;
+                return true;
+            }
+            modifies_any_local(cx, stmt, &cond_locals)
+                || !eq_stmts(stmt, blocks, |b| b.stmts.get(i), &mut eq, &mut moved_locals)
+        })
         .map_or(block.stmts.len(), |(i, _)| i);
 
+    if local_needs_ordered_drop {
+        return BlockEq {
+            start_end_eq,
+            end_begin_eq: None,
+            moved_locals,
+        };
+    }
+
     // Walk backwards through the final expression/statements so long as their hashes are equal. Note
     // `SpanlessHash` treats all local references as equal allowing locals declared earlier in the block
     // to match those in other blocks. e.g. If each block ends with the following the hash value will be
diff --git a/clippy_utils/src/lib.rs b/clippy_utils/src/lib.rs
index 0e739303683..fd0c6869929 100644
--- a/clippy_utils/src/lib.rs
+++ b/clippy_utils/src/lib.rs
@@ -890,7 +890,7 @@ pub fn capture_local_usage<'tcx>(cx: &LateContext<'tcx>, e: &Expr<'_>) -> Captur
             Node::Expr(e) => match e.kind {
                 ExprKind::AddrOf(_, mutability, _) => return CaptureKind::Ref(mutability),
                 ExprKind::Index(..) | ExprKind::Unary(UnOp::Deref, _) => capture = CaptureKind::Ref(Mutability::Not),
-                ExprKind::Assign(lhs, ..) | ExprKind::Assign(_, lhs, _) if lhs.hir_id == child_id => {
+                ExprKind::Assign(lhs, ..) | ExprKind::AssignOp(_, lhs, _) if lhs.hir_id == child_id => {
                     return CaptureKind::Ref(Mutability::Mut);
                 },
                 ExprKind::Field(..) => {
diff --git a/clippy_utils/src/visitors.rs b/clippy_utils/src/visitors.rs
index e89a46b8538..0b9f75238a6 100644
--- a/clippy_utils/src/visitors.rs
+++ b/clippy_utils/src/visitors.rs
@@ -5,7 +5,7 @@ use rustc_hir as hir;
 use rustc_hir::def::{CtorKind, DefKind, Res};
 use rustc_hir::intravisit::{self, walk_block, walk_expr, Visitor};
 use rustc_hir::{
-    Arm, Block, BlockCheckMode, Body, BodyId, Expr, ExprKind, HirId, ItemId, ItemKind, Let, QPath, Stmt, UnOp,
+    Arm, Block, BlockCheckMode, Body, BodyId, Expr, ExprKind, HirId, ItemId, ItemKind, Let, Pat, QPath, Stmt, UnOp,
     UnsafeSource, Unsafety,
 };
 use rustc_lint::LateContext;
@@ -13,6 +13,74 @@ use rustc_middle::hir::map::Map;
 use rustc_middle::hir::nested_filter;
 use rustc_middle::ty::adjustment::Adjust;
 use rustc_middle::ty::{self, Ty, TypeckResults};
+use rustc_span::Span;
+
+mod internal {
+    /// Trait for visitor functions to control whether or not to descend to child nodes. Implemented
+    /// for only two types. `()` always descends. `Descend` allows controlled descent.
+    pub trait Continue {
+        fn descend(&self) -> bool;
+    }
+}
+use internal::Continue;
+
+impl Continue for () {
+    fn descend(&self) -> bool {
+        true
+    }
+}
+
+/// Allows for controlled descent whe using visitor functions. Use `()` instead when always
+/// descending into child nodes.
+#[derive(Clone, Copy)]
+pub enum Descend {
+    Yes,
+    No,
+}
+impl From<bool> for Descend {
+    fn from(from: bool) -> Self {
+        if from { Self::Yes } else { Self::No }
+    }
+}
+impl Continue for Descend {
+    fn descend(&self) -> bool {
+        matches!(self, Self::Yes)
+    }
+}
+
+/// Calls the given function once for each expression contained. This does not enter any bodies or
+/// nested items.
+pub fn for_each_expr<'tcx, B, C: Continue>(
+    node: impl Visitable<'tcx>,
+    f: impl FnMut(&'tcx Expr<'tcx>) -> ControlFlow<B, C>,
+) -> Option<B> {
+    struct V<B, F> {
+        f: F,
+        res: Option<B>,
+    }
+    impl<'tcx, B, C: Continue, F: FnMut(&'tcx Expr<'tcx>) -> ControlFlow<B, C>> Visitor<'tcx> for V<B, F> {
+        fn visit_expr(&mut self, e: &'tcx Expr<'tcx>) {
+            if self.res.is_some() {
+                return;
+            }
+            match (self.f)(e) {
+                ControlFlow::Continue(c) if c.descend() => walk_expr(self, e),
+                ControlFlow::Break(b) => self.res = Some(b),
+                ControlFlow::Continue(_) => (),
+            }
+        }
+
+        // Avoid unnecessary `walk_*` calls.
+        fn visit_ty(&mut self, _: &'tcx hir::Ty<'tcx>) {}
+        fn visit_pat(&mut self, _: &'tcx Pat<'tcx>) {}
+        fn visit_qpath(&mut self, _: &'tcx QPath<'tcx>, _: HirId, _: Span) {}
+        // Avoid monomorphising all `visit_*` functions.
+        fn visit_nested_item(&mut self, _: ItemId) {}
+    }
+    let mut v = V { f, res: None };
+    node.visit(&mut v);
+    v.res
+}
 
 /// Convenience method for creating a `Visitor` with just `visit_expr` overridden and nested
 /// bodies (i.e. closures) are visited.
diff --git a/tests/ui/branches_sharing_code/false_positives.rs b/tests/ui/branches_sharing_code/false_positives.rs
index 06448200951..5e3a1a29693 100644
--- a/tests/ui/branches_sharing_code/false_positives.rs
+++ b/tests/ui/branches_sharing_code/false_positives.rs
@@ -1,6 +1,8 @@
 #![allow(dead_code)]
 #![deny(clippy::if_same_then_else, clippy::branches_sharing_code)]
 
+use std::sync::Mutex;
+
 // ##################################
 // # Issue clippy#7369
 // ##################################
@@ -38,4 +40,56 @@ fn main() {
         let (y, x) = x;
         foo(x, y)
     };
+
+    let m = Mutex::new(0u32);
+    let l = m.lock().unwrap();
+    let _ = if true {
+        drop(l);
+        println!("foo");
+        m.lock().unwrap();
+        0
+    } else if *l == 0 {
+        drop(l);
+        println!("foo");
+        println!("bar");
+        m.lock().unwrap();
+        1
+    } else {
+        drop(l);
+        println!("foo");
+        println!("baz");
+        m.lock().unwrap();
+        2
+    };
+
+    if true {
+        let _guard = m.lock();
+        println!("foo");
+    } else {
+        println!("foo");
+    }
+
+    if true {
+        let _guard = m.lock();
+        println!("foo");
+        println!("bar");
+    } else {
+        let _guard = m.lock();
+        println!("foo");
+        println!("baz");
+    }
+
+    let mut c = 0;
+    for _ in 0..5 {
+        if c == 0 {
+            c += 1;
+            println!("0");
+        } else if c == 1 {
+            c += 1;
+            println!("1");
+        } else {
+            c += 1;
+            println!("more");
+        }
+    }
 }