about summary refs log tree commit diff
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2024-06-09 09:48:26 +0000
committerbors <bors@rust-lang.org>2024-06-09 09:48:26 +0000
commitde822dc602c1e9da9572b91864bd7aa0101337e0 (patch)
tree337b74253153faa89e9ca09f93551c0c2d7d551e
parent773415d28fc0e710557d7d392aed85fd62ec7418 (diff)
parentba45198c9bff9e7ab7d36171595e78a0f5f636d1 (diff)
downloadrust-de822dc602c1e9da9572b91864bd7aa0101337e0.tar.gz
rust-de822dc602c1e9da9572b91864bd7aa0101337e0.zip
Auto merge of #3662 - RalfJung:simd-bitmask, r=RalfJung
simd_bitmask: work correctly for sizes like 24
-rw-r--r--src/tools/miri/src/helpers.rs3
-rw-r--r--src/tools/miri/src/intrinsics/mod.rs2
-rw-r--r--src/tools/miri/src/intrinsics/simd.rs105
-rw-r--r--src/tools/miri/tests/pass/intrinsics/portable-simd.rs71
4 files changed, 130 insertions, 51 deletions
diff --git a/src/tools/miri/src/helpers.rs b/src/tools/miri/src/helpers.rs
index e6ef2f5dc60..87b53d87ac8 100644
--- a/src/tools/miri/src/helpers.rs
+++ b/src/tools/miri/src/helpers.rs
@@ -374,7 +374,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
         let val = if dest.layout().abi.is_signed() {
             Scalar::from_int(i, dest.layout().size)
         } else {
-            Scalar::from_uint(u64::try_from(i.into()).unwrap(), dest.layout().size)
+            // `unwrap` can only fail here if `i` is negative
+            Scalar::from_uint(u128::try_from(i.into()).unwrap(), dest.layout().size)
         };
         self.eval_context_mut().write_scalar(val, dest)
     }
diff --git a/src/tools/miri/src/intrinsics/mod.rs b/src/tools/miri/src/intrinsics/mod.rs
index 9f7e2baaaf9..74e39ffd933 100644
--- a/src/tools/miri/src/intrinsics/mod.rs
+++ b/src/tools/miri/src/intrinsics/mod.rs
@@ -1,3 +1,5 @@
+#![warn(clippy::arithmetic_side_effects)]
+
 mod atomic;
 mod simd;
 
diff --git a/src/tools/miri/src/intrinsics/simd.rs b/src/tools/miri/src/intrinsics/simd.rs
index df52f223db3..b2b81df55bb 100644
--- a/src/tools/miri/src/intrinsics/simd.rs
+++ b/src/tools/miri/src/intrinsics/simd.rs
@@ -458,26 +458,48 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                     );
                 }
 
-                // The mask must be an integer or an array.
-                assert!(
-                    mask.layout.ty.is_integral()
-                        || matches!(mask.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
-                );
-                assert_eq!(bitmask_len, mask.layout.size.bits());
                 assert_eq!(dest_len, yes_len);
                 assert_eq!(dest_len, no_len);
-                let dest_len = u32::try_from(dest_len).unwrap();
-                let bitmask_len = u32::try_from(bitmask_len).unwrap();
 
-                // To read the mask, we transmute it to an integer.
-                // That does the right thing wrt endianness.
-                let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
-                let mask = mask.transmute(mask_ty, this)?;
-                let mask: u64 = this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap();
+                // Read the mask, either as an integer or as an array.
+                let mask: u64 = match mask.layout.ty.kind() {
+                    ty::Uint(_) => {
+                        // Any larger integer type is fine.
+                        assert!(mask.layout.size.bits() >= bitmask_len);
+                        this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap()
+                    }
+                    ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
+                        // The array must have exactly the right size.
+                        assert_eq!(mask.layout.size.bits(), bitmask_len);
+                        // Read the raw bytes.
+                        let mask = mask.assert_mem_place(); // arrays cannot be immediate
+                        let mask_bytes =
+                            this.read_bytes_ptr_strip_provenance(mask.ptr(), mask.layout.size)?;
+                        // Turn them into a `u64` in the right way.
+                        let mask_size = mask.layout.size.bytes_usize();
+                        let mut mask_arr = [0u8; 8];
+                        match this.data_layout().endian {
+                            Endian::Little => {
+                                // Fill the first N bytes.
+                                mask_arr[..mask_size].copy_from_slice(mask_bytes);
+                                u64::from_le_bytes(mask_arr)
+                            }
+                            Endian::Big => {
+                                // Fill the last N bytes.
+                                let i = mask_arr.len().strict_sub(mask_size);
+                                mask_arr[i..].copy_from_slice(mask_bytes);
+                                u64::from_be_bytes(mask_arr)
+                            }
+                        }
+                    }
+                    _ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
+                };
 
