about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCaleb Zulawski <caleb.zulawski@gmail.com>2023-11-18 22:05:02 -0500
committerGitHub <noreply@github.com>2023-11-18 22:05:02 -0500
commit7e5c03a33db5548fbd78a22656a60574dbf0788f (patch)
treee183c29c2019f2a457692826c80f1096867cb67b
parent8d9bcda64cfe5f4dd172620d5d0eacadbdb13751 (diff)
parent0ad68db91a3149885bc62ae11d2d83e7d401fc25 (diff)
downloadrust-7e5c03a33db5548fbd78a22656a60574dbf0788f.tar.gz
rust-7e5c03a33db5548fbd78a22656a60574dbf0788f.zip
Merge pull request #375 from rust-lang/bitmask
Simplify bitmasks
-rw-r--r--crates/core_simd/src/masks.rs42
-rw-r--r--crates/core_simd/src/masks/bitmask.rs46
-rw-r--r--crates/core_simd/src/masks/full_masks.rs140
-rw-r--r--crates/core_simd/src/masks/to_bitmask.rs111
-rw-r--r--crates/core_simd/src/swizzle.rs35
-rw-r--r--crates/core_simd/tests/masks.rs13
6 files changed, 202 insertions, 185 deletions
diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs
index 1199153a5bd..63731342423 100644
--- a/crates/core_simd/src/masks.rs
+++ b/crates/core_simd/src/masks.rs
@@ -12,9 +12,6 @@
 )]
 mod mask_impl;
 
-mod to_bitmask;
-pub use to_bitmask::{ToBitMask, ToBitMaskArray};
-
 use crate::simd::{
     cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount,
 };
@@ -262,6 +259,45 @@ where
     pub fn all(self) -> bool {
         self.0.all()
     }
+
+    /// Create a bitmask from a mask.
+    ///
+    /// Each bit is set if the corresponding element in the mask is `true`.
+    /// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
+    #[inline]
+    #[must_use = "method returns a new integer and does not mutate the original value"]
+    pub fn to_bitmask(self) -> u64 {
+        self.0.to_bitmask_integer()
+    }
+
+    /// Create a mask from a bitmask.
+    ///
+    /// For each bit, if it is set, the corresponding element in the mask is set to `true`.
+    /// If the mask contains more than 64 elements, the remainder are set to `false`.
+    #[inline]
+    #[must_use = "method returns a new mask and does not mutate the original value"]
+    pub fn from_bitmask(bitmask: u64) -> Self {
+        Self(mask_impl::Mask::from_bitmask_integer(bitmask))
+    }
+
+    /// Create a bitmask vector from a mask.
+    ///
+    /// Each bit is set if the corresponding element in the mask is `true`.
+    /// The remaining bits are unset.
+    #[inline]
+    #[must_use = "method returns a new integer and does not mutate the original value"]
+    pub fn to_bitmask_vector(self) -> Simd<u8, N> {
+        self.0.to_bitmask_vector()
+    }
+
+    /// Create a mask from a bitmask vector.
+    ///
+    /// For each bit, if it is set, the corresponding element in the mask is set to `true`.
+    #[inline]
+    #[must_use = "method returns a new mask and does not mutate the original value"]
+    pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
+        Self(mask_impl::Mask::from_bitmask_vector(bitmask))
+    }
 }
 
 // vector/array conversion
diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs
index aaae28a07be..6ddff07fea2 100644
--- a/crates/core_simd/src/masks/bitmask.rs
+++ b/crates/core_simd/src/masks/bitmask.rs
@@ -1,7 +1,7 @@
 #![allow(unused_imports)]
 use super::MaskElement;
 use crate::simd::intrinsics;
-use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
+use crate::simd::{LaneCount, Simd, SupportedLaneCount};
 use core::marker::PhantomData;
 
 /// A mask where each lane is represented by a single bit.
@@ -120,39 +120,37 @@ where
     }
 
     #[inline]
