about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2023-12-22 12:25:46 +0100
committerRalf Jung <post@ralfj.de>2023-12-22 12:25:53 +0100
commite8a4bd17f3e3821e31db58b50facdc8ed134852a (patch)
tree65ea2710c12da8bac7011ffaa4c3554ab86c9ae5
parent84304fc00a2b582b1797a52dfaa954ff36b4ff91 (diff)
downloadrust-e8a4bd17f3e3821e31db58b50facdc8ed134852a.tar.gz
rust-e8a4bd17f3e3821e31db58b50facdc8ed134852a.zip
implement and test simd_masked_load and simd_masked_store
-rw-r--r--src/tools/miri/src/shims/intrinsics/simd.rs48
-rw-r--r--src/tools/miri/tests/pass/portable-simd.rs24
2 files changed, 72 insertions, 0 deletions
diff --git a/src/tools/miri/src/shims/intrinsics/simd.rs b/src/tools/miri/src/shims/intrinsics/simd.rs
index e17c06be9b8..2c8493d8aad 100644
--- a/src/tools/miri/src/shims/intrinsics/simd.rs
+++ b/src/tools/miri/src/shims/intrinsics/simd.rs
@@ -656,6 +656,54 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     }
                 }
             }
+            "masked_load" => {
+                let [mask, ptr, default] = check_arg_count(args)?;
+                let (mask, mask_len) = this.operand_to_simd(mask)?;
+                let ptr = this.read_pointer(ptr)?;
+                let (default, default_len) = this.operand_to_simd(default)?;
+                let (dest, dest_len) = this.place_to_simd(dest)?;
+
+                assert_eq!(dest_len, mask_len);
+                assert_eq!(dest_len, default_len);
+
+                for i in 0..dest_len {
+                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
+                    let default = this.read_immediate(&this.project_index(&default, i)?)?;
+                    let dest = this.project_index(&dest, i)?;
+
+                    let val = if simd_element_to_bool(mask)? {
+                        // Size * u64 is implemented as always checked
+                        #[allow(clippy::arithmetic_side_effects)]
+                        let ptr = ptr.wrapping_offset(dest.layout.size * i, this);
+                        let place = this.ptr_to_mplace(ptr, dest.layout);
+                        this.read_immediate(&place)?
+                    } else {
+                        default
+                    };
+                    this.write_immediate(*val, &dest)?;
+                }
+            }
+            "masked_store" => {
+                let [mask, ptr, vals] = check_arg_count(args)?;
+                let (mask, mask_len) = this.operand_to_simd(mask)?;
+                let ptr = this.read_pointer(ptr)?;
+                let (vals, vals_len) = this.operand_to_simd(vals)?;
+
+                assert_eq!(mask_len, vals_len);
+
+                for i in 0..vals_len {
+                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
+                    let val = this.read_immediate(&this.project_index(&vals, i)?)?;
+
+                    if simd_element_to_bool(mask)? {
+                        // Size * u64 is implemented as always checked
+                        #[allow(clippy::arithmetic_side_effects)]
+                        let ptr = ptr.wrapping_offset(val.layout.size * i, this);
+                        let place = this.ptr_to_mplace(ptr, val.layout);
+                        this.write_immediate(*val, &place)?
+                    };
+                }
+            }
 
             name => throw_unsup_format!("unimplemented intrinsic: `simd_{name}`"),
         }
diff --git a/src/tools/miri/tests/pass/portable-simd.rs b/src/tools/miri/tests/pass/portable-simd.rs
index 3d24943293c..57d0b6a87b2 100644
--- a/src/tools/miri/tests/pass/portable-simd.rs
+++ b/src/tools/miri/tests/pass/portable-simd.rs
@@ -536,6 +536,29 @@ fn simd_intrinsics() {
     }
 }
 
+fn simd_masked_loadstore() {
+    // The buffer is deliberarely too short, so reading the last element would be UB.
+    let buf = [3i32; 3];
+    let default = i32x4::splat(0);
+    let mask = i32x4::from_array([!0, !0, !0, 0]);
+    let vals = unsafe { intrinsics::simd_masked_load(mask, buf.as_ptr(), default) };
+    assert_eq!(vals, i32x4::from_array([3, 3, 3, 0]));
+    // Also read in a way that the *first* element is OOB.
+    let mask2 = i32x4::from_array([0, !0, !0, !0]);
+    let vals =
+        unsafe { intrinsics::simd_masked_load(mask2, buf.as_ptr().wrapping_sub(1), default) };
+    assert_eq!(vals, i32x4::from_array([0, 3, 3, 3]));
+
+    // The buffer is deliberarely too short, so writing the last element would be UB.
+    let mut buf = [42i32; 3];
+    let vals = i32x4::from_array([1, 2, 3, 4]);
+    unsafe { intrinsics::simd_masked_store(mask, buf.as_mut_ptr(), vals) };
+    assert_eq!(buf, [1, 2, 3]);
+    // Also write in a way that the *first* element is OOB.
+    unsafe { intrinsics::simd_masked_store(mask2, buf.as_mut_ptr().wrapping_sub(1), vals) };
+    assert_eq!(buf, [2, 3, 4]);
+}
+
 fn main() {
     simd_mask();
     simd_ops_f32();
@@ -546,4 +569,5 @@ fn main() {
     simd_gather_scatter();
     simd_round();
     simd_intrinsics();
+    simd_masked_loadstore();
 }