diff options
| author | Caleb Zulawski <caleb.zulawski@gmail.com> | 2021-04-19 04:31:43 +0000 |
|---|---|---|
| committer | Caleb Zulawski <caleb.zulawski@gmail.com> | 2021-04-28 21:56:11 +0000 |
| commit | eec42808aa024d354bb40ec890612c37ba4a496c (patch) | |
| tree | 85836dfa0c89687cde4351a1acd5562c6dc60577 | |
| parent | da42aa5403659a6f1f6f4bc4b65d177f13fb6536 (diff) | |
| download | rust-eec42808aa024d354bb40ec890612c37ba4a496c.tar.gz rust-eec42808aa024d354bb40ec890612c37ba4a496c.zip | |
Update bitmask API
| -rw-r--r-- | crates/core_simd/src/intrinsics.rs | 3 | ||||
| -rw-r--r-- | crates/core_simd/src/lanes_at_most_32.rs | 38 | ||||
| -rw-r--r-- | crates/core_simd/src/masks/bitmask.rs | 146 | ||||
| -rw-r--r-- | crates/core_simd/src/masks/full_masks.rs | 59 | ||||
| -rw-r--r-- | crates/core_simd/src/masks/mod.rs | 93 | ||||
| -rw-r--r-- | crates/core_simd/tests/masks.rs | 17 |
6 files changed, 196 insertions, 160 deletions
diff --git a/crates/core_simd/src/intrinsics.rs b/crates/core_simd/src/intrinsics.rs index 665dc1a51d7..1812a9c624d 100644 --- a/crates/core_simd/src/intrinsics.rs +++ b/crates/core_simd/src/intrinsics.rs @@ -76,6 +76,9 @@ extern "platform-intrinsic" { pub(crate) fn simd_reduce_and<T, U>(x: T) -> U; pub(crate) fn simd_reduce_or<T, U>(x: T) -> U; pub(crate) fn simd_reduce_xor<T, U>(x: T) -> U; + + // truncate integer vector to bitmask + pub(crate) fn simd_bitmask<T, U>(x: T) -> U; } #[cfg(feature = "std")] diff --git a/crates/core_simd/src/lanes_at_most_32.rs b/crates/core_simd/src/lanes_at_most_32.rs index 2fee9ca9189..2d84b1306ea 100644 --- a/crates/core_simd/src/lanes_at_most_32.rs +++ b/crates/core_simd/src/lanes_at_most_32.rs @@ -1,14 +1,38 @@ /// Implemented for vectors that are supported by the implementation. -pub trait LanesAtMost32 {} +pub trait LanesAtMost32: sealed::Sealed { + #[doc(hidden)] + type BitMask: Into<u64>; +} + +mod sealed { + pub trait Sealed {} +} macro_rules! impl_for { { $name:ident } => { - impl LanesAtMost32 for $name<1> {} - impl LanesAtMost32 for $name<2> {} - impl LanesAtMost32 for $name<4> {} - impl LanesAtMost32 for $name<8> {} - impl LanesAtMost32 for $name<16> {} - impl LanesAtMost32 for $name<32> {} + impl<const LANES: usize> sealed::Sealed for $name<LANES> + where + $name<LANES>: LanesAtMost32, + {} + + impl LanesAtMost32 for $name<1> { + type BitMask = u8; + } + impl LanesAtMost32 for $name<2> { + type BitMask = u8; + } + impl LanesAtMost32 for $name<4> { + type BitMask = u8; + } + impl LanesAtMost32 for $name<8> { + type BitMask = u8; + } + impl LanesAtMost32 for $name<16> { + type BitMask = u16; + } + impl LanesAtMost32 for $name<32> { + type BitMask = u32; + } } } diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs index 32e2ffb8615..bf7c70c5a3a 100644 --- a/crates/core_simd/src/masks/bitmask.rs +++ b/crates/core_simd/src/masks/bitmask.rs @@ -1,13 +1,9 @@ -use crate::LanesAtMost32; - /// A mask where each lane is represented by a single bit. #[derive(Copy, Clone, Debug, PartialOrd, PartialEq, Ord, Eq, Hash)] #[repr(transparent)] -pub struct BitMask<const LANES: usize>(u64) +pub struct BitMask<const LANES: usize>(u64); impl<const LANES: usize> BitMask<LANES> -where - Self: LanesAtMost32, { #[inline] pub fn splat(value: bool) -> Self { @@ -25,13 +21,50 @@ where #[inline] pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { - self.0 ^= ((value ^ self.test(lane)) as u64) << lane + self.0 ^= ((value ^ self.test_unchecked(lane)) as u64) << lane + } + + #[inline] + pub fn to_int<V, T>(self) -> V + where + V: Default + AsMut<[T; LANES]>, + T: From<i8>, + { + // TODO this should be an intrinsic sign-extension + let mut v = V::default(); + for i in 0..LANES { + let lane = unsafe { self.test_unchecked(i) }; + v.as_mut()[i] = (-(lane as i8)).into(); + } + v + } + + #[inline] + pub unsafe fn from_int_unchecked<V>(value: V) -> Self + where + V: crate::LanesAtMost32, + { + let mask: V::BitMask = crate::intrinsics::simd_bitmask(value); + Self(mask.into()) + } + + #[inline] + pub fn to_bitmask(self) -> u64 { + self.0 + } + + #[inline] + pub fn any(self) -> bool { + self != Self::splat(false) + } + + #[inline] + pub fn all(self) -> bool { + self == Self::splat(true) } } impl<const LANES: usize> core::ops::BitAnd for BitMask<LANES> -where - Self: LanesAtMost32, { type Output = Self; #[inline] @@ -41,8 +74,6 @@ where } impl<const LANES: usize> core::ops::BitAnd<bool> for BitMask<LANES> -where - Self: LanesAtMost32, { type Output = Self; #[inline] @@ -52,8 +83,6 @@ where } impl<const LANES: usize> core::ops::BitAnd<BitMask<LANES>> for bool -where - BitMask<LANES>: LanesAtMost32, { type Output = BitMask<LANES>; #[inline] @@ -63,8 +92,6 @@ where } impl<const LANES: usize> core::ops::BitOr for BitMask<LANES> -where - Self: LanesAtMost32, { type Output = Self; #[inline] @@ -73,31 +100,7 @@ where } } -impl<const LANES: usize> core::ops::BitOr<bool> for BitMask<LANES> -where - Self: LanesAtMost32, -{ - type Output = Self; - #[inline] - fn bitor(self, rhs: bool) -> Self { - self | Self::splat(rhs) - } -} - -impl<const LANES: usize> core::ops::BitOr<BitMask<LANES>> for bool -where - BitMask<LANES>: LanesAtMost32, -{ - type Output = BitMask<LANES>; - #[inline] - fn bitor(self, rhs: BitMask<LANES>) -> BitMask<LANES> { - BitMask::<LANES>::splat(self) | rhs - } -} - impl<const LANES: usize> core::ops::BitXor for BitMask<LANES> -where - Self: LanesAtMost32, { type Output = Self; #[inline] @@ -106,42 +109,16 @@ where } } -impl<const LANES: usize> core::ops::BitXor<bool> for BitMask<LANES> -where - Self: LanesAtMost32, -{ - type Output = Self; - #[inline] - fn bitxor(self, rhs: bool) -> Self::Output { - self ^ Self::splat(rhs) - } -} - -impl<const LANES: usize> core::ops::BitXor<BitMask<LANES>> for bool -where - BitMask<LANES>: LanesAtMost32, -{ - type Output = BitMask<LANES>; - #[inline] - fn bitxor(self, rhs: BitMask<LANES>) -> Self::Output { - BitMask::<LANES>::splat(self) ^ rhs - } -} - impl<const LANES: usize> core::ops::Not for BitMask<LANES> -where - Self: LanesAtMost32, { type Output = BitMask<LANES>; #[inline] fn not(self) -> Self::Output { - Self(!self.0) + Self(!self.0) & Self::splat(true) } } impl<const LANES: usize> core::ops::BitAndAssign for BitMask<LANES> -where - Self: LanesAtMost32, { #[inline] fn bitand_assign(&mut self, rhs: Self) { @@ -149,19 +126,7 @@ where } } -impl<const LANES: usize> core::ops::BitAndAssign<bool> for BitMask<LANES> -where - Self: LanesAtMost32, -{ - #[inline] - fn bitand_assign(&mut self, rhs: bool) { - *self &= Self::splat(rhs); - } -} - impl<const LANES: usize> core::ops::BitOrAssign for BitMask<LANES> -where - Self: LanesAtMost32, { #[inline] fn bitor_assign(&mut self, rhs: Self) { @@ -169,19 +134,7 @@ where } } -impl<const LANES: usize> core::ops::BitOrAssign<bool> for BitMask<LANES> -where - Self: LanesAtMost32, -{ - #[inline] - fn bitor_assign(&mut self, rhs: bool) { - *self |= Self::splat(rhs); - } -} - impl<const LANES: usize> core::ops::BitXorAssign for BitMask<LANES> -where - Self: LanesAtMost32, { #[inline] fn bitxor_assign(&mut self, rhs: Self) { @@ -189,12 +142,9 @@ where } } -impl<const LANES: usize> core::ops::BitXorAssign<bool> for BitMask<LANES> -where - Self: LanesAtMost32, -{ - #[inline] - fn bitxor_assign(&mut self, rhs: bool) { - *self ^= Self::splat(rhs); - } -} +pub type Mask8<const LANES: usize> = BitMask<LANES>; +pub type Mask16<const LANES: usize> = BitMask<LANES>; +pub type Mask32<const LANES: usize> = BitMask<LANES>; +pub type Mask64<const LANES: usize> = BitMask<LANES>; +pub type Mask128<const LANES: usize> = BitMask<LANES>; +pub type MaskSize<const LANES: usize> = BitMask<LANES>; diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs index 6972a4216b6..2d1ddd6dc30 100644 --- a/crates/core_simd/src/masks/full_masks.rs +++ b/crates/core_simd/src/masks/full_masks.rs @@ -46,14 +46,12 @@ macro_rules! define_mask { } #[inline] - pub fn test(&self, lane: usize) -> bool { - assert!(lane < LANES, "lane index out of range"); + pub unsafe fn test_unchecked(&self, lane: usize) -> bool { self.0[lane] == -1 } #[inline] - pub fn set(&mut self, lane: usize, value: bool) { - assert!(lane < LANES, "lane index out of range"); + pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { self.0[lane] = if value { -1 } else { @@ -70,6 +68,12 @@ macro_rules! define_mask { pub unsafe fn from_int_unchecked(value: crate::$type<LANES>) -> Self { Self(value) } + + #[inline] + pub fn to_bitmask(self) -> u64 { + let mask: <crate::$type<LANES> as crate::LanesAtMost32>::BitMask = unsafe { crate::intrinsics::simd_bitmask(self.0) }; + mask.into() + } } impl<const LANES: usize> core::convert::From<$name<LANES>> for crate::$type<LANES> @@ -81,53 +85,6 @@ macro_rules! define_mask { } } - impl<const LANES: usize> core::fmt::Debug for $name<LANES> - where - crate::$type<LANES>: crate::LanesAtMost32, - { - fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { - f.debug_list() - .entries((0..LANES).map(|lane| self.test(lane))) - .finish() - } - } - - impl<const LANES: usize> core::fmt::Binary for $name<LANES> - where - crate::$type<LANES>: crate::LanesAtMost32, - { - fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { - core::fmt::Binary::fmt(&self.0, f) - } - } - - impl<const LANES: usize> core::fmt::Octal for $name<LANES> - where - crate::$type<LANES>: crate::LanesAtMost32, - { - fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { - core::fmt::Octal::fmt(&self.0, f) - } - } - - impl<const LANES: usize> core::fmt::LowerHex for $name<LANES> - where - crate::$type<LANES>: crate::LanesAtMost32, - { - fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { - core::fmt::LowerHex::fmt(&self.0, f) - } - } - - impl<const LANES: usize> core::fmt::UpperHex for $name<LANES> - where - crate::$type<LANES>: crate::LanesAtMost32, - { - fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { - core::fmt::UpperHex::fmt(&self.0, f) - } - } - impl<const LANES: usize> core::ops::BitAnd for $name<LANES> where crate::$type<LANES>: crate::LanesAtMost32, diff --git a/crates/core_simd/src/masks/mod.rs b/crates/core_simd/src/masks/mod.rs index fbb934b9642..e5352ef4d1a 100644 --- a/crates/core_simd/src/masks/mod.rs +++ b/crates/core_simd/src/masks/mod.rs @@ -8,6 +8,12 @@ mod mask_impl; use crate::{LanesAtMost32, SimdI16, SimdI32, SimdI64, SimdI8, SimdIsize}; +/// Converts masks to bitmasks, with one bit set for each lane. +pub trait ToBitMask { + /// Converts this mask to a bitmask. + fn to_bitmask(self) -> u64; +} + macro_rules! define_opaque_mask { { $(#[$attr:meta])* @@ -61,13 +67,53 @@ macro_rules! define_opaque_mask { Self(<$inner_ty>::from_int_unchecked(value)) } + /// Converts a vector of integers to a mask, where 0 represents `false` and -1 + /// represents `true`. + /// + /// # Panics + /// Panics if any lane is not 0 or -1. + #[inline] + pub fn from_int(value: $bits_ty<LANES>) -> Self { + assert!( + (value.lanes_eq($bits_ty::splat(0)) | value.lanes_eq($bits_ty::splat(-1))).all(), + "all values must be either 0 or -1", + ); + unsafe { Self::from_int_unchecked(value) } + } + + /// Converts the mask to a vector of integers, where 0 represents `false` and -1 + /// represents `true`. + #[inline] + pub fn to_int(self) -> $bits_ty<LANES> { + self.0.to_int() + } + + /// Tests the value of the specified lane. + /// + /// # Safety + /// `lane` must be less than `LANES`. + #[inline] + pub unsafe fn test_unchecked(&self, lane: usize) -> bool { + self.0.test_unchecked(lane) + } + /// Tests the value of the specified lane. /// /// # Panics /// Panics if `lane` is greater than or equal to the number of lanes in the vector. #[inline] pub fn test(&self, lane: usize) -> bool { - self.0.test(lane) + assert!(lane < LANES, "lane index out of range"); + unsafe { self.test_unchecked(lane) } + } + + /// Sets the value of the specified lane. + /// + /// # Safety + /// `lane` must be less than `LANES`. + #[inline] + pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { + self.0.set_unchecked(lane, value); } /// Sets the value of the specified lane. @@ -76,7 +122,44 @@ macro_rules! define_opaque_mask { /// Panics if `lane` is greater than or equal to the number of lanes in the vector. #[inline] pub fn set(&mut self, lane: usize, value: bool) { - self.0.set(lane, value); + assert!(lane < LANES, "lane index out of range"); + unsafe { self.set_unchecked(lane, value); } + } + } + + impl ToBitMask for $name<1> { + fn to_bitmask(self) -> u64 { + self.0.to_bitmask() + } + } + + impl ToBitMask for $name<2> { + fn to_bitmask(self) -> u64 { + self.0.to_bitmask() + } + } + + impl ToBitMask for $name<4> { + fn to_bitmask(self) -> u64 { + self.0.to_bitmask() + } + } + + impl ToBitMask for $name<8> { + fn to_bitmask(self) -> u64 { + self.0.to_bitmask() + } + } + + impl ToBitMask for $name<16> { + fn to_bitmask(self) -> u64 { + self.0.to_bitmask() + } + } + + impl ToBitMask for $name<32> { + fn to_bitmask(self) -> u64 { + self.0.to_bitmask() } } @@ -147,10 +230,12 @@ macro_rules! define_opaque_mask { impl<const LANES: usize> core::fmt::Debug for $name<LANES> where - $bits_ty<LANES>: LanesAtMost32, + $bits_ty<LANES>: crate::LanesAtMost32, { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { - core::fmt::Debug::fmt(&self.0, f) + f.debug_list() + .entries((0..LANES).map(|lane| self.test(lane))) + .finish() } } diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs index 6c3993e39a9..be83f4c2ec7 100644 --- a/crates/core_simd/tests/masks.rs +++ b/crates/core_simd/tests/masks.rs @@ -56,6 +56,23 @@ macro_rules! test_mask_api { v.set(2, true); assert!(!v.all()); } + + #[test] + fn roundtrip_int_conversion() { + let values = [true, false, false, true, false, false, true, false]; + let mask = core_simd::$name::<8>::from_array(values); + let int = mask.to_int(); + assert_eq!(int.to_array(), [-1, 0, 0, -1, 0, 0, -1, 0]); + assert_eq!(core_simd::$name::<8>::from_int(int), mask); + } + + #[test] + fn to_bitmask() { + use core_simd::ToBitMask; + let values = [true, false, false, true, false, false, true, false]; + let mask = core_simd::$name::<8>::from_array(values); + assert_eq!(mask.to_bitmask(), 0b01001001); + } } } } |