-    #[must_use = "method returns a new array and does not mutate the original value"]
-    pub fn to_bitmask_array<const M: usize>(self) -> [u8; M] {
-        assert!(core::mem::size_of::<Self>() == M);
-
-        // Safety: converting an integer to an array of bytes of the same size is safe
-        unsafe { core::mem::transmute_copy(&self.0) }
+    #[must_use = "method returns a new vector and does not mutate the original value"]
+    pub fn to_bitmask_vector(self) -> Simd<u8, N> {
+        let mut bitmask = Simd::splat(0);
+        bitmask.as_mut_array()[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
+        bitmask
     }
 
     #[inline]
     #[must_use = "method returns a new mask and does not mutate the original value"]
-    pub fn from_bitmask_array<const M: usize>(bitmask: [u8; M]) -> Self {
-        assert!(core::mem::size_of::<Self>() == M);
-
-        // Safety: converting an array of bytes to an integer of the same size is safe
-        Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
+    pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
+        let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
+        let len = bytes.as_ref().len();
+        bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
+        Self(bytes, PhantomData)
     }
 
     #[inline]
-    pub fn to_bitmask_integer<U>(self) -> U
-    where
-        super::Mask<T, N>: ToBitMask<BitMask = U>,
-    {
-        // Safety: these are the same types
-        unsafe { core::mem::transmute_copy(&self.0) }
+    pub fn to_bitmask_integer(self) -> u64 {
+        let mut bitmask = [0u8; 8];
+        bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
+        u64::from_ne_bytes(bitmask)
     }
 
     #[inline]
-    pub fn from_bitmask_integer<U>(bitmask: U) -> Self
-    where
-        super::Mask<T, N>: ToBitMask<BitMask = U>,
-    {
-        // Safety: these are the same types
-        unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
+    pub fn from_bitmask_integer(bitmask: u64) -> Self {
+        let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
+        let len = bytes.as_mut().len();
+        bytes
+            .as_mut()
+            .copy_from_slice(&bitmask.to_ne_bytes()[..len]);
+        Self(bytes, PhantomData)
     }
 
     #[inline]
diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs
index 2aa9272ab46..0d17e90c128 100644
--- a/crates/core_simd/src/masks/full_masks.rs
+++ b/crates/core_simd/src/masks/full_masks.rs
@@ -1,8 +1,7 @@
 //! Masks that take up full SIMD vector registers.
 
-use super::{to_bitmask::ToBitMaskArray, MaskElement};
 use crate::simd::intrinsics;
-use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
+use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};
 
 #[repr(transparent)]
 pub struct Mask<T, const N: usize>(Simd<T, N>)
@@ -143,53 +142,49 @@ where
     }
 
     #[inline]
-    #[must_use = "method returns a new array and does not mutate the original value"]
-    pub fn to_bitmask_array<const M: usize>(self) -> [u8; M]
-    where
-        super::Mask<T, N>: ToBitMaskArray,
-    {
+    #[must_use = "method returns a new vector and does not mutate the original value"]
+    pub fn to_bitmask_vector(self) -> Simd<u8, N> {
+        let mut bitmask = Simd::splat(0);
+
         // Safety: Bytes is the right size array
         unsafe {
             // Compute the bitmask
-            let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
+            let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
                 intrinsics::simd_bitmask(self.0);
 
-            // Transmute to the return type
-            let mut bitmask: [u8; M] = core::mem::transmute_copy(&bitmask);
-
             // LLVM assumes bit order should match endianness
             if cfg!(target_endian = "big") {
-                for x in bitmask.as_mut() {
-                    *x = x.reverse_bits();
+                for x in bytes.as_mut() {
+                    *x = x.reverse_bits()
                 }
-            };
+            }
 
-            bitmask
+            bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref());
         }
+
+        bitmask
     }
 
     #[inline]
     #[must_use = "method returns a new mask and does not mutate the original value"]
-    pub fn from_bitmask_array<const M: usize>(mut bitmask: [u8; M]) -> Self
-    where
-        super::Mask<T, N>: ToBitMaskArray,
-    {
+    pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
+        let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
+
         // Safety: Bytes is the right size array
         unsafe {
+            let len = bytes.as_ref().len();
+            bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
+
             // LLVM assumes bit order should match endianness
             if cfg!(target_endian = "big") {
-                for x in bitmask.as_mut() {
+                for x in bytes.as_mut() {
                     *x = x.reverse_bits();
                 }
             }
 
-            // Transmute to the bitmask
-            let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
-                core::mem::transmute_copy(&bitmask);
-
             // Compute the regular mask
             Self::from_int_unchecked(intrinsics::simd_select_bitmask(
-                bitmask,
+                bytes,
                 Self::splat(true).to_int(),
                 Self::splat(false).to_int(),
             ))
@@ -197,40 +192,107 @@ where
     }
 
     #[inline]
-    pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
+    unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
     where
-        super::Mask<T, N>: ToBitMask<BitMask = U>,
+        LaneCount<M>: SupportedLaneCount,
     {
-        // Safety: U is required to be the appropriate bitmask type
-        let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
+        let resized = self.to_int().resize::<M>(T::FALSE);
+
+        // Safety: `resized` is an integer vector with length M, which must match T
+        let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) };
 
         // LLVM assumes bit order should match endianness
         if cfg!(target_endian = "big") {
-            bitmask.reverse_bits(N)
+            bitmask.reverse_bits(M)
         } else {
             bitmask
         }
     }
 
     #[inline]
-    pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
+    unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
     where
-        super::Mask<T, N>: ToBitMask<BitMask = U>,
+        LaneCount<M>: SupportedLaneCount,
     {
         // LLVM assumes bit order should match endianness
         let bitmask = if cfg!(target_endian = "big") {
-            bitmask.reverse_bits(N)
+            bitmask.reverse_bits(M)
         } else {
             bitmask
         };
 
-        // Safety: U is required to be the appropriate bitmask type
-        unsafe {
-            Self::from_int_unchecked(intrinsics::simd_select_bitmask(
+        // SAFETY: `mask` is the correct bitmask type for a u64 bitmask
+        let mask: Simd<T, M> = unsafe {
+            intrinsics::simd_select_bitmask(
                 bitmask,
-                Self::splat(true).to_int(),
-                Self::splat(false).to_int(),
-            ))
+                Simd::<T, M>::splat(T::TRUE),
+                Simd::<T, M>::splat(T::FALSE),
+            )
+        };
+
+        // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
+        unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
+    }
+
+    #[inline]
+    pub(crate) fn to_bitmask_integer(self) -> u64 {
+        // TODO modify simd_bitmask to zero-extend output, making this unnecessary
+        macro_rules! bitmask {
+            { $($ty:ty: $($len:literal),*;)* } => {
+                match N {
+                    $($(
+                    // Safety: bitmask matches length
+                    $len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 },
+                    )*)*
+                    // Safety: bitmask matches length
+                    _ => unsafe { self.to_bitmask_impl::<u64, 64>() },
+                }
+            }
+        }
+        #[cfg(all_lane_counts)]
+        bitmask! {
+            u8: 1, 2, 3, 4, 5, 6, 7, 8;
+            u16: 9, 10, 11, 12, 13, 14, 15, 16;
+            u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
+            u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
+        }
+        #[cfg(not(all_lane_counts))]
+        bitmask! {
+            u8: 1, 2, 4, 8;
+            u16: 16;
+            u32: 32;
+            u64: 64;
+        }
+    }
+
+    #[inline]
+    pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
+        // TODO modify simd_bitmask_select to truncate input, making this unnecessary
+        macro_rules! bitmask {
+            { $($ty:ty: $($len:literal),*;)* } => {
+                match N {
+                    $($(
+                    // Safety: bitmask matches length
+                    $len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) },
+                    )*)*
+                    // Safety: bitmask matches length
+                    _ => unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) },
+                }
+            }
+        }
+        #[cfg(all_lane_counts)]
+        bitmask! {
+            u8: 1, 2, 3, 4, 5, 6, 7, 8;
+            u16: 9, 10, 11, 12, 13, 14, 15, 16;
+            u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
+            u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
+        }
+        #[cfg(not(all_lane_counts))]
+        bitmask! {
+            u8: 1, 2, 4, 8;
+            u16: 16;
+            u32: 32;
+            u64: 64;
         }
     }
 
