about summary refs log tree commit diff
diff options
context:
space:
mode:
authorhkalbasi <hamidrezakalbasi@protonmail.com>2023-05-02 12:57:34 +0330
committerhkalbasi <hamidrezakalbasi@protonmail.com>2023-05-02 12:57:34 +0330
commit38544f56ab284152cc84363b81b0653d39db2990 (patch)
tree2fac4a2a116b4160639a5debf00dac24d17bbeca
parent266ceb7b4de6864480877e2dbcb8463d622ac257 (diff)
downloadrust-38544f56ab284152cc84363b81b0653d39db2990.tar.gz
rust-38544f56ab284152cc84363b81b0653d39db2990.zip
Catch overflow in shift binop evaluation
-rw-r--r--crates/hir-ty/src/consteval/tests.rs6
-rw-r--r--crates/hir-ty/src/mir/eval.rs41
2 files changed, 27 insertions, 20 deletions
diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs
index b700864f7dd..5a850f6d570 100644
--- a/crates/hir-ty/src/consteval/tests.rs
+++ b/crates/hir-ty/src/consteval/tests.rs
@@ -97,8 +97,10 @@ fn bit_op() {
     check_number(r#"const GOAL: u8 = !0 & !(!0 >> 1)"#, 128);
     check_number(r#"const GOAL: i8 = !0 & !(!0 >> 1)"#, 0);
     check_number(r#"const GOAL: i8 = 1 << 7"#, (1i8 << 7) as i128);
-    // FIXME: report panic here
-    check_number(r#"const GOAL: i8 = 1 << 8"#, 0);
+    check_number(r#"const GOAL: i8 = -1 << 2"#, (-1i8 << 2) as i128);
+    check_fail(r#"const GOAL: i8 = 1 << 8"#, |e| {
+        e == ConstEvalError::MirEvalError(MirEvalError::Panic("Overflow in Shl".to_string()))
+    });
 }
 
 #[test]
diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs
index 9811cd9192b..7ff68774bc9 100644
--- a/crates/hir-ty/src/mir/eval.rs
+++ b/crates/hir-ty/src/mir/eval.rs
@@ -860,6 +860,16 @@ impl Evaluator<'_> {
                     let is_signed = matches!(ty.as_builtin(), Some(BuiltinType::Int(_)));
                     let l128 = i128::from_le_bytes(pad16(lc, is_signed));
                     let r128 = i128::from_le_bytes(pad16(rc, is_signed));
+                    let check_overflow = |r: i128| {
+                        // FIXME: this is not very correct, and only catches the basic cases.
+                        let r = r.to_le_bytes();
+                        for &k in &r[lc.len()..] {
+                            if k != 0 && (k != 255 || !is_signed) {
+                                return Err(MirEvalError::Panic(format!("Overflow in {op:?}")));
+                            }
+                        }
+                        Ok(Owned(r[0..lc.len()].into()))
+                    };
                     match op {
                         BinOp::Ge | BinOp::Gt | BinOp::Le | BinOp::Lt | BinOp::Eq | BinOp::Ne => {
                             let r = op.run_compare(l128, r128) as u8;
@@ -888,28 +898,23 @@ impl Evaluator<'_> {
                                 BinOp::BitXor => l128 ^ r128,
                                 _ => unreachable!(),
                             };
-                            let r = r.to_le_bytes();
-                            for &k in &r[lc.len()..] {
-                                if k != 0 && (k != 255 || !is_signed) {
-                                    return Err(MirEvalError::Panic(format!("Overflow in {op:?}")));
-                                }
-                            }
-                            Owned(r[0..lc.len()].into())
+                            check_overflow(r)?
                         }
                         BinOp::Shl | BinOp::Shr => {
-                            let shift_amount = if r128 < 0 {
-                                return Err(MirEvalError::Panic(format!("Overflow in {op:?}")));
-                            } else if r128 > 128 {
+                            let r = 'b: {
+                                if let Ok(shift_amount) = u32::try_from(r128) {
+                                    let r = match op {
+                                        BinOp::Shl => l128.checked_shl(shift_amount),
+                                        BinOp::Shr => l128.checked_shr(shift_amount),
+                                        _ => unreachable!(),
+                                    };
+                                    if let Some(r) = r {
+                                        break 'b r;
+                                    }
+                                };
                                 return Err(MirEvalError::Panic(format!("Overflow in {op:?}")));
-                            } else {
-                                r128 as u8
-                            };
-                            let r = match op {
-                                BinOp::Shl => l128 << shift_amount,
-                                BinOp::Shr => l128 >> shift_amount,
-                                _ => unreachable!(),
                             };
-                            Owned(r.to_le_bytes()[0..lc.len()].into())
+                            check_overflow(r)?
                         }
                         BinOp::Offset => not_supported!("offset binop"),
                     }