about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/core_simd/src/masks/full_masks.rs90
-rw-r--r--crates/core_simd/src/swizzle.rs16
-rw-r--r--crates/core_simd/tests/masks.rs9
3 files changed, 90 insertions, 25 deletions
diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs
index 73a0d898700..a529490f3a2 100644
--- a/crates/core_simd/src/masks/full_masks.rs
+++ b/crates/core_simd/src/masks/full_masks.rs
@@ -207,40 +207,108 @@ where
     }
 
     #[inline]
-    pub(crate) fn to_bitmask_integer(self) -> u64 {
-        let resized = self.to_int().extend::<64>(T::FALSE);
+    unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
+    where
+        LaneCount<M>: SupportedLaneCount,
+    {
+        let resized = self.to_int().resize::<M>(T::FALSE);
 
-        // SAFETY: `resized` is an integer vector with length 64
-        let bitmask: u64 = unsafe { intrinsics::simd_bitmask(resized) };
+        // Safety: `resized` is an integer vector with length M, which must match T
+        let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) };
 
         // LLVM assumes bit order should match endianness
         if cfg!(target_endian = "big") {
-            bitmask.reverse_bits()
+            bitmask.reverse_bits(M)
         } else {
             bitmask
         }
     }
 
     #[inline]
-    pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
+    unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
+    where
+        LaneCount<M>: SupportedLaneCount,
+    {
         // LLVM assumes bit order should match endianness
         let bitmask = if cfg!(target_endian = "big") {
-            bitmask.reverse_bits()
+            bitmask.reverse_bits(M)
         } else {
             bitmask
         };
 
         // SAFETY: `mask` is the correct bitmask type for a u64 bitmask
-        let mask: Simd<T, 64> = unsafe {
+        let mask: Simd<T, M> = unsafe {
             intrinsics::simd_select_bitmask(
                 bitmask,
-                Simd::<T, 64>::splat(T::TRUE),
-                Simd::<T, 64>::splat(T::FALSE),
+                Simd::<T, M>::splat(T::TRUE),
+                Simd::<T, M>::splat(T::FALSE),
             )
         };
 
         // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
-        unsafe { Self::from_int_unchecked(mask.extend::<N>(T::FALSE)) }
+        unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
+    }
+
+    #[inline]
+    pub(crate) fn to_bitmask_integer(self) -> u64 {
+        // TODO modify simd_bitmask to zero-extend output, making this unnecessary
+        macro_rules! bitmask {
+            { $($ty:ty: $($len:literal),*;)* } => {
+                match N {
+                    $($(
+                    // Safety: bitmask matches length
+                    $len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 },
+                    )*)*
+                    // Safety: bitmask matches length
+                    _ => unsafe { self.to_bitmask_impl::<u64, 64>() },
+                }
+            }
+        }
+        #[cfg(all_lane_counts)]
+        bitmask! {
+            u8: 1, 2, 3, 4, 5, 6, 7, 8;
+            u16: 9, 10, 11, 12, 13, 14, 15, 16;
+            u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
+            u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
+        }
+        #[cfg(not(all_lane_counts))]
+        bitmask! {
+            u8: 1, 2, 4, 8;
+            u16: 16;
+            u32: 32;
+            u64: 64;
+        }
+    }
+
+    #[inline]
+    pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
+        // TODO modify simd_bitmask_select to truncate input, making this unnecessary
+        macro_rules! bitmask {
+            { $($ty:ty: $($len:literal),*;)* } => {
+                match N {
+                    $($(
+                    // Safety: bitmask matches length
+                    $len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) },
+                    )*)*
+                    // Safety: bitmask matches length
+                    _ => unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) },
+                }
+            }
+        }
+        #[cfg(all_lane_counts)]
+        bitmask! {
+            u8: 1, 2, 3, 4, 5, 6, 7, 8;
+            u16: 9, 10, 11, 12, 13, 14, 15, 16;
+            u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
+            u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
+        }
+        #[cfg(not(all_lane_counts))]
+        bitmask! {
+            u8: 1, 2, 4, 8;
+            u16: 16;
+            u32: 32;
+            u64: 64;
+        }
     }
 
     #[inline]
diff --git a/crates/core_simd/src/swizzle.rs b/crates/core_simd/src/swizzle.rs
index e5b3d4444d8..ec8548d5574 100644
--- a/crates/core_simd/src/swizzle.rs
+++ b/crates/core_simd/src/swizzle.rs
@@ -350,9 +350,9 @@ where
         )
     }
 
-    /// Extend a vector.
+    /// Resize a vector.
     ///
-    /// Extends the length of a vector, setting the new elements to `value`.
+    /// If `M` > `N`, extends the length of a vector, setting the new elements to `value`.
     /// If `M` < `N`, truncates the vector to the first `M` elements.
     ///
     /// ```
@@ -361,17 +361,17 @@ where
     /// # #[cfg(not(feature = "as_crate"))] use core::simd;
     /// # use simd::u32x4;
     /// let x = u32x4::from_array([0, 1, 2, 3]);
-    /// assert_eq!(x.extend::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]);
-    /// assert_eq!(x.extend::<2>(9).to_array(), [0, 1]);
+    /// assert_eq!(x.resize::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]);
+    /// assert_eq!(x.resize::<2>(9).to_array(), [0, 1]);
     /// ```
     #[inline]
     #[must_use = "method returns a new vector and does not mutate the original inputs"]
-    pub fn extend<const M: usize>(self, value: T) -> Simd<T, M>
+    pub fn resize<const M: usize>(self, value: T) -> Simd<T, M>
     where
         LaneCount<M>: SupportedLaneCount,
     {
-        struct Extend<const N: usize>;
-        impl<const N: usize, const M: usize> Swizzle<M> for Extend<N> {
+        struct Resize<const N: usize>;
+        impl<const N: usize, const M: usize> Swizzle<M> for Resize<N> {
             const INDEX: [usize; M] = const {
                 let mut index = [0; M];
                 let mut i = 0;
@@ -382,6 +382,6 @@ where
                 index
             };
         }
-        Extend::<N>::concat_swizzle(self, Simd::splat(value))
+        Resize::<N>::concat_swizzle(self, Simd::splat(value))
     }
 }
diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs
index 92ee53b3e55..00fc2a24e27 100644
--- a/crates/core_simd/tests/masks.rs
+++ b/crates/core_simd/tests/masks.rs
@@ -13,7 +13,7 @@ macro_rules! test_mask_api {
             #[cfg(target_arch = "wasm32")]
             use wasm_bindgen_test::*;
 
-            use core_simd::simd::{Mask, Simd};
+            use core_simd::simd::Mask;
 
             #[test]
             #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
@@ -124,17 +124,14 @@ macro_rules! test_mask_api {
 
             #[test]
             fn roundtrip_bitmask_vector_conversion() {
+                use core_simd::simd::ToBytes;
                 let values = [
                     true, false, false, true, false, false, true, false,
                     true, true, false, false, false, false, false, true,
                 ];
                 let mask = Mask::<$type, 16>::from_array(values);
                 let bitmask = mask.to_bitmask_vector();
-                if core::mem::size_of::<$type>() == 1 {
-                    assert_eq!(bitmask, Simd::from_array([0b01001001 as _, 0b10000011 as _, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]));
-                } else {
-                    assert_eq!(bitmask, Simd::from_array([0b1000001101001001 as _, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]));
-                }
+                assert_eq!(bitmask.resize::<2>(0).to_ne_bytes()[..2], [0b01001001, 0b10000011]);
                 assert_eq!(Mask::<$type, 16>::from_bitmask_vector(bitmask), mask);
             }
         }