diff --git a/crates/core_simd/src/masks/to_bitmask.rs b/crates/core_simd/src/masks/to_bitmask.rs
deleted file mode 100644
index 06f09c65aca..00000000000
--- a/crates/core_simd/src/masks/to_bitmask.rs
+++ /dev/null
@@ -1,111 +0,0 @@
-use super::{mask_impl, Mask, MaskElement};
-use crate::simd::{LaneCount, SupportedLaneCount};
-use core::borrow::{Borrow, BorrowMut};
-
-mod sealed {
-    pub trait Sealed {}
-}
-pub use sealed::Sealed;
-
-impl<T, const N: usize> Sealed for Mask<T, N>
-where
-    T: MaskElement,
-    LaneCount<N>: SupportedLaneCount,
-{
-}
-
-/// Converts masks to and from integer bitmasks.
-///
-/// Each bit of the bitmask corresponds to a mask element, starting with the LSB.
-pub trait ToBitMask: Sealed {
-    /// The integer bitmask type.
-    type BitMask;
-
-    /// Converts a mask to a bitmask.
-    fn to_bitmask(self) -> Self::BitMask;
-
-    /// Converts a bitmask to a mask.
-    fn from_bitmask(bitmask: Self::BitMask) -> Self;
-}
-
-/// Converts masks to and from byte array bitmasks.
-///
-/// Each bit of the bitmask corresponds to a mask element, starting with the LSB of the first byte.
-pub trait ToBitMaskArray: Sealed {
-    /// The bitmask array.
-    type BitMaskArray: Copy
-        + Unpin
-        + Send
-        + Sync
-        + AsRef<[u8]>
-        + AsMut<[u8]>
-        + Borrow<[u8]>
-        + BorrowMut<[u8]>
-        + 'static;
-
-    /// Converts a mask to a bitmask.
-    fn to_bitmask_array(self) -> Self::BitMaskArray;
-
-    /// Converts a bitmask to a mask.
-    fn from_bitmask_array(bitmask: Self::BitMaskArray) -> Self;
-}
-
-macro_rules! impl_integer {
-    { $(impl ToBitMask<BitMask=$int:ty> for Mask<_, $lanes:literal>)* } => {
-        $(
-        impl<T: MaskElement> ToBitMask for Mask<T, $lanes> {
-            type BitMask = $int;
-
-            #[inline]
-            fn to_bitmask(self) -> $int {
-                self.0.to_bitmask_integer()
-            }
-
-            #[inline]
-            fn from_bitmask(bitmask: $int) -> Self {
-                Self(mask_impl::Mask::from_bitmask_integer(bitmask))
-            }
-        }
-        )*
-    }
-}
-
-macro_rules! impl_array {
-    { $(impl ToBitMaskArray<Bytes=$int:literal> for Mask<_, $lanes:literal>)* } => {
-        $(
-        impl<T: MaskElement> ToBitMaskArray for Mask<T, $lanes> {
-            type BitMaskArray = [u8; $int];
-
-            #[inline]
-            fn to_bitmask_array(self) -> Self::BitMaskArray {
-                self.0.to_bitmask_array()
-            }
-
-            #[inline]
-            fn from_bitmask_array(bitmask: Self::BitMaskArray) -> Self {
-                Self(mask_impl::Mask::from_bitmask_array(bitmask))
-            }
-        }
-        )*
-    }
-}
-
-impl_integer! {
-    impl ToBitMask<BitMask=u8> for Mask<_, 1>
-    impl ToBitMask<BitMask=u8> for Mask<_, 2>
-    impl ToBitMask<BitMask=u8> for Mask<_, 4>
-    impl ToBitMask<BitMask=u8> for Mask<_, 8>
-    impl ToBitMask<BitMask=u16> for Mask<_, 16>
-    impl ToBitMask<BitMask=u32> for Mask<_, 32>
-    impl ToBitMask<BitMask=u64> for Mask<_, 64>
-}
-
-impl_array! {
-    impl ToBitMaskArray<Bytes=1> for Mask<_, 1>
-    impl ToBitMaskArray<Bytes=1> for Mask<_, 2>
-    impl ToBitMaskArray<Bytes=1> for Mask<_, 4>
-    impl ToBitMaskArray<Bytes=1> for Mask<_, 8>
-    impl ToBitMaskArray<Bytes=2> for Mask<_, 16>
-    impl ToBitMaskArray<Bytes=4> for Mask<_, 32>
-    impl ToBitMaskArray<Bytes=8> for Mask<_, 64>
-}
diff --git a/crates/core_simd/src/swizzle.rs b/crates/core_simd/src/swizzle.rs
index 6af882c0a0e..ec8548d5574 100644
--- a/crates/core_simd/src/swizzle.rs
+++ b/crates/core_simd/src/swizzle.rs
@@ -349,4 +349,39 @@ where
             Odd::concat_swizzle(self, other),
         )
     }
