about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCaleb Zulawski <caleb.zulawski@gmail.com>2021-04-19 04:31:43 +0000
committerCaleb Zulawski <caleb.zulawski@gmail.com>2021-04-28 21:56:11 +0000
commiteec42808aa024d354bb40ec890612c37ba4a496c (patch)
tree85836dfa0c89687cde4351a1acd5562c6dc60577
parentda42aa5403659a6f1f6f4bc4b65d177f13fb6536 (diff)
downloadrust-eec42808aa024d354bb40ec890612c37ba4a496c.tar.gz
rust-eec42808aa024d354bb40ec890612c37ba4a496c.zip
Update bitmask API
-rw-r--r--crates/core_simd/src/intrinsics.rs3
-rw-r--r--crates/core_simd/src/lanes_at_most_32.rs38
-rw-r--r--crates/core_simd/src/masks/bitmask.rs146
-rw-r--r--crates/core_simd/src/masks/full_masks.rs59
-rw-r--r--crates/core_simd/src/masks/mod.rs93
-rw-r--r--crates/core_simd/tests/masks.rs17
6 files changed, 196 insertions, 160 deletions
diff --git a/crates/core_simd/src/intrinsics.rs b/crates/core_simd/src/intrinsics.rs
index 665dc1a51d7..1812a9c624d 100644
--- a/crates/core_simd/src/intrinsics.rs
+++ b/crates/core_simd/src/intrinsics.rs
@@ -76,6 +76,9 @@ extern "platform-intrinsic" {
     pub(crate) fn simd_reduce_and<T, U>(x: T) -> U;
     pub(crate) fn simd_reduce_or<T, U>(x: T) -> U;
     pub(crate) fn simd_reduce_xor<T, U>(x: T) -> U;
+
+    // truncate integer vector to bitmask
+    pub(crate) fn simd_bitmask<T, U>(x: T) -> U;
 }
 
 #[cfg(feature = "std")]
diff --git a/crates/core_simd/src/lanes_at_most_32.rs b/crates/core_simd/src/lanes_at_most_32.rs
index 2fee9ca9189..2d84b1306ea 100644
--- a/crates/core_simd/src/lanes_at_most_32.rs
+++ b/crates/core_simd/src/lanes_at_most_32.rs
@@ -1,14 +1,38 @@
 /// Implemented for vectors that are supported by the implementation.
-pub trait LanesAtMost32 {}
+pub trait LanesAtMost32: sealed::Sealed {
+    #[doc(hidden)]
+    type BitMask: Into<u64>;
+}
+
+mod sealed {
+    pub trait Sealed {}
+}
 
 macro_rules! impl_for {
     { $name:ident } => {
-        impl LanesAtMost32 for $name<1> {}
-        impl LanesAtMost32 for $name<2> {}
-        impl LanesAtMost32 for $name<4> {}
-        impl LanesAtMost32 for $name<8> {}
-        impl LanesAtMost32 for $name<16> {}
-        impl LanesAtMost32 for $name<32> {}
+        impl<const LANES: usize> sealed::Sealed for $name<LANES>
+        where
+            $name<LANES>: LanesAtMost32,
+        {}
+
+        impl LanesAtMost32 for $name<1> {
+            type BitMask = u8;
+        }
+        impl LanesAtMost32 for $name<2> {
+            type BitMask = u8;
+        }
+        impl LanesAtMost32 for $name<4> {
+            type BitMask = u8;
+        }
+        impl LanesAtMost32 for $name<8> {
+            type BitMask = u8;
+        }
+        impl LanesAtMost32 for $name<16> {
+            type BitMask = u16;
+        }
+        impl LanesAtMost32 for $name<32> {
+            type BitMask = u32;
+        }
     }
 }
 
diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs
index 32e2ffb8615..bf7c70c5a3a 100644
--- a/crates/core_simd/src/masks/bitmask.rs
+++ b/crates/core_simd/src/masks/bitmask.rs
@@ -1,13 +1,9 @@
-use crate::LanesAtMost32;
-
 /// A mask where each lane is represented by a single bit.
 #[derive(Copy, Clone, Debug, PartialOrd, PartialEq, Ord, Eq, Hash)]
 #[repr(transparent)]
-pub struct BitMask<const LANES: usize>(u64)
+pub struct BitMask<const LANES: usize>(u64);
 
 impl<const LANES: usize> BitMask<LANES>
