about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/miri/src/shims/x86/sse41.rs80
-rw-r--r--src/tools/miri/tests/pass/intrinsics-x86-sse41.rs60
2 files changed, 124 insertions, 16 deletions
diff --git a/src/tools/miri/src/shims/x86/sse41.rs b/src/tools/miri/src/shims/x86/sse41.rs
index cfa06ded6e6..523f3bfc26f 100644
--- a/src/tools/miri/src/shims/x86/sse41.rs
+++ b/src/tools/miri/src/shims/x86/sse41.rs
@@ -148,6 +148,14 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
 
                 round_first::<rustc_apfloat::ieee::Single>(this, left, right, rounding, dest)?;
             }
+            // Used to implement the _mm_floor_ps, _mm_ceil_ps and _mm_round_ps
+            // functions. Rounds the elements of `op` according to `rounding`.
+            "round.ps" => {
+                let [op, rounding] =
+                    this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
+
+                round_all::<rustc_apfloat::ieee::Single>(this, op, rounding, dest)?;
+            }
             // Used to implement the _mm_floor_sd, _mm_ceil_sd and _mm_round_sd
             // functions. Rounds the first element of `right` according to `rounding`
             // and copies the remaining elements from `left`.
@@ -157,6 +165,14 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
 
                 round_first::<rustc_apfloat::ieee::Double>(this, left, right, rounding, dest)?;
             }
+            // Used to implement the _mm_floor_pd, _mm_ceil_pd and _mm_round_pd
+            // functions. Rounds the elements of `op` according to `rounding`.
+            "round.pd" => {
+                let [op, rounding] =
+                    this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
+
+                round_all::<rustc_apfloat::ieee::Double>(this, op, rounding, dest)?;
+            }
             // Used to implement the _mm_minpos_epu16 function.
             // Find the minimum unsinged 16-bit integer in `op` and
             // returns its value and position.
@@ -283,22 +299,7 @@ fn round_first<'tcx, F: rustc_apfloat::Float>(
     assert_eq!(dest_len, left_len);
     assert_eq!(dest_len, right_len);
 
-    // The fourth bit of `rounding` only affects the SSE status
-    // register, which cannot be accessed from Miri (or from Rust,
-    // for that matter), so we can ignore it.
-    let rounding = match this.read_scalar(rounding)?.to_i32()? & !0b1000 {
-        // When the third bit is 0, the rounding mode is determined by the
-        // first two bits.
-        0b000 => rustc_apfloat::Round::NearestTiesToEven,
-        0b001 => rustc_apfloat::Round::TowardNegative,
-        0b010 => rustc_apfloat::Round::TowardPositive,
-        0b011 => rustc_apfloat::Round::TowardZero,
-        // When the third bit is 1, the rounding mode is determined by the
-        // SSE status register. Since we do not support modifying it from
-        // Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
-        0b100..=0b111 => rustc_apfloat::Round::NearestTiesToEven,
-        rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
-    };
+    let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;
 
     let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?;
     let res = op0.round_to_integral(rounding).value;
@@ -317,3 +318,50 @@ fn round_first<'tcx, F: rustc_apfloat::Float>(
 
     Ok(())
 }