+
+    /// Resize a vector.
+    ///
+    /// If `M` > `N`, extends the length of a vector, setting the new elements to `value`.
+    /// If `M` < `N`, truncates the vector to the first `M` elements.
+    ///
+    /// ```
+    /// # #![feature(portable_simd)]
+    /// # #[cfg(feature = "as_crate")] use core_simd::simd;
+    /// # #[cfg(not(feature = "as_crate"))] use core::simd;
+    /// # use simd::u32x4;
+    /// let x = u32x4::from_array([0, 1, 2, 3]);
+    /// assert_eq!(x.resize::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]);
+    /// assert_eq!(x.resize::<2>(9).to_array(), [0, 1]);
+    /// ```
+    #[inline]
+    #[must_use = "method returns a new vector and does not mutate the original inputs"]
+    pub fn resize<const M: usize>(self, value: T) -> Simd<T, M>
+    where
+        LaneCount<M>: SupportedLaneCount,
+    {
+        struct Resize<const N: usize>;
+        impl<const N: usize, const M: usize> Swizzle<M> for Resize<N> {
+            const INDEX: [usize; M] = const {
+                let mut index = [0; M];
+                let mut i = 0;
+                while i < M {
+                    index[i] = if i < N { i } else { N };
+                    i += 1;
+                }
+                index
+            };
+        }
+        Resize::<N>::concat_swizzle(self, Simd::splat(value))
+    }
 }
diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs
index 7c1d4c7dd3f..00fc2a24e27 100644
--- a/crates/core_simd/tests/masks.rs
+++ b/crates/core_simd/tests/masks.rs
@@ -72,7 +72,6 @@ macro_rules! test_mask_api {
 
             #[test]
             fn roundtrip_bitmask_conversion() {
-                use core_simd::simd::ToBitMask;
                 let values = [
                     true, false, false, true, false, false, true, false,
                     true, true, false, false, false, false, false, true,
@@ -85,8 +84,6 @@ macro_rules! test_mask_api {
 
             #[test]
             fn roundtrip_bitmask_conversion_short() {
-                use core_simd::simd::ToBitMask;
-
                 let values = [
                     false, false, false, true,
                 ];
@@ -126,16 +123,16 @@ macro_rules! test_mask_api {
             }
 
             #[test]
-            fn roundtrip_bitmask_array_conversion() {
-                use core_simd::simd::ToBitMaskArray;
+            fn roundtrip_bitmask_vector_conversion() {
+                use core_simd::simd::ToBytes;
                 let values = [
                     true, false, false, true, false, false, true, false,
                     true, true, false, false, false, false, false, true,
                 ];
                 let mask = Mask::<$type, 16>::from_array(values);
-                let bitmask = mask.to_bitmask_array();
-                assert_eq!(bitmask, [0b01001001, 0b10000011]);
-                assert_eq!(Mask::<$type, 16>::from_bitmask_array(bitmask), mask);
+                let bitmask = mask.to_bitmask_vector();
+                assert_eq!(bitmask.resize::<2>(0).to_ne_bytes()[..2], [0b01001001, 0b10000011]);
+                assert_eq!(Mask::<$type, 16>::from_bitmask_vector(bitmask), mask);
             }
         }
     }