about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-01-06 19:22:59 +0000
committerbors <bors@rust-lang.org>2024-01-06 19:22:59 +0000
commit0e5dc8e630038338ac4432c5f7553573ef15a61c (patch)
tree33d07baa52325105bdccf69840553a02b8fb68ce
parent788094c23570d2f2b6feb883372d5bc23fc93c6c (diff)
parenteae2317977acb0d42a93740ae93122a1287d8090 (diff)
downloadrust-0e5dc8e630038338ac4432c5f7553573ef15a61c.tar.gz
rust-0e5dc8e630038338ac4432c5f7553573ef15a61c.zip
Auto merge of #11883 - J-ZhengLi:issue11642, r=dswij
improve [`cast_sign_loss`], to skip warning on always positive expressions

fixes: #11642

changelog: improve [`cast_sign_loss`] to skip warning on always positive expressions

Turns out this is change became quite big, and I still can't cover all the cases, like method calls such as `POSITIVE_NUM.mul(POSITIVE_NUM)`, or `NEGATIVE_NUM.div(NEGATIVE_NUM)`... but well, if I do, I'm scared that this will goes forever, so I stopped, unless it needs to be done, lol.
-rw-r--r--clippy_lints/src/casts/cast_sign_loss.rs141
-rw-r--r--tests/ui/cast.rs49
-rw-r--r--tests/ui/cast.stderr74
3 files changed, 238 insertions, 26 deletions
diff --git a/clippy_lints/src/casts/cast_sign_loss.rs b/clippy_lints/src/casts/cast_sign_loss.rs
index bd12ee40628..1df5a25f674 100644
--- a/clippy_lints/src/casts/cast_sign_loss.rs
+++ b/clippy_lints/src/casts/cast_sign_loss.rs
@@ -1,12 +1,14 @@
 use clippy_utils::consts::{constant, Constant};
 use clippy_utils::diagnostics::span_lint;
-use clippy_utils::{method_chain_args, sext};
-use rustc_hir::{Expr, ExprKind};
+use clippy_utils::{clip, method_chain_args, sext};
+use rustc_hir::{BinOpKind, Expr, ExprKind};
 use rustc_lint::LateContext;
-use rustc_middle::ty::{self, Ty};
+use rustc_middle::ty::{self, Ty, UintTy};
 
 use super::CAST_SIGN_LOSS;
 