+
+// Rounds all elements of `op` according to `rounding`.
+fn round_all<'tcx, F: rustc_apfloat::Float>(
+    this: &mut crate::MiriInterpCx<'_, 'tcx>,
+    op: &OpTy<'tcx, Provenance>,
+    rounding: &OpTy<'tcx, Provenance>,
+    dest: &PlaceTy<'tcx, Provenance>,
+) -> InterpResult<'tcx, ()> {
+    let (op, op_len) = this.operand_to_simd(op)?;
+    let (dest, dest_len) = this.place_to_simd(dest)?;
+
+    assert_eq!(dest_len, op_len);
+
+    let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;
+
+    for i in 0..dest_len {
+        let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?;
+        let res = op.round_to_integral(rounding).value;
+        this.write_scalar(
+            Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
+            &this.project_index(&dest, i)?,
+        )?;
+    }
+
+    Ok(())
+}
+
+/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
+/// `round.{ss,sd,ps,pd}` intrinsics.
+fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
+    // The fourth bit of `rounding` only affects the SSE status
+    // register, which cannot be accessed from Miri (or from Rust,
+    // for that matter), so we can ignore it.
+    match rounding & !0b1000 {
+        // When the third bit is 0, the rounding mode is determined by the
+        // first two bits.
+        0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven),
+        0b001 => Ok(rustc_apfloat::Round::TowardNegative),
+        0b010 => Ok(rustc_apfloat::Round::TowardPositive),
+        0b011 => Ok(rustc_apfloat::Round::TowardZero),
+        // When the third bit is 1, the rounding mode is determined by the
+        // SSE status register. Since we do not support modifying it from
+        // Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
+        0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven),
+        rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
+    }
+}
diff --git a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs
index d5489ffaf4b..db106bb9833 100644
--- a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs
+++ b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs
@@ -148,6 +148,36 @@ unsafe fn test_sse41() {
     test_mm_round_sd();
 
     #[target_feature(enable = "sse4.1")]
+    unsafe fn test_mm_round_pd() {
+        let a = _mm_setr_pd(-1.75, -4.25);
+        let r = _mm_round_pd::<_MM_FROUND_TO_NEAREST_INT>(a);
+        let e = _mm_setr_pd(-2.0, -4.0);
+        assert_eq_m128d(r, e);
+
+        let a = _mm_setr_pd(-1.75, -4.25);
+        let r = _mm_round_pd::<_MM_FROUND_TO_NEG_INF>(a);
+        let e = _mm_setr_pd(-2.0, -5.0);
+        assert_eq_m128d(r, e);
+
+        let a = _mm_setr_pd(-1.75, -4.25);
+        let r = _mm_round_pd::<_MM_FROUND_TO_POS_INF>(a);
+        let e = _mm_setr_pd(-1.0, -4.0);
+        assert_eq_m128d(r, e);
+
+        let a = _mm_setr_pd(-1.75, -4.25);
+        let r = _mm_round_pd::<_MM_FROUND_TO_ZERO>(a);
+        let e = _mm_setr_pd(-1.0, -4.0);
+        assert_eq_m128d(r, e);
+
+        // Assume round-to-nearest by default
+        let a = _mm_setr_pd(-1.75, -4.25);
+        let r = _mm_round_pd::<_MM_FROUND_CUR_DIRECTION>(a);
+        let e = _mm_setr_pd(-2.0, -4.0);
+        assert_eq_m128d(r, e);
+    }
+    test_mm_round_pd();
+
+    #[target_feature(enable = "sse4.1")]
     unsafe fn test_mm_round_ss() {
         let a = _mm_setr_ps(1.5, 3.5, 7.5, 15.5);
         let b = _mm_setr_ps(-1.75, -4.5, -8.5, -16.5);
@@ -183,6 +213,36 @@ unsafe fn test_sse41() {
     test_mm_round_ss();
 
     #[target_feature(enable = "sse4.1")]
+    unsafe fn test_mm_round_ps() {
+        let a = _mm_setr_ps(-1.75, -4.25, -8.5, -16.5);
+        let r = _mm_round_ps::<_MM_FROUND_TO_NEAREST_INT>(a);
+        let e = _mm_setr_ps(-2.0, -4.0, -8.0, -16.0);
+        assert_eq_m128(r, e);
+
+        let a = _mm_setr_ps(-1.75, -4.25, -8.5, -16.5);
+        let r = _mm_round_ps::<_MM_FROUND_TO_NEG_INF>(a);
+        let e = _mm_setr_ps(-2.0, -5.0, -9.0, -17.0);
+        assert_eq_m128(r, e);
+
+        let a = _mm_setr_ps(-1.75, -4.25, -8.5, -16.5);
+        let r = _mm_round_ps::<_MM_FROUND_TO_POS_INF>(a);
+        let e = _mm_setr_ps(-1.0, -4.0, -8.0, -16.0);
+        assert_eq_m128(r, e);
+
+        let a = _mm_setr_ps(-1.75, -4.25, -8.5, -16.5);
+        let r = _mm_round_ps::<_MM_FROUND_TO_ZERO>(a);
+        let e = _mm_setr_ps(-1.0, -4.0, -8.0, -16.0);
+        assert_eq_m128(r, e);
+
+        // Assume round-to-nearest by default
+        let a = _mm_setr_ps(-1.75, -4.25, -8.5, -16.5);
+        let r = _mm_round_ps::<_MM_FROUND_CUR_DIRECTION>(a);
+        let e = _mm_setr_ps(-2.0, -4.0, -8.0, -16.0);
+        assert_eq_m128(r, e);
+    }
+    test_mm_round_ps();
+
+    #[target_feature(enable = "sse4.1")]
     unsafe fn test_mm_minpos_epu16() {
         let a = _mm_setr_epi16(23, 18, 44, 97, 50, 13, 67, 66);
         let r = _mm_minpos_epu16(a);