about summary refs log tree commit diff
path: root/compiler/rustc_const_eval
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-04-19 17:00:28 +0000
committerbors <bors@rust-lang.org>2024-04-19 17:00:28 +0000
commitce3263e60e73f4018592cbfba540cf8bef4399de (patch)
tree36c27a6ec5815a6d7e4fe105d078874eac0d3d22 /compiler/rustc_const_eval
parentd1a0fa5ed3ffe52d72f761d3c95cbeb0a9cdfe66 (diff)
parentd3f927db874aaa82480cd2120e15a507e7cb9b15 (diff)
downloadrust-ce3263e60e73f4018592cbfba540cf8bef4399de.tar.gz
rust-ce3263e60e73f4018592cbfba540cf8bef4399de.zip
Auto merge of #124113 - RalfJung:interpret-scalar-ops, r=oli-obk
interpret: use ScalarInt for bin-ops; avoid PartialOrd for ScalarInt

Best reviewed commit-by-commit

r? `@oli-obk`
Diffstat (limited to 'compiler/rustc_const_eval')
-rw-r--r--compiler/rustc_const_eval/src/interpret/discriminant.rs3
-rw-r--r--compiler/rustc_const_eval/src/interpret/operand.rs24
-rw-r--r--compiler/rustc_const_eval/src/interpret/operator.rs113
3 files changed, 82 insertions, 58 deletions
diff --git a/compiler/rustc_const_eval/src/interpret/discriminant.rs b/compiler/rustc_const_eval/src/interpret/discriminant.rs
index 704f597cfdb..caacc6f57d3 100644
--- a/compiler/rustc_const_eval/src/interpret/discriminant.rs
+++ b/compiler/rustc_const_eval/src/interpret/discriminant.rs
@@ -295,8 +295,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                         &niche_start_val,
                     )?
                     .to_scalar()
-                    .try_to_int()
-                    .unwrap();
+                    .assert_int();
                 Ok(Some((tag, tag_field)))
             }
         }
diff --git a/compiler/rustc_const_eval/src/interpret/operand.rs b/compiler/rustc_const_eval/src/interpret/operand.rs
index 75a672785ea..718c91b2f76 100644
--- a/compiler/rustc_const_eval/src/interpret/operand.rs
+++ b/compiler/rustc_const_eval/src/interpret/operand.rs
@@ -6,9 +6,10 @@ use std::assert_matches::assert_matches;
 use either::{Either, Left, Right};
 
 use rustc_hir::def::Namespace;
+use rustc_middle::mir::interpret::ScalarSizeMismatch;
 use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
 use rustc_middle::ty::print::{FmtPrinter, PrettyPrinter};
-use rustc_middle::ty::{ConstInt, Ty, TyCtxt};
+use rustc_middle::ty::{ConstInt, ScalarInt, Ty, TyCtxt};
 use rustc_middle::{mir, ty};
 use rustc_target::abi::{self, Abi, HasDataLayout, Size};
 
@@ -211,6 +212,12 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
     }
 
     #[inline]
+    pub fn from_scalar_int(s: ScalarInt, layout: TyAndLayout<'tcx>) -> Self {
+        assert_eq!(s.size(), layout.size);
+        Self::from_scalar(Scalar::from(s), layout)
+    }
+
+    #[inline]
     pub fn try_from_uint(i: impl Into<u128>, layout: TyAndLayout<'tcx>) -> Option<Self> {
         Some(Self::from_scalar(Scalar::try_from_uint(i, layout.size)?, layout))
     }
@@ -223,7 +230,6 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
     pub fn try_from_int(i: impl Into<i128>, layout: TyAndLayout<'tcx>) -> Option<Self> {
         Some(Self::from_scalar(Scalar::try_from_int(i, layout.size)?, layout))
     }
-
     #[inline]
     pub fn from_int(i: impl Into<i128>, layout: TyAndLayout<'tcx>) -> Self {
         Self::from_scalar(Scalar::from_int(i, layout.size), layout)
@@ -242,6 +248,20 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
         Self::from_scalar(Scalar::from_i8(c as i8), layout)
     }
 