+                let dest_len = u32::try_from(dest_len).unwrap();
+                let bitmask_len = u32::try_from(bitmask_len).unwrap();
                 for i in 0..dest_len {
                     let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian);
-                    let mask = mask & 1u64.checked_shl(bit_i).unwrap();
+                    let mask = mask & 1u64.strict_shl(bit_i);
                     let yes = this.read_immediate(&this.project_index(&yes, i.into())?)?;
                     let no = this.read_immediate(&this.project_index(&no, i.into())?)?;
                     let dest = this.project_index(&dest, i.into())?;
@@ -489,7 +511,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                     // If the mask is "padded", ensure that padding is all-zero.
                     // This deliberately does not use `simd_bitmask_index`; these bits are outside
                     // the bitmask. It does not matter in which order we check them.
-                    let mask = mask & 1u64.checked_shl(i).unwrap();
+                    let mask = mask & 1u64.strict_shl(i);
                     if mask != 0 {
                         throw_ub_format!(
                             "a SIMD bitmask less than 8 bits long must be filled with 0s for the remaining bits"
@@ -508,28 +530,43 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                     );
                 }
 
-                // Returns either an unsigned integer or array of `u8`.
-                assert!(
-                    dest.layout.ty.is_integral()
-                        || matches!(dest.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
-                );
-                assert_eq!(bitmask_len, dest.layout.size.bits());
                 let op_len = u32::try_from(op_len).unwrap();
-
                 let mut res = 0u64;
                 for i in 0..op_len {
                     let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
                     if simd_element_to_bool(op)? {
-                        res |= 1u64
-                            .checked_shl(simd_bitmask_index(i, op_len, this.data_layout().endian))
-                            .unwrap();
+                        let bit_i = simd_bitmask_index(i, op_len, this.data_layout().endian);
+                        res |= 1u64.strict_shl(bit_i);
+                    }
+                }
+                // Write the result, depending on the `dest` type.
+                // Returns either an unsigned integer or array of `u8`.
+                match dest.layout.ty.kind() {
+                    ty::Uint(_) => {
+                        // Any larger integer type is fine, it will be zero-extended.
+                        assert!(dest.layout.size.bits() >= bitmask_len);
+                        this.write_int(res, dest)?;
+                    }
+                    ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
+                        // The array must have exactly the right size.
+                        assert_eq!(dest.layout.size.bits(), bitmask_len);
+                        // We have to write the result byte-for-byte.
+                        let res_size = dest.layout.size.bytes_usize();
+                        let res_bytes;
+                        let res_bytes_slice = match this.data_layout().endian {
+                            Endian::Little => {
+                                res_bytes = res.to_le_bytes();
+                                &res_bytes[..res_size] // take the first N bytes
+                            }
+                            Endian::Big => {
+                                res_bytes = res.to_be_bytes();
+                                &res_bytes[res_bytes.len().strict_sub(res_size)..] // take the last N bytes
+                            }
+                        };
+                        this.write_bytes_ptr(dest.ptr(), res_bytes_slice.iter().cloned())?;
                     }
+                    _ => bug!("simd_bitmask: invalid return type {}", dest.layout.ty),
                 }
-                // We have to change the type of the place to be able to write `res` into it. This
-                // transmutes the integer to an array, which does the right thing wrt endianness.
-                let dest =
-                    dest.transmute(this.machine.layouts.uint(dest.layout.size).unwrap(), this)?;
-                this.write_int(res, &dest)?;
             }
             "cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
                 let [op] = check_arg_count(args)?;
