about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-12-03 15:21:45 +0000
committerbors <bors@rust-lang.org>2023-12-03 15:21:45 +0000
commit6da09594e6db274f8d0d18f42c4067c610b6520a (patch)
tree954d167b9ad02d732849eebb835a4315c8f68cf0
parent28f9fe326253593442d175a847307a4b1a1063d2 (diff)
parent6e74d2ad500476007642982cbc6f2db6b6655300 (diff)
downloadrust-6da09594e6db274f8d0d18f42c4067c610b6520a.tar.gz
rust-6da09594e6db274f8d0d18f42c4067c610b6520a.zip
Auto merge of #3205 - RalfJung:simd-bitmask, r=RalfJung
also test simd_select_bitmask on arrays for less than 8 elements
-rw-r--r--src/tools/miri/src/shims/intrinsics/simd.rs88
-rw-r--r--src/tools/miri/tests/pass/portable-simd.rs54
2 files changed, 94 insertions, 48 deletions
diff --git a/src/tools/miri/src/shims/intrinsics/simd.rs b/src/tools/miri/src/shims/intrinsics/simd.rs
index e16d116f621..4d30eebb1af 100644
--- a/src/tools/miri/src/shims/intrinsics/simd.rs
+++ b/src/tools/miri/src/shims/intrinsics/simd.rs
@@ -405,6 +405,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     this.write_immediate(*val, &dest)?;
                 }
             }
+            // Variant of `select` that takes a bitmask rather than a "vector of bool".
             "select_bitmask" => {
                 let [mask, yes, no] = check_arg_count(args)?;
                 let (yes, yes_len) = this.operand_to_simd(yes)?;
@@ -412,6 +413,11 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 let (dest, dest_len) = this.place_to_simd(dest)?;
                 let bitmask_len = dest_len.max(8);
 
+                // The mask must be an integer or an array.
+                assert!(
+                    mask.layout.ty.is_integral()
+                        || matches!(mask.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
+                );
                 assert!(bitmask_len <= 64);
                 assert_eq!(bitmask_len, mask.layout.size.bits());
                 assert_eq!(dest_len, yes_len);
@@ -419,23 +425,15 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 let dest_len = u32::try_from(dest_len).unwrap();
                 let bitmask_len = u32::try_from(bitmask_len).unwrap();
 
-                // The mask can be a single integer or an array.
-                let mask: u64 = match mask.layout.ty.kind() {
-                    ty::Int(..) | ty::Uint(..) =>
-                        this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap(),
-                    ty::Array(elem, _) if matches!(elem.kind(), ty::Uint(ty::UintTy::U8)) => {
-                        let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
-                        let mask = mask.transmute(mask_ty, this)?;
-                        this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap()
-                    }
-                    _ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
-                };
+                // To read the mask, we transmute it to an integer.
+                // That does the right thing wrt endianess.
+                let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
+                let mask = mask.transmute(mask_ty, this)?;
+                let mask: u64 = this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap();
 
                 for i in 0..dest_len {
-                    let mask = mask
-                        & 1u64
-                            .checked_shl(simd_bitmask_index(i, dest_len, this.data_layout().endian))
-                            .unwrap();
+                    let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian);
+                    let mask = mask & 1u64.checked_shl(bit_i).unwrap();
                     let yes = this.read_immediate(&this.project_index(&yes, i.into())?)?;
                     let no = this.read_immediate(&this.project_index(&no, i.into())?)?;
                     let dest = this.project_index(&dest, i.into())?;
@@ -445,6 +443,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 }
                 for i in dest_len..bitmask_len {
                     // If the mask is "padded", ensure that padding is all-zero.
+                    // This deliberately does not use `simd_bitmask_index`; these bits are outside
+                    // the bitmask. It does not matter in which order we check them.
                     let mask = mask & 1u64.checked_shl(i).unwrap();
                     if mask != 0 {
                         throw_ub_format!(
@@ -453,6 +453,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     }
                 }
             }
+            // Converts a "vector of bool" into a bitmask.
+            "bitmask" => {
+                let [op] = check_arg_count(args)?;
+                let (op, op_len) = this.operand_to_simd(op)?;
+                let bitmask_len = op_len.max(8);
+
+                // Returns either an unsigned integer or array of `u8`.
+                assert!(
+                    dest.layout.ty.is_integral()
+                        || matches!(dest.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
+                );
+                assert!(bitmask_len <= 64);
+                assert_eq!(bitmask_len, dest.layout.size.bits());
+                let op_len = u32::try_from(op_len).unwrap();
+
+                let mut res = 0u64;
+                for i in 0..op_len {
+                    let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
+                    if simd_element_to_bool(op)? {
+                        res |= 1u64
+                            .checked_shl(simd_bitmask_index(i, op_len, this.data_layout().endian))
+                            .unwrap();
+                    }
+                }
+                // We have to change the type of the place to be able to write `res` into it. This
+                // transmutes the integer to an array, which does the right thing wrt endianess.
+                let dest =
+                    dest.transmute(this.machine.layouts.uint(dest.layout.size).unwrap(), this)?;
+                this.write_int(res, &dest)?;
+            }
             "cast" | "as" | "cast_ptr" | "expose_addr" | "from_exposed_addr" => {
                 let [op] = check_arg_count(args)?;
                 let (op, op_len) = this.operand_to_simd(op)?;
@@ -635,34 +665,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     }
                 }
             }