-where
-    Self: LanesAtMost32,
 {
     #[inline]
     pub fn splat(value: bool) -> Self {
@@ -25,13 +21,50 @@ where
 
     #[inline]
     pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
-        self.0 ^= ((value ^ self.test(lane)) as u64) << lane
+        self.0 ^= ((value ^ self.test_unchecked(lane)) as u64) << lane
+    }
+
+    #[inline]
+    pub fn to_int<V, T>(self) -> V
+    where
+        V: Default + AsMut<[T; LANES]>,
+        T: From<i8>,
+    {
+        // TODO this should be an intrinsic sign-extension
+        let mut v = V::default();
+        for i in 0..LANES {
+            let lane = unsafe { self.test_unchecked(i) };
+            v.as_mut()[i] = (-(lane as i8)).into();
+        }
+        v
+    }
+
+    #[inline]
+    pub unsafe fn from_int_unchecked<V>(value: V) -> Self
+    where
+        V: crate::LanesAtMost32,
+    {
+        let mask: V::BitMask = crate::intrinsics::simd_bitmask(value);
+        Self(mask.into())
+    }
+
+    #[inline]
+    pub fn to_bitmask(self) -> u64 {
+        self.0
+    }
+
+    #[inline]
+    pub fn any(self) -> bool {
+        self != Self::splat(false)
+    }
+
+    #[inline]
+    pub fn all(self) -> bool {
+        self == Self::splat(true)
     }
 }
 
 impl<const LANES: usize> core::ops::BitAnd for BitMask<LANES>
-where
-    Self: LanesAtMost32,
 {
     type Output = Self;
     #[inline]
@@ -41,8 +74,6 @@ where
 }
 
 impl<const LANES: usize> core::ops::BitAnd<bool> for BitMask<LANES>
-where
-    Self: LanesAtMost32,
 {
     type Output = Self;
     #[inline]
@@ -52,8 +83,6 @@ where
 }
 
 impl<const LANES: usize> core::ops::BitAnd<BitMask<LANES>> for bool
-where
-    BitMask<LANES>: LanesAtMost32,
 {
     type Output = BitMask<LANES>;
     #[inline]
@@ -63,8 +92,6 @@ where
 }
 
 impl<const LANES: usize> core::ops::BitOr for BitMask<LANES>
-where
-    Self: LanesAtMost32,
 {
     type Output = Self;
     #[inline]
@@ -73,31 +100,7 @@ where
     }
 }
 
-impl<const LANES: usize> core::ops::BitOr<bool> for BitMask<LANES>
-where
-    Self: LanesAtMost32,
-{
-    type Output = Self;
-    #[inline]
-    fn bitor(self, rhs: bool) -> Self {
-        self | Self::splat(rhs)
-    }
-}
-
-impl<const LANES: usize> core::ops::BitOr<BitMask<LANES>> for bool
-where
-    BitMask<LANES>: LanesAtMost32,
-{
-    type Output = BitMask<LANES>;
-    #[inline]
-    fn bitor(self, rhs: BitMask<LANES>) -> BitMask<LANES> {
-        BitMask::<LANES>::splat(self) | rhs
-    }
-}
-
 impl<const LANES: usize> core::ops::BitXor for BitMask<LANES>
-where
-    Self: LanesAtMost32,
 {
     type Output = Self;
     #[inline]
@@ -106,42 +109,16 @@ where
     }
 }
 
-impl<const LANES: usize> core::ops::BitXor<bool> for BitMask<LANES>
-where
-    Self: LanesAtMost32,
-{
-    type Output = Self;
-    #[inline]
-    fn bitxor(self, rhs: bool) -> Self::Output {
-        self ^ Self::splat(rhs)
-    }
-}
-
-impl<const LANES: usize> core::ops::BitXor<BitMask<LANES>> for bool
-where
-    BitMask<LANES>: LanesAtMost32,
-{
-    type Output = BitMask<LANES>;
-    #[inline]
-    fn bitxor(self, rhs: BitMask<LANES>) -> Self::Output {
-        BitMask::<LANES>::splat(self) ^ rhs
-    }
-}
-
 impl<const LANES: usize> core::ops::Not for BitMask<LANES>
-where
-    Self: LanesAtMost32,
 {
     type Output = BitMask<LANES>;
     #[inline]
     fn not(self) -> Self::Output {
-        Self(!self.0)
+        Self(!self.0) & Self::splat(true)
     }
 }
 
 impl<const LANES: usize> core::ops::BitAndAssign for BitMask<LANES>
-where
-    Self: LanesAtMost32,
 {
     #[inline]
     fn bitand_assign(&mut self, rhs: Self) {
@@ -149,19 +126,7 @@ where
     }
 }
 
-impl<const LANES: usize> core::ops::BitAndAssign<bool> for BitMask<LANES>
-where
-    Self: LanesAtMost32,
-{
-    #[inline]
-    fn bitand_assign(&mut self, rhs: bool) {
-        *self &= Self::splat(rhs);
-    }
-}
-
 impl<const LANES: usize> core::ops::BitOrAssign for BitMask<LANES>
-where
-    Self: LanesAtMost32,
 {
     #[inline]
     fn bitor_assign(&mut self, rhs: Self) {
@@ -169,19 +134,7 @@ where
     }
 }
 