@@ -615,8 +652,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
 
                     let val = if src_index < left_len {
                         this.read_immediate(&this.project_index(&left, src_index)?)?
-                    } else if src_index < left_len.checked_add(right_len).unwrap() {
-                        let right_idx = src_index.checked_sub(left_len).unwrap();
+                    } else if src_index < left_len.strict_add(right_len) {
+                        let right_idx = src_index.strict_sub(left_len);
                         this.read_immediate(&this.project_index(&right, right_idx)?)?
                     } else {
                         throw_ub_format!(
@@ -655,8 +692,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
 
                     let val = if src_index < left_len {
                         this.read_immediate(&this.project_index(&left, src_index)?)?
-                    } else if src_index < left_len.checked_add(right_len).unwrap() {
-                        let right_idx = src_index.checked_sub(left_len).unwrap();
+                    } else if src_index < left_len.strict_add(right_len) {
+                        let right_idx = src_index.strict_sub(left_len);
                         this.read_immediate(&this.project_index(&right, right_idx)?)?
                     } else {
                         throw_ub_format!(
diff --git a/src/tools/miri/tests/pass/intrinsics/portable-simd.rs b/src/tools/miri/tests/pass/intrinsics/portable-simd.rs
index de87f153b6f..c4ec45e4bbe 100644
--- a/src/tools/miri/tests/pass/intrinsics/portable-simd.rs
+++ b/src/tools/miri/tests/pass/intrinsics/portable-simd.rs
@@ -323,38 +323,77 @@ fn simd_mask() {
     #[repr(simd, packed)]
     #[allow(non_camel_case_types)]
     #[derive(Copy, Clone, Debug, PartialEq)]
-    struct i32x10(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32);
+    struct i32x10([i32; 10]);
     impl i32x10 {
         fn splat(x: i32) -> Self {
-            Self(x, x, x, x, x, x, x, x, x, x)
-        }
-        fn from_array(a: [i32; 10]) -> Self {
-            unsafe { std::mem::transmute(a) }
+            Self([x; 10])
         }
     }
     unsafe {
-        let mask = i32x10::from_array([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0]);
+        let mask = i32x10([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0]);
+        let mask_bits = if cfg!(target_endian = "little") { 0b0101001011 } else { 0b1101001010 };
+        let mask_bytes =
+            if cfg!(target_endian = "little") { [0b01001011, 0b01] } else { [0b11, 0b01001010] };
+
         let bitmask1: u16 = simd_bitmask(mask);
         let bitmask2: [u8; 2] = simd_bitmask(mask);
-        if cfg!(target_endian = "little") {
-            assert_eq!(bitmask1, 0b0101001011);
-            assert_eq!(bitmask2, [0b01001011, 0b01]);
-        } else {
-            assert_eq!(bitmask1, 0b1101001010);
-            assert_eq!(bitmask2, [0b11, 0b01001010]);
-        }
+        assert_eq!(bitmask1, mask_bits);
+        assert_eq!(bitmask2, mask_bytes);
+
         let selected1 = simd_select_bitmask::<u16, _>(
-            if cfg!(target_endian = "little") { 0b0101001011 } else { 0b1101001010 },
+            mask_bits,
             i32x10::splat(!0), // yes
             i32x10::splat(0),  // no
         );
         let selected2 = simd_select_bitmask::<[u8; 2], _>(
-            if cfg!(target_endian = "little") { [0b01001011, 0b01] } else { [0b11, 0b01001010] },
+            mask_bytes,
             i32x10::splat(!0), // yes
             i32x10::splat(0),  // no
         );
         assert_eq!(selected1, mask);
-        assert_eq!(selected2, selected1);
+        assert_eq!(selected2, mask);
+    }
+
+    // Test for a mask where the next multiple of 8 is not a power of two.
+    #[repr(simd, packed)]
+    #[allow(non_camel_case_types)]
+    #[derive(Copy, Clone, Debug, PartialEq)]
+    struct i32x20([i32; 20]);
+    impl i32x20 {
+        fn splat(x: i32) -> Self {
+            Self([x; 20])
+        }
+    }
+    unsafe {
+        let mask = i32x20([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0, 0, 0, 0, !0, !0, !0, !0, !0, !0, !0]);
+        let mask_bits = if cfg!(target_endian = "little") {
+            0b11111110000101001011
+        } else {
+            0b11010010100001111111
+        };
+        let mask_bytes = if cfg!(target_endian = "little") {
+            [0b01001011, 0b11100001, 0b1111]
+        } else {
+            [0b1101, 0b00101000, 0b01111111]
+        };
+
+        let bitmask1: u32 = simd_bitmask(mask);
+        let bitmask2: [u8; 3] = simd_bitmask(mask);
+        assert_eq!(bitmask1, mask_bits);
+        assert_eq!(bitmask2, mask_bytes);
+
+        let selected1 = simd_select_bitmask::<u32, _>(
+            mask_bits,
+            i32x20::splat(!0), // yes
+            i32x20::splat(0),  // no
+        );
+        let selected2 = simd_select_bitmask::<[u8; 3], _>(
+            mask_bytes,
+            i32x20::splat(!0), // yes
+            i32x20::splat(0),  // no
+        );
+        assert_eq!(selected1, mask);
+        assert_eq!(selected2, mask);
     }
 }