-            "bitmask" => {
-                let [op] = check_arg_count(args)?;
-                let (op, op_len) = this.operand_to_simd(op)?;
-                let bitmask_len = op_len.max(8);
-
-                // Returns either an unsigned integer or array of `u8`.
-                assert!(
-                    dest.layout.ty.is_integral()
-                        || matches!(dest.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
-                );
-                assert!(bitmask_len <= 64);
-                assert_eq!(bitmask_len, dest.layout.size.bits());
-                let op_len = u32::try_from(op_len).unwrap();
-
-                let mut res = 0u64;
-                for i in 0..op_len {
-                    let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
-                    if simd_element_to_bool(op)? {
-                        res |= 1u64
-                            .checked_shl(simd_bitmask_index(i, op_len, this.data_layout().endian))
-                            .unwrap();
-                    }
-                }
-                // We have to force the place type to be an int so that we can write `res` into it.
-                let mut dest = this.force_allocation(dest)?;
-                dest.layout = this.machine.layouts.uint(dest.layout.size).unwrap();
-                this.write_int(res, &dest)?;
-            }
 
             name => throw_unsup_format!("unimplemented intrinsic: `simd_{name}`"),
         }
diff --git a/src/tools/miri/tests/pass/portable-simd.rs b/src/tools/miri/tests/pass/portable-simd.rs
index 514e12fffc5..184cc3d22e0 100644
--- a/src/tools/miri/tests/pass/portable-simd.rs
+++ b/src/tools/miri/tests/pass/portable-simd.rs
@@ -3,10 +3,6 @@
 #![allow(incomplete_features, internal_features)]
 use std::simd::{prelude::*, StdFloat};
 
-extern "platform-intrinsic" {
-    pub(crate) fn simd_bitmask<T, U>(x: T) -> U;
-}
-
 fn simd_ops_f32() {
     let a = f32x4::splat(10.0);
     let b = f32x4::from_array([1.0, 2.0, 3.0, -4.0]);
@@ -218,6 +214,11 @@ fn simd_ops_i32() {
 }
 
 fn simd_mask() {
+    extern "platform-intrinsic" {
+        pub(crate) fn simd_bitmask<T, U>(x: T) -> U;
+        pub(crate) fn simd_select_bitmask<M, T>(m: M, yes: T, no: T) -> T;
+    }
+
     let intmask = Mask::from_int(i32x4::from_array([0, -1, 0, 0]));
     assert_eq!(intmask, Mask::from_array([false, true, false, false]));
     assert_eq!(intmask.to_array(), [false, true, false, false]);
@@ -266,7 +267,16 @@ fn simd_mask() {
         }
     }
 
-    // This used to cause an ICE.
+    // This used to cause an ICE. It exercises simd_select_bitmask with an array as input.
+    if cfg!(target_endian = "little") {
+        // FIXME this test currently fails on big-endian:
+        // <https://github.com/rust-lang/portable-simd/issues/379>
+        let bitmask = u8x4::from_array([0b00001101, 0, 0, 0]);
+        assert_eq!(
+            mask32x4::from_bitmask_vector(bitmask),
+            mask32x4::from_array([true, false, true, true]),
+        );
+    }
     let bitmask = u8x8::from_array([0b01000101, 0, 0, 0, 0, 0, 0, 0]);
     assert_eq!(
         mask32x8::from_bitmask_vector(bitmask),
@@ -281,6 +291,40 @@ fn simd_mask() {
             true, true, true,
         ]),
     );
+
+    // Also directly call simd_select_bitmask, to test both kinds of argument types.
+    unsafe {
+        // These masks are exactly the results we got out above in the `simd_bitmask` tests.
+        let selected1 = simd_select_bitmask::<u16, _>(
+            if cfg!(target_endian = "little") { 0b1010001101001001 } else { 0b1001001011000101 },
+            i32x16::splat(1), // yes
+            i32x16::splat(0), // no
+        );
+        let selected2 = simd_select_bitmask::<[u8; 2], _>(
+            if cfg!(target_endian = "little") {
+                [0b01001001, 0b10100011]
+            } else {
+                [0b10010010, 0b11000101]
+            },
+            i32x16::splat(1), // yes
+            i32x16::splat(0), // no
+        );
+        assert_eq!(selected1, i32x16::from_array([1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1]));
+        assert_eq!(selected2, selected1);
+        // Also try masks less than a byte long.
+        let selected1 = simd_select_bitmask::<u8, _>(
+            if cfg!(target_endian = "little") { 0b1000 } else { 0b0001 },
+            i32x4::splat(1), // yes
+            i32x4::splat(0), // no
+        );
+        let selected2 = simd_select_bitmask::<[u8; 1], _>(
+            if cfg!(target_endian = "little") { [0b1000] } else { [0b0001] },
+            i32x4::splat(1), // yes
+            i32x4::splat(0), // no
+        );
+        assert_eq!(selected1, i32x4::from_array([0, 0, 0, 1]));
+        assert_eq!(selected2, selected1);
+    }
 }
 
 fn simd_cast() {