about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2023-12-03 11:28:40 +0100
committerRalf Jung <post@ralfj.de>2023-12-03 11:28:40 +0100
commit11db9de7289d0405f40cfd3cb0f9acda7e5b9564 (patch)
treea828d091fda3e8e626df2c5302baf52b0ab8e29e /src
parent6318e9dd89c6719330b4237db1a70daaacb8c527 (diff)
downloadrust-11db9de7289d0405f40cfd3cb0f9acda7e5b9564.tar.gz
rust-11db9de7289d0405f40cfd3cb0f9acda7e5b9564.zip
simd_select_bitmask: support passing the mask as an array
Diffstat (limited to 'src')
-rw-r--r--src/tools/miri/src/shims/intrinsics/simd.rs15
-rw-r--r--src/tools/miri/tests/pass/portable-simd.rs16
2 files changed, 28 insertions, 3 deletions
diff --git a/src/tools/miri/src/shims/intrinsics/simd.rs b/src/tools/miri/src/shims/intrinsics/simd.rs
index d0a293d5f81..63af0814c8b 100644
--- a/src/tools/miri/src/shims/intrinsics/simd.rs
+++ b/src/tools/miri/src/shims/intrinsics/simd.rs
@@ -386,7 +386,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 let (dest, dest_len) = this.place_to_simd(dest)?;
                 let bitmask_len = dest_len.max(8);
 
-                assert!(mask.layout.ty.is_integral());
                 assert!(bitmask_len <= 64);
                 assert_eq!(bitmask_len, mask.layout.size.bits());
                 assert_eq!(dest_len, yes_len);
@@ -394,8 +393,18 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 let dest_len = u32::try_from(dest_len).unwrap();
                 let bitmask_len = u32::try_from(bitmask_len).unwrap();
 
-                let mask: u64 =
-                    this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap();
+                // The mask can be a single integer or an array.
+                let mask: u64 = match mask.layout.ty.kind() {
+                    ty::Int(..) | ty::Uint(..) =>
+                        this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap(),
+                    ty::Array(elem, _) if matches!(elem.kind(), ty::Uint(ty::UintTy::U8)) => {
+                        let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
+                        let mask = mask.transmute(mask_ty, this)?;
+                        this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap()
+                    }
+                    _ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
+                };
+
                 for i in 0..dest_len {
                     let mask = mask
                         & 1u64
diff --git a/src/tools/miri/tests/pass/portable-simd.rs b/src/tools/miri/tests/pass/portable-simd.rs
index 2179bcf1c38..1ef9d8f38c0 100644
--- a/src/tools/miri/tests/pass/portable-simd.rs
+++ b/src/tools/miri/tests/pass/portable-simd.rs
@@ -247,6 +247,22 @@ fn simd_mask() {
             assert_eq!(bitmask2, [0b0001]);
         }
     }
+
+    // This used to cause an ICE.
+    let bitmask = u8x8::from_array([0b01000101, 0, 0, 0, 0, 0, 0, 0]);
+    assert_eq!(
+        mask32x8::from_bitmask_vector(bitmask),
+        mask32x8::from_array([true, false, true, false, false, false, true, false]),
+    );
+    let bitmask =
+        u8x16::from_array([0b01000101, 0b11110000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
+    assert_eq!(
+        mask32x16::from_bitmask_vector(bitmask),
+        mask32x16::from_array([
+            true, false, true, false, false, false, true, false, false, false, false, false, true,
+            true, true, true,
+        ]),
+    );
 }
 
 fn simd_cast() {