about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/miri/src/helpers.rs20
-rw-r--r--src/tools/miri/src/shims/x86/sse2.rs74
2 files changed, 54 insertions, 40 deletions
diff --git a/src/tools/miri/src/helpers.rs b/src/tools/miri/src/helpers.rs
index 7805fe25bcd..bbd905be0a9 100644
--- a/src/tools/miri/src/helpers.rs
+++ b/src/tools/miri/src/helpers.rs
@@ -14,7 +14,7 @@ use rustc_middle::mir;
 use rustc_middle::ty::{
     self,
     layout::{IntegerExt as _, LayoutOf, TyAndLayout},
-    Ty, TyCtxt,
+    IntTy, Ty, TyCtxt, UintTy,
 };
 use rustc_span::{def_id::CrateNum, sym, Span, Symbol};
 use rustc_target::abi::{Align, FieldIdx, FieldsShape, Integer, Size, Variants};
@@ -1067,6 +1067,24 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 ),
         }
     }
+
+    /// Returns an integer type that is twice wide as `ty`
+    fn get_twice_wide_int_ty(&self, ty: Ty<'tcx>) -> Ty<'tcx> {
+        let this = self.eval_context_ref();
+        match ty.kind() {
+            // Unsigned
+            ty::Uint(UintTy::U8) => this.tcx.types.u16,
+            ty::Uint(UintTy::U16) => this.tcx.types.u32,
+            ty::Uint(UintTy::U32) => this.tcx.types.u64,
+            ty::Uint(UintTy::U64) => this.tcx.types.u128,
+            // Signed
+            ty::Int(IntTy::I8) => this.tcx.types.i16,
+            ty::Int(IntTy::I16) => this.tcx.types.i32,
+            ty::Int(IntTy::I32) => this.tcx.types.i64,
+            ty::Int(IntTy::I64) => this.tcx.types.i128,
+            _ => span_bug!(this.cur_span(), "unexpected type: {ty:?}"),
+        }
+    }
 }
 
 impl<'mir, 'tcx> MiriMachine<'mir, 'tcx> {
diff --git a/src/tools/miri/src/shims/x86/sse2.rs b/src/tools/miri/src/shims/x86/sse2.rs
index 098409d6e35..28aebc8cba8 100644
--- a/src/tools/miri/src/shims/x86/sse2.rs
+++ b/src/tools/miri/src/shims/x86/sse2.rs
@@ -2,6 +2,7 @@ use rustc_apfloat::{
     ieee::{Double, Single},
     Float as _,
 };
+use rustc_middle::mir;
 use rustc_middle::ty::layout::LayoutOf as _;
 use rustc_middle::ty::Ty;
 use rustc_span::Symbol;
@@ -36,9 +37,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         // Intrinsincs sufixed with "epiX" or "epuX" operate with X-bit signed or unsigned
         // vectors.
         match unprefixed_name {
-            // Used to implement the _mm_avg_epu8 function.
-            // Averages packed unsigned 8-bit integers in `left` and `right`.
-            "pavg.b" => {
+            // Used to implement the _mm_avg_epu8 and _mm_avg_epu16 functions.
+            // Averages packed unsigned 8/16-bit integers in `left` and `right`.
+            "pavg.b" | "pavg.w" => {
                 let [left, right] =
                     this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
 
@@ -50,46 +51,41 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 assert_eq!(dest_len, right_len);
 
                 for i in 0..dest_len {
-                    let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u8()?;
-                    let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
+                    let left = this.read_immediate(&this.project_index(&left, i)?)?;
+                    let right = this.read_immediate(&this.project_index(&right, i)?)?;
                     let dest = this.project_index(&dest, i)?;
 
-                    // Values are expanded from u8 to u16, so adds cannot overflow.
-                    let res = u16::from(left)
-                        .checked_add(u16::from(right))
-                        .unwrap()
-                        .checked_add(1)
-                        .unwrap()
-                        / 2;
-                    this.write_scalar(Scalar::from_u8(res.try_into().unwrap()), &dest)?;
-                }
-            }
-            // Used to implement the _mm_avg_epu16 function.
-            // Averages packed unsigned 16-bit integers in `left` and `right`.
-            "pavg.w" => {
-                let [left, right] =
-                    this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
-
-                let (left, left_len) = this.operand_to_simd(left)?;
-                let (right, right_len) = this.operand_to_simd(right)?;
-                let (dest, dest_len) = this.place_to_simd(dest)?;
-
-                assert_eq!(dest_len, left_len);
-                assert_eq!(dest_len, right_len);
+                    // Widen the operands to avoid overflow
+                    let twice_wide_ty = this.get_twice_wide_int_ty(left.layout.ty);
+                    let twice_wide_layout = this.layout_of(twice_wide_ty)?;
+                    let left = this.int_to_int_or_float(&left, twice_wide_ty)?;
+                    let right = this.int_to_int_or_float(&right, twice_wide_ty)?;
+
+                    // Calculate left + right + 1
+                    let (added, _overflow, _ty) = this.overflowing_binary_op(
+                        mir::BinOp::Add,
+                        &ImmTy::from_immediate(left, twice_wide_layout),
+                        &ImmTy::from_immediate(right, twice_wide_layout),
+                    )?;
+                    let (added, _overflow, _ty) = this.overflowing_binary_op(
+                        mir::BinOp::Add,
+                        &ImmTy::from_scalar(added, twice_wide_layout),
+                        &ImmTy::from_uint(1u32, twice_wide_layout),
+                    )?;
 
-                for i in 0..dest_len {
-                    let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u16()?;
-                    let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u16()?;
-                    let dest = this.project_index(&dest, i)?;
+                    // Calculate (left + right + 1) / 2
+                    let (divided, _overflow, _ty) = this.overflowing_binary_op(
+                        mir::BinOp::Div,
+                        &ImmTy::from_scalar(added, twice_wide_layout),
+                        &ImmTy::from_uint(2u32, twice_wide_layout),
+                    )?;
 
-                    // Values are expanded from u16 to u32, so adds cannot overflow.
-                    let res = u32::from(left)
-                        .checked_add(u32::from(right))
-                        .unwrap()
-                        .checked_add(1)
-                        .unwrap()
-                        / 2;
-                    this.write_scalar(Scalar::from_u16(res.try_into().unwrap()), &dest)?;
+                    // Narrow back to the original type
+                    let res = this.int_to_int_or_float(
+                        &ImmTy::from_scalar(divided, twice_wide_layout),
+                        dest.layout.ty,
+                    )?;
+                    this.write_immediate(res, &dest)?;
                 }
             }
             // Used to implement the _mm_mulhi_epi16 function.