-impl<const LANES: usize> core::ops::BitOrAssign<bool> for BitMask<LANES>
-where
-    Self: LanesAtMost32,
-{
-    #[inline]
-    fn bitor_assign(&mut self, rhs: bool) {
-        *self |= Self::splat(rhs);
-    }
-}
-
 impl<const LANES: usize> core::ops::BitXorAssign for BitMask<LANES>
-where
-    Self: LanesAtMost32,
 {
     #[inline]
     fn bitxor_assign(&mut self, rhs: Self) {
@@ -189,12 +142,9 @@ where
     }
 }
 
-impl<const LANES: usize> core::ops::BitXorAssign<bool> for BitMask<LANES>
-where
-    Self: LanesAtMost32,
-{
-    #[inline]
-    fn bitxor_assign(&mut self, rhs: bool) {
-        *self ^= Self::splat(rhs);
-    }
-}
+pub type Mask8<const LANES: usize> = BitMask<LANES>;
+pub type Mask16<const LANES: usize> = BitMask<LANES>;
+pub type Mask32<const LANES: usize> = BitMask<LANES>;
+pub type Mask64<const LANES: usize> = BitMask<LANES>;
+pub type Mask128<const LANES: usize> = BitMask<LANES>;
+pub type MaskSize<const LANES: usize> = BitMask<LANES>;
diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs
index 6972a4216b6..2d1ddd6dc30 100644
--- a/crates/core_simd/src/masks/full_masks.rs
+++ b/crates/core_simd/src/masks/full_masks.rs
@@ -46,14 +46,12 @@ macro_rules! define_mask {
             }
 
             #[inline]
-            pub fn test(&self, lane: usize) -> bool {
-                assert!(lane < LANES, "lane index out of range");
+            pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
                 self.0[lane] == -1
             }
 
             #[inline]
-            pub fn set(&mut self, lane: usize, value: bool) {
-                assert!(lane < LANES, "lane index out of range");
+            pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
                 self.0[lane] = if value {
                     -1
                 } else {
@@ -70,6 +68,12 @@ macro_rules! define_mask {
             pub unsafe fn from_int_unchecked(value: crate::$type<LANES>) -> Self {
                 Self(value)
             }
+
+            #[inline]
+            pub fn to_bitmask(self) -> u64 {
+                let mask: <crate::$type<LANES> as crate::LanesAtMost32>::BitMask = unsafe { crate::intrinsics::simd_bitmask(self.0) };
+                mask.into()
+            }
         }
 
         impl<const LANES: usize> core::convert::From<$name<LANES>> for crate::$type<LANES>
@@ -81,53 +85,6 @@ macro_rules! define_mask {
             }
         }
 
-        impl<const LANES: usize> core::fmt::Debug for $name<LANES>
-        where
-            crate::$type<LANES>: crate::LanesAtMost32,
-        {
-            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
-                f.debug_list()
-                    .entries((0..LANES).map(|lane| self.test(lane)))
-                    .finish()
-            }
-        }
-
-        impl<const LANES: usize> core::fmt::Binary for $name<LANES>
-        where
-            crate::$type<LANES>: crate::LanesAtMost32,
-        {
-            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
-                core::fmt::Binary::fmt(&self.0, f)
-            }
-        }
-
-        impl<const LANES: usize> core::fmt::Octal for $name<LANES>
-        where
-            crate::$type<LANES>: crate::LanesAtMost32,
-        {
-            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
-                core::fmt::Octal::fmt(&self.0, f)
-            }
-        }
-
-        impl<const LANES: usize> core::fmt::LowerHex for $name<LANES>
-        where
-            crate::$type<LANES>: crate::LanesAtMost32,
-        {
-            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
-                core::fmt::LowerHex::fmt(&self.0, f)
-            }
-        }
-
-        impl<const LANES: usize> core::fmt::UpperHex for $name<LANES>
-        where
-            crate::$type<LANES>: crate::LanesAtMost32,
-        {
-            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
-                core::fmt::UpperHex::fmt(&self.0, f)
-            }
-        }
-
         impl<const LANES: usize> core::ops::BitAnd for $name<LANES>
         where
             crate::$type<LANES>: crate::LanesAtMost32,
diff --git a/crates/core_simd/src/masks/mod.rs b/crates/core_simd/src/masks/mod.rs
index fbb934b9642..e5352ef4d1a 100644
--- a/crates/core_simd/src/masks/mod.rs
+++ b/crates/core_simd/src/masks/mod.rs
@@ -8,6 +8,12 @@ mod mask_impl;
 
 use crate::{LanesAtMost32, SimdI16, SimdI32, SimdI64, SimdI8, SimdIsize};
 