+    /// Return the immediate as a `ScalarInt`. Ensures that it has the size that the layout of the
+    /// immediate indicates.
+    #[inline]
+    pub fn to_scalar_int(&self) -> InterpResult<'tcx, ScalarInt> {
+        let s = self.to_scalar().to_scalar_int()?;
+        if s.size() != self.layout.size {
+            throw_ub!(ScalarSizeMismatch(ScalarSizeMismatch {
+                target_size: self.layout.size.bytes(),
+                data_size: s.size().bytes(),
+            }));
+        }
+        Ok(s)
+    }
+
     #[inline]
     pub fn to_const_int(self) -> ConstInt {
         assert!(self.layout.ty.is_integral());
diff --git a/compiler/rustc_const_eval/src/interpret/operator.rs b/compiler/rustc_const_eval/src/interpret/operator.rs
index 5665bb4999f..9af755e40de 100644
--- a/compiler/rustc_const_eval/src/interpret/operator.rs
+++ b/compiler/rustc_const_eval/src/interpret/operator.rs
@@ -2,7 +2,7 @@ use rustc_apfloat::{Float, FloatConvert};
 use rustc_middle::mir;
 use rustc_middle::mir::interpret::{InterpResult, Scalar};
 use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
-use rustc_middle::ty::{self, FloatTy, Ty};
+use rustc_middle::ty::{self, FloatTy, ScalarInt, Ty};
 use rustc_span::symbol::sym;
 use rustc_target::abi::Abi;
 
@@ -146,14 +146,20 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
     fn binary_int_op(
         &self,
         bin_op: mir::BinOp,
-        // passing in raw bits
-        l: u128,
-        left_layout: TyAndLayout<'tcx>,
-        r: u128,
-        right_layout: TyAndLayout<'tcx>,
+        left: &ImmTy<'tcx, M::Provenance>,
+        right: &ImmTy<'tcx, M::Provenance>,
     ) -> InterpResult<'tcx, (ImmTy<'tcx, M::Provenance>, bool)> {
         use rustc_middle::mir::BinOp::*;
 
+        // This checks the size, so that we can just assert it below.
+        let l = left.to_scalar_int()?;
+        let r = right.to_scalar_int()?;
+        // Prepare to convert the values to signed or unsigned form.
+        let l_signed = || l.assert_int(left.layout.size);
+        let l_unsigned = || l.assert_uint(left.layout.size);
+        let r_signed = || r.assert_int(right.layout.size);
+        let r_unsigned = || r.assert_uint(right.layout.size);
+
         let throw_ub_on_overflow = match bin_op {
             AddUnchecked => Some(sym::unchecked_add),
             SubUnchecked => Some(sym::unchecked_sub),
@@ -165,69 +171,72 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
 
         // Shift ops can have an RHS with a different numeric type.
         if matches!(bin_op, Shl | ShlUnchecked | Shr | ShrUnchecked) {
-            let size = left_layout.size.bits();
+            let size = left.layout.size.bits();
             // The shift offset is implicitly masked to the type size. (This is the one MIR operator
             // that does *not* directly map to a single LLVM operation.) Compute how much we
             // actually shift and whether there was an overflow due to shifting too much.
-            let (shift_amount, overflow) = if right_layout.abi.is_signed() {
-                let shift_amount = self.sign_extend(r, right_layout) as i128;
+            let (shift_amount, overflow) = if right.layout.abi.is_signed() {
+                let shift_amount = r_signed();
                 let overflow = shift_amount < 0 || shift_amount >= i128::from(size);
+                // Deliberately wrapping `as` casts: shift_amount *can* be negative, but the result
+                // of the `as` will be equal modulo `size` (since it is a power of two).
                 let masked_amount = (shift_amount as u128) % u128::from(size);
-                debug_assert_eq!(overflow, shift_amount != (masked_amount as i128));
+                assert_eq!(overflow, shift_amount != (masked_amount as i128));
                 (masked_amount, overflow)
             } else {
-                let shift_amount = r;
+                let shift_amount = r_unsigned();
                 let masked_amount = shift_amount % u128::from(size);
                 (masked_amount, shift_amount != masked_amount)
             };
             let shift_amount = u32::try_from(shift_amount).unwrap(); // we masked so this will always fit
             // Compute the shifted result.
-            let result = if left_layout.abi.is_signed() {
-                let l = self.sign_extend(l, left_layout) as i128;
+            let result = if left.layout.abi.is_signed() {
+                let l = l_signed();
                 let result = match bin_op {
                     Shl | ShlUnchecked => l.checked_shl(shift_amount).unwrap(),
                     Shr | ShrUnchecked => l.checked_shr(shift_amount).unwrap(),
                     _ => bug!(),
                 };
-                result as u128
+                ScalarInt::truncate_from_int(result, left.layout.size).0
             } else {
-                match bin_op {
+                let l = l_unsigned();
+                let result = match bin_op {
                     Shl | ShlUnchecked => l.checked_shl(shift_amount).unwrap(),
                     Shr | ShrUnchecked => l.checked_shr(shift_amount).unwrap(),
                     _ => bug!(),
-                }
+                };
+                ScalarInt::truncate_from_uint(result, left.layout.size).0
             };
-            let truncated = self.truncate(result, left_layout);
 
             if overflow && let Some(intrinsic_name) = throw_ub_on_overflow {
                 throw_ub_custom!(
                     fluent::const_eval_overflow_shift,
-                    val = if right_layout.abi.is_signed() {
-                        (self.sign_extend(r, right_layout) as i128).to_string()
+                    val = if right.layout.abi.is_signed() {
+                        r_signed().to_string()
                     } else {
-                        r.to_string()
+                        r_unsigned().to_string()
                     },
                     name = intrinsic_name
                 );
             }
 
-            return Ok((ImmTy::from_uint(truncated, left_layout), overflow));
+            return Ok((ImmTy::from_scalar_int(result, left.layout), overflow));
         }
 
         // For the remaining ops, the types must be the same on both sides
-        if left_layout.ty != right_layout.ty {
+        if left.layout.ty != right.layout.ty {
             span_bug!(
                 self.cur_span(),
                 "invalid asymmetric binary op {bin_op:?}: {l:?} ({l_ty}), {r:?} ({r_ty})",
-                l_ty = left_layout.ty,
-                r_ty = right_layout.ty,
+                l_ty = left.layout.ty,
+                r_ty = right.layout.ty,
             )
         }
 
-        let size = left_layout.size;
+        let size = left.layout.size;
 
         // Operations that need special treatment for signed integers
-        if left_layout.abi.is_signed() {
+        if left.layout.abi.is_signed() {
             let op: Option<fn(&i128, &i128) -> bool> = match bin_op {
                 Lt => Some(i128::lt),
                 Le => Some(i128::le),
@@ -236,18 +245,14 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                 _ => None,
             };
             if let Some(op) = op {
-                let l = self.sign_extend(l, left_layout) as i128;
-                let r = self.sign_extend(r, right_layout) as i128;
-                return Ok((ImmTy::from_bool(op(&l, &r), *self.tcx), false));
+                return Ok((ImmTy::from_bool(op(&l_signed(), &r_signed()), *self.tcx), false));
             }
             if bin_op == Cmp {
-                let l = self.sign_extend(l, left_layout) as i128;
-                let r = self.sign_extend(r, right_layout) as i128;
-                return Ok(self.three_way_compare(l, r));
+                return Ok(self.three_way_compare(l_signed(), r_signed()));
             }
             let op: Option<fn(i128, i128) -> (i128, bool)> = match bin_op {
-                Div if r == 0 => throw_ub!(DivisionByZero),
-                Rem if r == 0 => throw_ub!(RemainderByZero),
+                Div if r.is_null() => throw_ub!(DivisionByZero),
+                Rem if r.is_null() => throw_ub!(RemainderByZero),
                 Div => Some(i128::overflowing_div),
                 Rem => Some(i128::overflowing_rem),
                 Add | AddUnchecked => Some(i128::overflowing_add),
@@ -256,8 +261,8 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                 _ => None,
             };
             if let Some(op) = op {
-                let l = self.sign_extend(l, left_layout) as i128;
-                let r = self.sign_extend(r, right_layout) as i128;
+                let l = l_signed();
+                let r = r_signed();
 
                 // We need a special check for overflowing Rem and Div since they are *UB*
                 // on overflow, which can happen with "int_min $OP -1".
@@ -272,17 +277,19 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                 }
 
                 let (result, oflo) = op(l, r);
-                // This may be out-of-bounds for the result type, so we have to truncate ourselves.
+                // This may be out-of-bounds for the result type, so we have to truncate.
                 // If that truncation loses any information, we have an overflow.
-                let result = result as u128;
-                let truncated = self.truncate(result, left_layout);
-                let overflow = oflo || self.sign_extend(truncated, left_layout) != result;
+                let (result, lossy) = ScalarInt::truncate_from_int(result, left.layout.size);
+                let overflow = oflo || lossy;
                 if overflow && let Some(intrinsic_name) = throw_ub_on_overflow {
                     throw_ub_custom!(fluent::const_eval_overflow, name = intrinsic_name);
                 }
-                return Ok((ImmTy::from_uint(truncated, left_layout), overflow));
+                return Ok((ImmTy::from_scalar_int(result, left.layout), overflow));
             }
         }
+        // From here on it's okay to treat everything as unsigned.
+        let l = l_unsigned();
+        let r = r_unsigned();
 
         if bin_op == Cmp {
             return Ok(self.three_way_compare(l, r));
@@ -297,12 +304,12 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
             Gt => ImmTy::from_bool(l > r, *self.tcx),
             Ge => ImmTy::from_bool(l >= r, *self.tcx),
 
-            BitOr => ImmTy::from_uint(l | r, left_layout),
-            BitAnd => ImmTy::from_uint(l & r, left_layout),
-            BitXor => ImmTy::from_uint(l ^ r, left_layout),
+            BitOr => ImmTy::from_uint(l | r, left.layout),
+            BitAnd => ImmTy::from_uint(l & r, left.layout),
+            BitXor => ImmTy::from_uint(l ^ r, left.layout),
 
             Add | AddUnchecked | Sub | SubUnchecked | Mul | MulUnchecked | Rem | Div => {
-                assert!(!left_layout.abi.is_signed());
+                assert!(!left.layout.abi.is_signed());
                 let op: fn(u128, u128) -> (u128, bool) = match bin_op {
                     Add | AddUnchecked => u128::overflowing_add,
                     Sub | SubUnchecked => u128::overflowing_sub,
@@ -316,21 +323,21 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                 let (result, oflo) = op(l, r);
                 // Truncate to target type.
                 // If that truncation loses any information, we have an overflow.
-                let truncated = self.truncate(result, left_layout);
-                let overflow = oflo || truncated != result;
+                let (result, lossy) = ScalarInt::truncate_from_uint(result, left.layout.size);
+                let overflow = oflo || lossy;
                 if overflow && let Some(intrinsic_name) = throw_ub_on_overflow {
                     throw_ub_custom!(fluent::const_eval_overflow, name = intrinsic_name);
                 }
-                return Ok((ImmTy::from_uint(truncated, left_layout), overflow));
+                return Ok((ImmTy::from_scalar_int(result, left.layout), overflow));
             }
 
             _ => span_bug!(
                 self.cur_span(),
                 "invalid binary op {:?}: {:?}, {:?} (both {})",
                 bin_op,
-                l,
-                r,
-                right_layout.ty,
+                left,
+                right,
+                right.layout.ty,
             ),
         };
 
@@ -427,9 +434,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
                     right.layout.ty
                 );
 
-                let l = left.to_scalar().to_bits(left.layout.size)?;
-                let r = right.to_scalar().to_bits(right.layout.size)?;
-                self.binary_int_op(bin_op, l, left.layout, r, right.layout)
+                self.binary_int_op(bin_op, left, right)
             }
             _ if left.layout.ty.is_any_ptr() => {
                 // The RHS type must be a `pointer` *or an integer type* (for `Offset`).