about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-12-09 06:34:35 +0000
committerbors <bors@rust-lang.org>2023-12-09 06:34:35 +0000
commitaa7afff8c619fc9c04515b7cef57d72dc3c7e51e (patch)
tree068bfacc8fa23c7b63860366ab60617980a0e98e
parent57935c35aac3f389eb406ac5d33b7953daeba7d1 (diff)
parentd57125632e78474e2623697fd77b5762df5b5454 (diff)
downloadrust-aa7afff8c619fc9c04515b7cef57d72dc3c7e51e.tar.gz
rust-aa7afff8c619fc9c04515b7cef57d72dc3c7e51e.zip
Auto merge of #3216 - eduardosm:fix-ptestnzc, r=RalfJung
Fix x86 SSE4.1 ptestnzc

Fixes ptestnzc by bringing back the original implementation of https://github.com/rust-lang/miri/pull/3214.

`(op & mask) != 0 && (op & mask) == !ask` need to be calculated for the whole vector. It cannot be calculated for each element and then folded.

For example, given
* `op = [0b100, 0b010]`
* `mask = [0b100, 0b110]`

The correct result would be:
* `op & mask = [0b100, 0b010]`
Comparisons are done on the vector as a whole:
* `all_zero = (op & mask) == [0, 0] = false`
* `masked_set = (op & mask) == mask = false`
* `!all_zero && !masked_set = true` correct result

The previous method:
* `op & mask = [0b100, 0b010]`
Comparisons are done element-wise:
* `all_zero = (op & mask) == [0, 0] = [true, true]`
* `masked_set = (op & mask) == mask = [true, false]`
* `!all_zero && !masked_set = [true, false]`
 After folding with AND, the final result would be `false`, which is incorrect.
-rw-r--r--src/tools/miri/src/shims/x86/mod.rs49
-rw-r--r--src/tools/miri/src/shims/x86/sse41.rs23
-rw-r--r--src/tools/miri/tests/pass/intrinsics-x86-sse41.rs5
3 files changed, 41 insertions, 36 deletions
diff --git a/src/tools/miri/src/shims/x86/mod.rs b/src/tools/miri/src/shims/x86/mod.rs
index d8c3b4826a9..1aaf820f460 100644
--- a/src/tools/miri/src/shims/x86/mod.rs
+++ b/src/tools/miri/src/shims/x86/mod.rs
@@ -666,30 +666,33 @@ fn conditional_dot_product<'tcx>(
     Ok(())
 }
 
-/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
-fn bin_op_folded<'tcx, T>(
+/// Calculates two booleans.
+///
+/// The first is true when all the bits of `op & mask` are zero.
+/// The second is true when `(op & mask) == mask`
+fn test_bits_masked<'tcx>(
     this: &crate::MiriInterpCx<'_, 'tcx>,
-    lhs: &OpTy<'tcx, Provenance>,
-    rhs: &OpTy<'tcx, Provenance>,
-    init: T,
-    mut f: impl FnMut(T, ImmTy<'tcx, Provenance>, ImmTy<'tcx, Provenance>) -> InterpResult<'tcx, T>,
-) -> InterpResult<'tcx, T> {
-    assert_eq!(lhs.layout, rhs.layout);
-
-    let (lhs, lhs_len) = this.operand_to_simd(lhs)?;
-    let (rhs, rhs_len) = this.operand_to_simd(rhs)?;
-
-    assert_eq!(lhs_len, rhs_len);
-
-    let mut acc = init;
-    for i in 0..lhs_len {
-        let lhs = this.project_index(&lhs, i)?;
-        let rhs = this.project_index(&rhs, i)?;
-
-        let lhs = this.read_immediate(&lhs)?;
-        let rhs = this.read_immediate(&rhs)?;
-        acc = f(acc, lhs, rhs)?;
+    op: &OpTy<'tcx, Provenance>,
+    mask: &OpTy<'tcx, Provenance>,
+) -> InterpResult<'tcx, (bool, bool)> {
+    assert_eq!(op.layout, mask.layout);
+
+    let (op, op_len) = this.operand_to_simd(op)?;
+    let (mask, mask_len) = this.operand_to_simd(mask)?;
+
+    assert_eq!(op_len, mask_len);
+
+    let mut all_zero = true;
+    let mut masked_set = true;
+    for i in 0..op_len {
+        let op = this.project_index(&op, i)?;
+        let mask = this.project_index(&mask, i)?;
+
+        let op = this.read_scalar(&op)?.to_uint(op.layout.size)?;
+        let mask = this.read_scalar(&mask)?.to_uint(mask.layout.size)?;
+        all_zero &= (op & mask) == 0;
+        masked_set &= (op & mask) == mask;
     }
 
-    Ok(acc)
+    Ok((all_zero, masked_set))
 }
diff --git a/src/tools/miri/src/shims/x86/sse41.rs b/src/tools/miri/src/shims/x86/sse41.rs
index 08e3404a224..67bb63f0a3d 100644
--- a/src/tools/miri/src/shims/x86/sse41.rs
+++ b/src/tools/miri/src/shims/x86/sse41.rs
@@ -1,7 +1,7 @@
 use rustc_span::Symbol;
 use rustc_target::spec::abi::Abi;
 
-use super::{bin_op_folded, conditional_dot_product, round_all, round_first};
+use super::{conditional_dot_product, round_all, round_first, test_bits_masked};
 use crate::*;
 use shims::foreign_items::EmulateForeignItemResult;
 
@@ -217,21 +217,18 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
             }
             // Used to implement the _mm_testz_si128, _mm_testc_si128
             // and _mm_testnzc_si128 functions.
-            // Tests `op & mask == 0`, `op & mask == mask` or
-            // `op & mask != 0 && op & mask != mask`
+            // Tests `(op & mask) == 0`, `(op & mask) == mask` or
+            // `(op & mask) != 0 && (op & mask) != mask`
             "ptestz" | "ptestc" | "ptestnzc" => {
                 let [op, mask] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
 
-                let res = bin_op_folded(this, op, mask, true, |acc, op, mask| {
-                    let op = op.to_scalar().to_uint(op.layout.size)?;
-                    let mask = mask.to_scalar().to_uint(mask.layout.size)?;
-                    Ok(match unprefixed_name {
-                        "ptestz" => acc && (op & mask) == 0,
-                        "ptestc" => acc && (op & mask) == mask,
-                        "ptestnzc" => acc && (op & mask) != 0 && (op & mask) != mask,
-                        _ => unreachable!(),
-                    })
-                })?;
+                let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
+                let res = match unprefixed_name {
+                    "ptestz" => all_zero,
+                    "ptestc" => masked_set,
+                    "ptestnzc" => !all_zero && !masked_set,
+                    _ => unreachable!(),
+                };
 
                 this.write_scalar(Scalar::from_i32(res.into()), dest)?;
             }
diff --git a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs
index 13856d29d3f..06607f3fd59 100644
--- a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs
+++ b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs
@@ -515,6 +515,11 @@ unsafe fn test_sse41() {
         let mask = _mm_set1_epi8(0b101);
         let r = _mm_testnzc_si128(a, mask);
         assert_eq!(r, 0);
+
+        let a = _mm_setr_epi32(0b100, 0, 0, 0b010);
+        let mask = _mm_setr_epi32(0b100, 0, 0, 0b110);
+        let r = _mm_testnzc_si128(a, mask);
+        assert_eq!(r, 1);
     }
     test_mm_testnzc_si128();
 }