+/// Converts masks to bitmasks, with one bit set for each lane.
+pub trait ToBitMask {
+    /// Converts this mask to a bitmask.
+    fn to_bitmask(self) -> u64;
+}
+
 macro_rules! define_opaque_mask {
     {
         $(#[$attr:meta])*
@@ -61,13 +67,53 @@ macro_rules! define_opaque_mask {
                 Self(<$inner_ty>::from_int_unchecked(value))
             }
 
+            /// Converts a vector of integers to a mask, where 0 represents `false` and -1
+            /// represents `true`.
+            ///
+            /// # Panics
+            /// Panics if any lane is not 0 or -1.
+            #[inline]
+            pub fn from_int(value: $bits_ty<LANES>) -> Self {
+                assert!(
+                    (value.lanes_eq($bits_ty::splat(0)) | value.lanes_eq($bits_ty::splat(-1))).all(),
+                    "all values must be either 0 or -1",
+                );
+                unsafe { Self::from_int_unchecked(value) }
+            }
+
+            /// Converts the mask to a vector of integers, where 0 represents `false` and -1
+            /// represents `true`.
+            #[inline]
+            pub fn to_int(self) -> $bits_ty<LANES> {
+                self.0.to_int()
+            }
+
+            /// Tests the value of the specified lane.
+            ///
+            /// # Safety
+            /// `lane` must be less than `LANES`.
+            #[inline]
+            pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
+                self.0.test_unchecked(lane)
+            }
+
             /// Tests the value of the specified lane.
             ///
             /// # Panics
             /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
             #[inline]
             pub fn test(&self, lane: usize) -> bool {
-                self.0.test(lane)
+                assert!(lane < LANES, "lane index out of range");
+                unsafe { self.test_unchecked(lane) }
+            }
+
+            /// Sets the value of the specified lane.
+            ///
+            /// # Safety
+            /// `lane` must be less than `LANES`.
+            #[inline]
+            pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
+                self.0.set_unchecked(lane, value);
             }
 
             /// Sets the value of the specified lane.
@@ -76,7 +122,44 @@ macro_rules! define_opaque_mask {
             /// Panics if `lane` is greater than or equal to the number of lanes in the vector.
             #[inline]
             pub fn set(&mut self, lane: usize, value: bool) {
-                self.0.set(lane, value);
+                assert!(lane < LANES, "lane index out of range");
+                unsafe { self.set_unchecked(lane, value); }
+            }
+        }
+
+        impl ToBitMask for $name<1> {
+            fn to_bitmask(self) -> u64 {
+                self.0.to_bitmask()
+            }
+        }
+
+        impl ToBitMask for $name<2> {
+            fn to_bitmask(self) -> u64 {
+                self.0.to_bitmask()
+            }
+        }
+
+        impl ToBitMask for $name<4> {
+            fn to_bitmask(self) -> u64 {
+                self.0.to_bitmask()
+            }
+        }
+
+        impl ToBitMask for $name<8> {
+            fn to_bitmask(self) -> u64 {
+                self.0.to_bitmask()
+            }
+        }
+
+        impl ToBitMask for $name<16> {
+            fn to_bitmask(self) -> u64 {
+                self.0.to_bitmask()
+            }
+        }
+
+        impl ToBitMask for $name<32> {
+            fn to_bitmask(self) -> u64 {
+                self.0.to_bitmask()
             }
         }
 
@@ -147,10 +230,12 @@ macro_rules! define_opaque_mask {
 
         impl<const LANES: usize> core::fmt::Debug for $name<LANES>
         where
-            $bits_ty<LANES>: LanesAtMost32,
+            $bits_ty<LANES>: crate::LanesAtMost32,
         {
             fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
-                core::fmt::Debug::fmt(&self.0, f)
+                f.debug_list()
+                    .entries((0..LANES).map(|lane| self.test(lane)))
+                    .finish()
             }
         }
 
diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs
index 6c3993e39a9..be83f4c2ec7 100644
--- a/crates/core_simd/tests/masks.rs
+++ b/crates/core_simd/tests/masks.rs
@@ -56,6 +56,23 @@ macro_rules! test_mask_api {
                 v.set(2, true);
                 assert!(!v.all());
             }
+
+            #[test]
+            fn roundtrip_int_conversion() {
+                let values = [true, false, false, true, false, false, true, false];
+                let mask = core_simd::$name::<8>::from_array(values);
+                let int = mask.to_int();
+                assert_eq!(int.to_array(), [-1, 0, 0, -1, 0, 0, -1, 0]);
+                assert_eq!(core_simd::$name::<8>::from_int(int), mask);
+            }
+
+            #[test]
+            fn to_bitmask() {
+                use core_simd::ToBitMask;
+                let values = [true, false, false, true, false, false, true, false];
+                let mask = core_simd::$name::<8>::from_array(values);
+                assert_eq!(mask.to_bitmask(), 0b01001001);
+            }
         }
     }
 }