+const METHODS_RET_POSITIVE: &[&str] = &["abs", "checked_abs", "rem_euclid", "checked_rem_euclid"];
+
 pub(super) fn check(cx: &LateContext<'_>, expr: &Expr<'_>, cast_op: &Expr<'_>, cast_from: Ty<'_>, cast_to: Ty<'_>) {
     if should_lint(cx, cast_op, cast_from, cast_to) {
         span_lint(
@@ -25,33 +27,28 @@ fn should_lint(cx: &LateContext<'_>, cast_op: &Expr<'_>, cast_from: Ty<'_>, cast
                 return false;
             }
 
-            // Don't lint for positive constants.
-            let const_val = constant(cx, cx.typeck_results(), cast_op);
-            if let Some(Constant::Int(n)) = const_val
-                && let ty::Int(ity) = *cast_from.kind()
-                && sext(cx.tcx, n, ity) >= 0
-            {
+            // Don't lint if `cast_op` is known to be positive.
+            if let Sign::ZeroOrPositive = expr_sign(cx, cast_op, cast_from) {
                 return false;
             }
 
-            // Don't lint for the result of methods that always return non-negative values.
-            if let ExprKind::MethodCall(path, ..) = cast_op.kind {
-                let mut method_name = path.ident.name.as_str();
-                let allowed_methods = ["abs", "checked_abs", "rem_euclid", "checked_rem_euclid"];
-
-                if method_name == "unwrap"
-                    && let Some(arglist) = method_chain_args(cast_op, &["unwrap"])
-                    && let ExprKind::MethodCall(inner_path, ..) = &arglist[0].0.kind
-                {
-                    method_name = inner_path.ident.name.as_str();
-                }
-
-                if allowed_methods.iter().any(|&name| method_name == name) {
-                    return false;
-                }
+            let (mut uncertain_count, mut negative_count) = (0, 0);
+            // Peel off possible binary expressions, e.g. x * x * y => [x, x, y]
+            let Some(exprs) = exprs_with_selected_binop_peeled(cast_op) else {
+                // Assume cast sign lose if we cannot determine the sign of `cast_op`
+                return true;
+            };
+            for expr in exprs {
+                let ty = cx.typeck_results().expr_ty(expr);
+                match expr_sign(cx, expr, ty) {
+                    Sign::Negative => negative_count += 1,
+                    Sign::Uncertain => uncertain_count += 1,
+                    Sign::ZeroOrPositive => (),
+                };
             }
 
-            true
+            // Lint if there are odd number of uncertain or negative results
+            uncertain_count % 2 == 1 || negative_count % 2 == 1
         },
 
         (false, true) => !cast_to.is_signed(),
@@ -59,3 +56,97 @@ fn should_lint(cx: &LateContext<'_>, cast_op: &Expr<'_>, cast_from: Ty<'_>, cast
         (_, _) => false,
     }
 }
+
+fn get_const_int_eval(cx: &LateContext<'_>, expr: &Expr<'_>, ty: Ty<'_>) -> Option<i128> {
+    if let Constant::Int(n) = constant(cx, cx.typeck_results(), expr)?
+        && let ty::Int(ity) = *ty.kind()
+    {
+        return Some(sext(cx.tcx, n, ity));
+    }
+    None
+}
+
+enum Sign {
+    ZeroOrPositive,
+    Negative,
+    Uncertain,
+}
+
+fn expr_sign(cx: &LateContext<'_>, expr: &Expr<'_>, ty: Ty<'_>) -> Sign {
+    // Try evaluate this expr first to see if it's positive
+    if let Some(val) = get_const_int_eval(cx, expr, ty) {
+        return if val >= 0 { Sign::ZeroOrPositive } else { Sign::Negative };
+    }
+    // Calling on methods that always return non-negative values.
+    if let ExprKind::MethodCall(path, caller, args, ..) = expr.kind {
+        let mut method_name = path.ident.name.as_str();
+
+        if method_name == "unwrap"
+            && let Some(arglist) = method_chain_args(expr, &["unwrap"])
+            && let ExprKind::MethodCall(inner_path, ..) = &arglist[0].0.kind
+        {
+            method_name = inner_path.ident.name.as_str();
+        }
+
+        if method_name == "pow"
+            && let [arg] = args
+        {
+            return pow_call_result_sign(cx, caller, arg);
+        } else if METHODS_RET_POSITIVE.iter().any(|&name| method_name == name) {
+            return Sign::ZeroOrPositive;
+        }
+    }
+
+    Sign::Uncertain
+}
+
+/// Return the sign of the `pow` call's result.
+///
+/// If the caller is a positive number, the result is always positive,
+/// If the `power_of` is a even number, the result is always positive as well,
+/// Otherwise a [`Sign::Uncertain`] will be returned.
+fn pow_call_result_sign(cx: &LateContext<'_>, caller: &Expr<'_>, power_of: &Expr<'_>) -> Sign {
+    let caller_ty = cx.typeck_results().expr_ty(caller);
+    if let Some(caller_val) = get_const_int_eval(cx, caller, caller_ty)
+        && caller_val >= 0
+    {
+        return Sign::ZeroOrPositive;
+    }
+
+    if let Some(Constant::Int(n)) = constant(cx, cx.typeck_results(), power_of)
+        && clip(cx.tcx, n, UintTy::U32) % 2 == 0
+    {
+        return Sign::ZeroOrPositive;
+    }
+
+    Sign::Uncertain
+}
+
+/// Peels binary operators such as [`BinOpKind::Mul`], [`BinOpKind::Div`] or [`BinOpKind::Rem`],
+/// which the result could always be positive under certain condition.
+///
+/// Other operators such as `+`/`-` causing the result's sign hard to determine, which we will
+/// return `None`
+fn exprs_with_selected_binop_peeled<'a>(expr: &'a Expr<'_>) -> Option<Vec<&'a Expr<'a>>> {
+    #[inline]
+    fn collect_operands<'a>(expr: &'a Expr<'a>, operands: &mut Vec<&'a Expr<'a>>) -> Option<()> {
+        match expr.kind {
+            ExprKind::Binary(op, lhs, rhs) => {
+                if matches!(op.node, BinOpKind::Mul | BinOpKind::Div | BinOpKind::Rem) {
+                    collect_operands(lhs, operands);
+                    operands.push(rhs);
+                } else {
+                    // Things are complicated when there are other binary ops exist,
+                    // abort checking by returning `None` for now.
+                    return None;
+                }
+            },
+            _ => operands.push(expr),
+        }
+        Some(())
+    }
+
+    let mut res = vec![];
+    collect_operands(expr, &mut res)?;
+    Some(res)
+}
diff --git a/tests/ui/cast.rs b/tests/ui/cast.rs
index 1ca18170f8a..e9476c80ccb 100644
--- a/tests/ui/cast.rs
+++ b/tests/ui/cast.rs
@@ -365,3 +365,52 @@ fn avoid_subtract_overflow(q: u32) {
 fn issue11426() {
     (&42u8 >> 0xa9008fb6c9d81e42_0e25730562a601c8_u128) as usize;
 }
+
+fn issue11642() {
+    fn square(x: i16) -> u32 {
+        let x = x as i32;
+        (x * x) as u32;
+        x.pow(2) as u32;
+        (-2_i32).pow(2) as u32
+    }
+
+    let _a = |x: i32| -> u32 { (x * x * x * x) as u32 };
+
+    (-2_i32).pow(3) as u32;
+    //~^ ERROR: casting `i32` to `u32` may lose the sign of the value
+
+    let x: i32 = 10;
+    (x * x) as u32;
+    (x * x * x) as u32;
+    //~^ ERROR: casting `i32` to `u32` may lose the sign of the value
+
+    let y: i16 = -2;
+    (y * y * y * y * -2) as u16;
+    //~^ ERROR: casting `i16` to `u16` may lose the sign of the value
+    (y * y * y * y * 2) as u16;
+    (y * y * y * 2) as u16;
+    //~^ ERROR: casting `i16` to `u16` may lose the sign of the value
+    (y * y * y * -2) as u16;
+    //~^ ERROR: casting `i16` to `u16` may lose the sign of the value
+
+    fn foo(a: i32, b: i32, c: i32) -> u32 {
+        (a * a * b * b * c * c) as u32;
+        (a * b * c) as u32;
+        //~^ ERROR: casting `i32` to `u32` may lose the sign of the value
+        (a * -b * c) as u32;
+        //~^ ERROR: casting `i32` to `u32` may lose the sign of the value
+        (a * b * c * c) as u32;
+        (a * -2) as u32;
+        //~^ ERROR: casting `i32` to `u32` may lose the sign of the value
+        (a * b * c * -2) as u32;
+        //~^ ERROR: casting `i32` to `u32` may lose the sign of the value
+        (a / b) as u32;
+        (a / b * c) as u32;
+        //~^ ERROR: casting `i32` to `u32` may lose the sign of the value
+        (a / b + b * c) as u32;
+        //~^ ERROR: casting `i32` to `u32` may lose the sign of the value
+        a.pow(3) as u32;
+        //~^ ERROR: casting `i32` to `u32` may lose the sign of the value
+        (a.abs() * b.pow(2) / c.abs()) as u32
+    }
+}
diff --git a/tests/ui/cast.stderr b/tests/ui/cast.stderr
index bc74f7b728e..4e37af7f378 100644
--- a/tests/ui/cast.stderr
+++ b/tests/ui/cast.stderr
@@ -444,5 +444,77 @@ help: ... or use `try_from` and handle the error accordingly
 LL |     let c = u8::try_from(q / 1000);
    |             ~~~~~~~~~~~~~~~~~~~~~~
 
-error: aborting due to 51 previous errors
+error: casting `i32` to `u32` may lose the sign of the value
+  --> $DIR/cast.rs:379:5
+   |
+LL |     (-2_i32).pow(3) as u32;
+   |     ^^^^^^^^^^^^^^^^^^^^^^
+
+error: casting `i32` to `u32` may lose the sign of the value
+  --> $DIR/cast.rs:384:5
+   |
+LL |     (x * x * x) as u32;
+   |     ^^^^^^^^^^^^^^^^^^
+
+error: casting `i16` to `u16` may lose the sign of the value
+  --> $DIR/cast.rs:388:5
+   |
+LL |     (y * y * y * y * -2) as u16;
+   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+error: casting `i16` to `u16` may lose the sign of the value
+  --> $DIR/cast.rs:391:5
+   |
+LL |     (y * y * y * 2) as u16;
+   |     ^^^^^^^^^^^^^^^^^^^^^^
+
+error: casting `i16` to `u16` may lose the sign of the value
+  --> $DIR/cast.rs:393:5
+   |
+LL |     (y * y * y * -2) as u16;
+   |     ^^^^^^^^^^^^^^^^^^^^^^^
+
+error: casting `i32` to `u32` may lose the sign of the value
+  --> $DIR/cast.rs:398:9
+   |
+LL |         (a * b * c) as u32;
+   |         ^^^^^^^^^^^^^^^^^^
+
+error: casting `i32` to `u32` may lose the sign of the value
+  --> $DIR/cast.rs:400:9
+   |
+LL |         (a * -b * c) as u32;
+   |         ^^^^^^^^^^^^^^^^^^^
+
+error: casting `i32` to `u32` may lose the sign of the value
+  --> $DIR/cast.rs:403:9
+   |
+LL |         (a * -2) as u32;
+   |         ^^^^^^^^^^^^^^^
+
+error: casting `i32` to `u32` may lose the sign of the value
+  --> $DIR/cast.rs:405:9
+   |
+LL |         (a * b * c * -2) as u32;
+   |         ^^^^^^^^^^^^^^^^^^^^^^^
+
+error: casting `i32` to `u32` may lose the sign of the value
+  --> $DIR/cast.rs:408:9
+   |
+LL |         (a / b * c) as u32;
+   |         ^^^^^^^^^^^^^^^^^^
+
+error: casting `i32` to `u32` may lose the sign of the value
+  --> $DIR/cast.rs:410:9
+   |
+LL |         (a / b + b * c) as u32;
+   |         ^^^^^^^^^^^^^^^^^^^^^^
+
+error: casting `i32` to `u32` may lose the sign of the value
+  --> $DIR/cast.rs:412:9
+   |
+LL |         a.pow(3) as u32;
+   |         ^^^^^^^^^^^^^^^
+
+error: aborting due to 63 previous errors