about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/core_simd/src/masks.rs88
1 files changed, 80 insertions, 8 deletions
diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs
index 63731342423..7af4517226a 100644
--- a/crates/core_simd/src/masks.rs
+++ b/crates/core_simd/src/masks.rs
@@ -13,7 +13,7 @@
 mod mask_impl;
 
 use crate::simd::{
-    cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount,
+    cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount,
 };
 use core::cmp::Ordering;
 use core::{fmt, mem};
@@ -35,6 +35,10 @@ mod sealed {
 
         fn eq(self, other: Self) -> bool;
 
+        fn as_usize(self) -> usize;
+
+        type Unsigned: SimdElement;
+
         const TRUE: Self;
 
         const FALSE: Self;
@@ -46,10 +50,10 @@ use sealed::Sealed;
 ///
 /// # Safety
 /// Type must be a signed integer.
-pub unsafe trait MaskElement: SimdElement + Sealed {}
+pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
 
 macro_rules! impl_element {
-    { $ty:ty } => {
+    { $ty:ty, $unsigned:ty } => {
         impl Sealed for $ty {
             #[inline]
             fn valid<const N: usize>(value: Simd<Self, N>) -> bool
@@ -62,6 +66,13 @@ macro_rules! impl_element {
             #[inline]
             fn eq(self, other: Self) -> bool { self == other }
 
+            #[inline]
+            fn as_usize(self) -> usize {
+                self as usize
+            }
+
+            type Unsigned = $unsigned;
+
             const TRUE: Self = -1;
             const FALSE: Self = 0;
         }
@@ -71,11 +82,11 @@ macro_rules! impl_element {
     }
 }
 
-impl_element! { i8 }
-impl_element! { i16 }
-impl_element! { i32 }
-impl_element! { i64 }
-impl_element! { isize }
+impl_element! { i8, u8 }
+impl_element! { i16, u16 }
+impl_element! { i32, u32 }
+impl_element! { i64, u64 }
+impl_element! { isize, usize }
 
 /// A SIMD vector mask for `N` elements of width specified by `Element`.
 ///
@@ -298,6 +309,67 @@ where
     pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
         Self(mask_impl::Mask::from_bitmask_vector(bitmask))
     }
+
+    /// Find the index of the first set element.
+    ///
+    /// ```
+    /// # #![feature(portable_simd)]
+    /// # #[cfg(feature = "as_crate")] use core_simd::simd;
+    /// # #[cfg(not(feature = "as_crate"))] use core::simd;
+    /// # use simd::mask32x8;
+    /// assert_eq!(mask32x8::splat(false).first_set(), None);
+    /// assert_eq!(mask32x8::splat(true).first_set(), Some(0));
+    ///
+    /// let mask = mask32x8::from_array([false, true, false, false, true, false, false, true]);
+    /// assert_eq!(mask.first_set(), Some(1));
+    /// ```
+    #[inline]
+    #[must_use = "method returns the index and does not mutate the original value"]
+    pub fn first_set(self) -> Option<usize> {
+        // If bitmasks are efficient, using them is better
+        if cfg!(target_feature = "sse") && N <= 64 {
+            let tz = self.to_bitmask().trailing_zeros();
+            return if tz == 64 { None } else { Some(tz as usize) };
+        }
+
+        // To find the first set index:
+        // * create a vector 0..N
+        // * replace unset mask elements in that vector with -1
+        // * perform _unsigned_ reduce-min
+        // * check if the result is -1 or an index
+
+        let index = Simd::from_array(
+            const {
+                let mut index = [0; N];
+                let mut i = 0;
+                while i < N {
+                    index[i] = i;
+                    i += 1;
+                }
+                index
+            },
+        );
+
+        // Safety: the input and output are integer vectors
+        let index: Simd<T, N> = unsafe { intrinsics::simd_cast(index) };
+
+        let masked_index = self.select(index, Self::splat(true).to_int());
+
+        // Safety: the input and output are integer vectors
+        let masked_index: Simd<T::Unsigned, N> = unsafe { intrinsics::simd_cast(masked_index) };
+
+        // Safety: the input is an integer vector
+        let min_index: T::Unsigned = unsafe { intrinsics::simd_reduce_min(masked_index) };
+
+        // Safety: the return value is the unsigned version of T
+        let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
+
+        if min_index.eq(T::TRUE) {
+            None
+        } else {
+            Some(min_index.as_usize())
+        }
+    }
 }
 
 // vector/array conversion