diff options
| author | Caleb Zulawski <caleb.zulawski@gmail.com> | 2022-06-23 01:21:58 -0400 |
|---|---|---|
| committer | Caleb Zulawski <caleb.zulawski@gmail.com> | 2022-10-29 11:55:01 -0400 |
| commit | 4076ba8a77326c70645f6c4a4351b0d84c5c898f (patch) | |
| tree | d523bd93ee272e435a9badfc6a27c5106145ac33 | |
| parent | 7e96f5dbea3fd2291f0e835a21ed0c41f6ef086e (diff) | |
| download | rust-4076ba8a77326c70645f6c4a4351b0d84c5c898f.tar.gz rust-4076ba8a77326c70645f6c4a4351b0d84c5c898f.zip | |
Implement scatter/gather with new pointer vector and add tests
| -rw-r--r-- | crates/core_simd/src/cast.rs | 132 | ||||
| -rw-r--r-- | crates/core_simd/src/elements/const_ptr.rs | 30 | ||||
| -rw-r--r-- | crates/core_simd/src/elements/mut_ptr.rs | 30 | ||||
| -rw-r--r-- | crates/core_simd/src/eq.rs | 20 | ||||
| -rw-r--r-- | crates/core_simd/src/ord.rs | 36 | ||||
| -rw-r--r-- | crates/core_simd/src/vector.rs | 13 | ||||
| -rw-r--r-- | crates/core_simd/src/vector/ptr.rs | 51 | ||||
| -rw-r--r-- | crates/core_simd/tests/pointers.rs | 43 | ||||
| -rw-r--r-- | crates/test_helpers/src/biteq.rs | 20 | ||||
| -rw-r--r-- | crates/test_helpers/src/lib.rs | 63 |
10 files changed, 277 insertions, 161 deletions
diff --git a/crates/core_simd/src/cast.rs b/crates/core_simd/src/cast.rs index e04a9042b1b..d62d3f6635d 100644 --- a/crates/core_simd/src/cast.rs +++ b/crates/core_simd/src/cast.rs @@ -1,25 +1,41 @@ -use crate::simd::SimdElement; +use crate::simd::{intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount}; /// Supporting trait for `Simd::cast`. Typically doesn't need to be used directly. -pub trait SimdCast<Target: SimdElement>: SimdElement {} +pub trait SimdCast<Target: SimdElement>: SimdElement { + #[doc(hidden)] + fn cast<const LANES: usize>(x: Simd<Self, LANES>) -> Simd<Target, LANES> + where + LaneCount<LANES>: SupportedLaneCount; +} macro_rules! into_number { + { $from:ty, $to:ty } => { + impl SimdCast<$to> for $from { + fn cast<const LANES: usize>(x: Simd<Self, LANES>) -> Simd<$to, LANES> + where + LaneCount<LANES>: SupportedLaneCount, + { + // Safety: simd_as can handle numeric conversions + unsafe { intrinsics::simd_as(x) } + } + } + }; { $($type:ty),* } => { $( - impl SimdCast<i8> for $type {} - impl SimdCast<i16> for $type {} - impl SimdCast<i32> for $type {} - impl SimdCast<i64> for $type {} - impl SimdCast<isize> for $type {} - - impl SimdCast<u8> for $type {} - impl SimdCast<u16> for $type {} - impl SimdCast<u32> for $type {} - impl SimdCast<u64> for $type {} - impl SimdCast<usize> for $type {} - - impl SimdCast<f32> for $type {} - impl SimdCast<f64> for $type {} + into_number! { $type, i8 } + into_number! { $type, i16 } + into_number! { $type, i32 } + into_number! { $type, i64 } + into_number! { $type, isize } + + into_number! { $type, u8 } + into_number! { $type, u16 } + into_number! { $type, u32 } + into_number! { $type, u64 } + into_number! { $type, usize } + + into_number! { $type, f32 } + into_number! { $type, f64 } )* } } @@ -29,17 +45,85 @@ into_number! { i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, f32, f64 } macro_rules! into_pointer { { $($type:ty),* } => { $( - impl<T> SimdCast<$type> for *const T {} - impl<T> SimdCast<$type> for *mut T {} - impl<T> SimdCast<*const T> for $type {} - impl<T> SimdCast<*mut T> for $type {} + impl<T> SimdCast<$type> for *const T { + fn cast<const LANES: usize>(x: Simd<Self, LANES>) -> Simd<$type, LANES> + where + LaneCount<LANES>: SupportedLaneCount, + { + // Safety: transmuting isize to pointers is safe + let x: Simd<isize, LANES> = unsafe { core::mem::transmute_copy(&x) }; + x.cast() + } + } + impl<T> SimdCast<$type> for *mut T { + fn cast<const LANES: usize>(x: Simd<Self, LANES>) -> Simd<$type, LANES> + where + LaneCount<LANES>: SupportedLaneCount, + { + // Safety: transmuting isize to pointers is safe + let x: Simd<isize, LANES> = unsafe { core::mem::transmute_copy(&x) }; + x.cast() + } + } + impl<T> SimdCast<*const T> for $type { + fn cast<const LANES: usize>(x: Simd<$type, LANES>) -> Simd<*const T, LANES> + where + LaneCount<LANES>: SupportedLaneCount, + { + let x: Simd<isize, LANES> = x.cast(); + // Safety: transmuting isize to pointers is safe + unsafe { core::mem::transmute_copy(&x) } + } + } + impl<T> SimdCast<*mut T> for $type { + fn cast<const LANES: usize>(x: Simd<$type, LANES>) -> Simd<*mut T, LANES> + where + LaneCount<LANES>: SupportedLaneCount, + { + let x: Simd<isize, LANES> = x.cast(); + // Safety: transmuting isize to pointers is safe + unsafe { core::mem::transmute_copy(&x) } + } + } )* } } into_pointer! { i8, i16, i32, i64, isize, u8, u16, u32, u64, usize } -impl<T, U> SimdCast<*const T> for *const U {} -impl<T, U> SimdCast<*const T> for *mut U {} -impl<T, U> SimdCast<*mut T> for *const U {} -impl<T, U> SimdCast<*mut T> for *mut U {} +impl<T, U> SimdCast<*const T> for *const U { + fn cast<const LANES: usize>(x: Simd<*const U, LANES>) -> Simd<*const T, LANES> + where + LaneCount<LANES>: SupportedLaneCount, + { + // Safety: transmuting pointers is safe + unsafe { core::mem::transmute_copy(&x) } + } +} +impl<T, U> SimdCast<*const T> for *mut U { + fn cast<const LANES: usize>(x: Simd<*mut U, LANES>) -> Simd<*const T, LANES> + where + LaneCount<LANES>: SupportedLaneCount, + { + // Safety: transmuting pointers is safe + unsafe { core::mem::transmute_copy(&x) } + } +} +impl<T, U> SimdCast<*mut T> for *const U { + fn cast<const LANES: usize>(x: Simd<*const U, LANES>) -> Simd<*mut T, LANES> + where + LaneCount<LANES>: SupportedLaneCount, + { + // Safety: transmuting pointers is safe + unsafe { core::mem::transmute_copy(&x) } + } +} +impl<T, U> SimdCast<*mut T> for *mut U { + fn cast<const LANES: usize>(x: Simd<*mut U, LANES>) -> Simd<*mut T, LANES> + where + LaneCount<LANES>: SupportedLaneCount, + { + // Safety: transmuting pointers is safe + unsafe { core::mem::transmute_copy(&x) } + } +} diff --git a/crates/core_simd/src/elements/const_ptr.rs b/crates/core_simd/src/elements/const_ptr.rs index 62365eace89..c4a254f5ab1 100644 --- a/crates/core_simd/src/elements/const_ptr.rs +++ b/crates/core_simd/src/elements/const_ptr.rs @@ -3,8 +3,8 @@ use crate::simd::{LaneCount, Mask, Simd, SimdPartialEq, SupportedLaneCount}; /// Operations on SIMD vectors of constant pointers. pub trait SimdConstPtr: Copy + Sealed { - /// Vector type representing the pointers as bits. - type Bits; + /// Vector of usize with the same number of lanes. + type Usize; /// Vector of mutable pointers to the same type. type MutPtr; @@ -18,11 +18,15 @@ pub trait SimdConstPtr: Copy + Sealed { /// Changes constness without changing the type. fn as_mut(self) -> Self::MutPtr; - /// Cast pointers to raw bits. - fn to_bits(self) -> Self::Bits; + /// Gets the "address" portion of the pointer. + /// + /// Equivalent to calling [`pointer::addr`] on each lane. + fn addr(self) -> Self::Usize; - /// Cast raw bits to pointers. - fn from_bits(bits: Self::Bits) -> Self; + /// Calculates the offset from a pointer using wrapping arithmetic. + /// + /// Equivalent to calling [`pointer::wrapping_add`] on each lane. + fn wrapping_add(self, count: Self::Usize) -> Self; } impl<T, const LANES: usize> Sealed for Simd<*const T, LANES> where @@ -34,23 +38,29 @@ impl<T, const LANES: usize> SimdConstPtr for Simd<*const T, LANES> where LaneCount<LANES>: SupportedLaneCount, { - type Bits = Simd<usize, LANES>; + type Usize = Simd<usize, LANES>; type MutPtr = Simd<*mut T, LANES>; type Mask = Mask<isize, LANES>; + #[inline] fn is_null(self) -> Self::Mask { Simd::splat(core::ptr::null()).simd_eq(self) } + #[inline] fn as_mut(self) -> Self::MutPtr { self.cast() } - fn to_bits(self) -> Self::Bits { + #[inline] + fn addr(self) -> Self::Usize { self.cast() } - fn from_bits(bits: Self::Bits) -> Self { - bits.cast() + #[inline] + fn wrapping_add(self, count: Self::Usize) -> Self { + let addr = self.addr() + (count * Simd::splat(core::mem::size_of::<T>())); + // Safety: transmuting usize to pointers is safe, even if accessing those pointers isn't. + unsafe { core::mem::transmute_copy(&addr) } } } diff --git a/crates/core_simd/src/elements/mut_ptr.rs b/crates/core_simd/src/elements/mut_ptr.rs index 8c68d42628f..5920960c49c 100644 --- a/crates/core_simd/src/elements/mut_ptr.rs +++ b/crates/core_simd/src/elements/mut_ptr.rs @@ -3,8 +3,8 @@ use crate::simd::{LaneCount, Mask, Simd, SimdPartialEq, SupportedLaneCount}; /// Operations on SIMD vectors of mutable pointers. pub trait SimdMutPtr: Copy + Sealed { - /// Vector type representing the pointers as bits. - type Bits; + /// Vector of usize with the same number of lanes. + type Usize; /// Vector of constant pointers to the same type. type ConstPtr; @@ -18,11 +18,15 @@ pub trait SimdMutPtr: Copy + Sealed { /// Changes constness without changing the type. fn as_const(self) -> Self::ConstPtr; - /// Cast pointers to raw bits. - fn to_bits(self) -> Self::Bits; + /// Gets the "address" portion of the pointer. + /// + /// Equivalent to calling [`pointer::addr`] on each lane. + fn addr(self) -> Self::Usize; - /// Cast raw bits to pointers. - fn from_bits(bits: Self::Bits) -> Self; + /// Calculates the offset from a pointer using wrapping arithmetic. + /// + /// Equivalent to calling [`pointer::wrapping_add`] on each lane. + fn wrapping_add(self, count: Self::Usize) -> Self; } impl<T, const LANES: usize> Sealed for Simd<*mut T, LANES> where LaneCount<LANES>: SupportedLaneCount @@ -32,23 +36,29 @@ impl<T, const LANES: usize> SimdMutPtr for Simd<*mut T, LANES> where LaneCount<LANES>: SupportedLaneCount, { - type Bits = Simd<usize, LANES>; + type Usize = Simd<usize, LANES>; type ConstPtr = Simd<*const T, LANES>; type Mask = Mask<isize, LANES>; + #[inline] fn is_null(self) -> Self::Mask { Simd::splat(core::ptr::null_mut()).simd_eq(self) } + #[inline] fn as_const(self) -> Self::ConstPtr { self.cast() } - fn to_bits(self) -> Self::Bits { + #[inline] + fn addr(self) -> Self::Usize { self.cast() } - fn from_bits(bits: Self::Bits) -> Self { - bits.cast() + #[inline] + fn wrapping_add(self, count: Self::Usize) -> Self { + let addr = self.addr() + (count * Simd::splat(core::mem::size_of::<T>())); + // Safety: transmuting usize to pointers is safe, even if accessing those pointers isn't. + unsafe { core::mem::transmute_copy(&addr) } } } diff --git a/crates/core_simd/src/eq.rs b/crates/core_simd/src/eq.rs index 149380746e7..80763c07272 100644 --- a/crates/core_simd/src/eq.rs +++ b/crates/core_simd/src/eq.rs @@ -1,4 +1,6 @@ -use crate::simd::{intrinsics, LaneCount, Mask, Simd, SimdElement, SupportedLaneCount}; +use crate::simd::{ + intrinsics, LaneCount, Mask, Simd, SimdConstPtr, SimdElement, SimdMutPtr, SupportedLaneCount, +}; /// Parallel `PartialEq`. pub trait SimdPartialEq { @@ -80,16 +82,12 @@ where #[inline] fn simd_eq(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_eq(self, other)) } + self.addr().simd_eq(other.addr()) } #[inline] fn simd_ne(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_ne(self, other)) } + self.addr().simd_ne(other.addr()) } } @@ -101,15 +99,11 @@ where #[inline] fn simd_eq(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_eq(self, other)) } + self.addr().simd_eq(other.addr()) } #[inline] fn simd_ne(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_ne(self, other)) } + self.addr().simd_ne(other.addr()) } } diff --git a/crates/core_simd/src/ord.rs b/crates/core_simd/src/ord.rs index 95a1ecaeeda..1ae9cd061fb 100644 --- a/crates/core_simd/src/ord.rs +++ b/crates/core_simd/src/ord.rs @@ -1,4 +1,6 @@ -use crate::simd::{intrinsics, LaneCount, Mask, Simd, SimdPartialEq, SupportedLaneCount}; +use crate::simd::{ + intrinsics, LaneCount, Mask, Simd, SimdConstPtr, SimdMutPtr, SimdPartialEq, SupportedLaneCount, +}; /// Parallel `PartialOrd`. pub trait SimdPartialOrd: SimdPartialEq { @@ -218,30 +220,22 @@ where { #[inline] fn simd_lt(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_lt(self, other)) } + self.addr().simd_lt(other.addr()) } #[inline] fn simd_le(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_le(self, other)) } + self.addr().simd_le(other.addr()) } #[inline] fn simd_gt(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_gt(self, other)) } + self.addr().simd_gt(other.addr()) } #[inline] fn simd_ge(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_ge(self, other)) } + self.addr().simd_ge(other.addr()) } } @@ -275,30 +269,22 @@ where { #[inline] fn simd_lt(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_lt(self, other)) } + self.addr().simd_lt(other.addr()) } #[inline] fn simd_le(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_le(self, other)) } + self.addr().simd_le(other.addr()) } #[inline] fn simd_gt(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_gt(self, other)) } + self.addr().simd_gt(other.addr()) } #[inline] fn simd_ge(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Mask::from_int_unchecked(intrinsics::simd_ge(self, other)) } + self.addr().simd_ge(other.addr()) } } diff --git a/crates/core_simd/src/vector.rs b/crates/core_simd/src/vector.rs index cbc8ced5a84..145394a519d 100644 --- a/crates/core_simd/src/vector.rs +++ b/crates/core_simd/src/vector.rs @@ -1,8 +1,6 @@ -// Vectors of pointers are not for public use at the current time. -pub(crate) mod ptr; - use crate::simd::{ - intrinsics, LaneCount, Mask, MaskElement, SimdCast, SimdPartialOrd, SupportedLaneCount, Swizzle, + intrinsics, LaneCount, Mask, MaskElement, SimdCast, SimdConstPtr, SimdMutPtr, SimdPartialOrd, + SupportedLaneCount, Swizzle, }; /// A SIMD vector of `LANES` elements of type `T`. `Simd<T, N>` has the same shape as [`[T; N]`](array), but operates like `T`. @@ -215,8 +213,7 @@ where where T: SimdCast<U>, { - // Safety: The input argument is a vector of a valid SIMD element type. - unsafe { intrinsics::simd_as(self) } + SimdCast::cast(self) } /// Rounds toward zero and converts to the same-width integer type, assuming that @@ -352,7 +349,7 @@ where idxs: Simd<usize, LANES>, or: Self, ) -> Self { - let base_ptr = crate::simd::ptr::SimdConstPtr::splat(slice.as_ptr()); + let base_ptr = Simd::<*const T, LANES>::splat(slice.as_ptr()); // Ferris forgive me, I have done pointer arithmetic here. let ptrs = base_ptr.wrapping_add(idxs); // Safety: The ptrs have been bounds-masked to prevent memory-unsafe reads insha'allah @@ -460,7 +457,7 @@ where // 3. &mut [T] which will become our base ptr. unsafe { // Now Entering ☢️ *mut T Zone - let base_ptr = crate::simd::ptr::SimdMutPtr::splat(slice.as_mut_ptr()); + let base_ptr = Simd::<*mut T, LANES>::splat(slice.as_mut_ptr()); // Ferris forgive me, I have done pointer arithmetic here. let ptrs = base_ptr.wrapping_add(idxs); // The ptrs have been bounds-masked to prevent memory-unsafe writes insha'allah diff --git a/crates/core_simd/src/vector/ptr.rs b/crates/core_simd/src/vector/ptr.rs deleted file mode 100644 index fa756344db9..00000000000 --- a/crates/core_simd/src/vector/ptr.rs +++ /dev/null @@ -1,51 +0,0 @@ -//! Private implementation details of public gather/scatter APIs. -use crate::simd::intrinsics; -use crate::simd::{LaneCount, Simd, SupportedLaneCount}; - -/// A vector of *const T. -#[derive(Debug, Copy, Clone)] -#[repr(simd)] -pub(crate) struct SimdConstPtr<T, const LANES: usize>([*const T; LANES]); - -impl<T, const LANES: usize> SimdConstPtr<T, LANES> -where - LaneCount<LANES>: SupportedLaneCount, - T: Sized, -{ - #[inline] - #[must_use] - pub fn splat(ptr: *const T) -> Self { - Self([ptr; LANES]) - } - - #[inline] - #[must_use] - pub fn wrapping_add(self, addend: Simd<usize, LANES>) -> Self { - // Safety: this intrinsic doesn't have a precondition - unsafe { intrinsics::simd_arith_offset(self, addend) } - } -} - -/// A vector of *mut T. Be very careful around potential aliasing. -#[derive(Debug, Copy, Clone)] -#[repr(simd)] -pub(crate) struct SimdMutPtr<T, const LANES: usize>([*mut T; LANES]); - -impl<T, const LANES: usize> SimdMutPtr<T, LANES> -where - LaneCount<LANES>: SupportedLaneCount, - T: Sized, -{ - #[inline] - #[must_use] - pub fn splat(ptr: *mut T) -> Self { - Self([ptr; LANES]) - } - - #[inline] - #[must_use] - pub fn wrapping_add(self, addend: Simd<usize, LANES>) -> Self { - // Safety: this intrinsic doesn't have a precondition - unsafe { intrinsics::simd_arith_offset(self, addend) } - } -} diff --git a/crates/core_simd/tests/pointers.rs b/crates/core_simd/tests/pointers.rs new file mode 100644 index 00000000000..df26c462f93 --- /dev/null +++ b/crates/core_simd/tests/pointers.rs @@ -0,0 +1,43 @@ +#![feature(portable_simd, strict_provenance)] + +use core_simd::{Simd, SimdConstPtr, SimdMutPtr}; + +macro_rules! common_tests { + { $constness:ident } => { + test_helpers::test_lanes! { + fn is_null<const LANES: usize>() { + test_helpers::test_unary_mask_elementwise( + &Simd::<*$constness (), LANES>::is_null, + &<*$constness ()>::is_null, + &|_| true, + ); + } + + fn addr<const LANES: usize>() { + test_helpers::test_unary_elementwise( + &Simd::<*$constness (), LANES>::addr, + &<*$constness ()>::addr, + &|_| true, + ); + } + + fn wrapping_add<const LANES: usize>() { + test_helpers::test_binary_elementwise( + &Simd::<*$constness (), LANES>::wrapping_add, + &<*$constness ()>::wrapping_add, + &|_, _| true, + ); + } + } + } +} + +mod const_ptr { + use super::*; + common_tests! { const } +} + +mod mut_ptr { + use super::*; + common_tests! { mut } +} diff --git a/crates/test_helpers/src/biteq.rs b/crates/test_helpers/src/biteq.rs index 00350e22418..7d91260d838 100644 --- a/crates/test_helpers/src/biteq.rs +++ b/crates/test_helpers/src/biteq.rs @@ -55,6 +55,26 @@ macro_rules! impl_float_biteq { impl_float_biteq! { f32, f64 } +impl<T> BitEq for *const T { + fn biteq(&self, other: &Self) -> bool { + self == other + } + + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl<T> BitEq for *mut T { + fn biteq(&self, other: &Self) -> bool { + self == other + } + + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(f, "{:?}", self) + } +} + impl<T: BitEq, const N: usize> BitEq for [T; N] { fn biteq(&self, other: &Self) -> bool { self.iter() diff --git a/crates/test_helpers/src/lib.rs b/crates/test_helpers/src/lib.rs index 650eadd12bf..5f2a928b5e4 100644 --- a/crates/test_helpers/src/lib.rs +++ b/crates/test_helpers/src/lib.rs @@ -38,6 +38,28 @@ impl_num! { usize } impl_num! { f32 } impl_num! { f64 } +impl<T> DefaultStrategy for *const T { + type Strategy = proptest::strategy::Map<proptest::num::isize::Any, fn(isize) -> *const T>; + fn default_strategy() -> Self::Strategy { + fn map<T>(x: isize) -> *const T { + x as _ + } + use proptest::strategy::Strategy; + proptest::num::isize::ANY.prop_map(map) + } +} + +impl<T> DefaultStrategy for *mut T { + type Strategy = proptest::strategy::Map<proptest::num::isize::Any, fn(isize) -> *mut T>; + fn default_strategy() -> Self::Strategy { + fn map<T>(x: isize) -> *mut T { + x as _ + } + use proptest::strategy::Strategy; + proptest::num::isize::ANY.prop_map(map) + } +} + #[cfg(not(target_arch = "wasm32"))] impl DefaultStrategy for u128 { type Strategy = proptest::num::u128::Any; @@ -135,21 +157,21 @@ pub fn test_unary_elementwise<Scalar, ScalarResult, Vector, VectorResult, const fs: &dyn Fn(Scalar) -> ScalarResult, check: &dyn Fn([Scalar; LANES]) -> bool, ) where - Scalar: Copy + Default + core::fmt::Debug + DefaultStrategy, - ScalarResult: Copy + Default + biteq::BitEq + core::fmt::Debug + DefaultStrategy, + Scalar: Copy + core::fmt::Debug + DefaultStrategy, + ScalarResult: Copy + biteq::BitEq + core::fmt::Debug + DefaultStrategy, Vector: Into<[Scalar; LANES]> + From<[Scalar; LANES]> + Copy, VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy, { test_1(&|x: [Scalar; LANES]| { proptest::prop_assume!(check(x)); let result_1: [ScalarResult; LANES] = fv(x.into()).into(); - let result_2: [ScalarResult; LANES] = { - let mut result = [ScalarResult::default(); LANES]; - for (i, o) in x.iter().zip(result.iter_mut()) { - *o = fs(*i); - } - result - }; + let result_2: [ScalarResult; LANES] = x + .iter() + .copied() + .map(fs) + .collect::<Vec<_>>() + .try_into() + .unwrap(); crate::prop_assert_biteq!(result_1, result_2); Ok(()) }); @@ -162,7 +184,7 @@ pub fn test_unary_mask_elementwise<Scalar, Vector, Mask, const LANES: usize>( fs: &dyn Fn(Scalar) -> bool, check: &dyn Fn([Scalar; LANES]) -> bool, ) where - Scalar: Copy + Default + core::fmt::Debug + DefaultStrategy, + Scalar: Copy + core::fmt::Debug + DefaultStrategy, Vector: Into<[Scalar; LANES]> + From<[Scalar; LANES]> + Copy, Mask: Into<[bool; LANES]> + From<[bool; LANES]> + Copy, { @@ -196,9 +218,9 @@ pub fn test_binary_elementwise< fs: &dyn Fn(Scalar1, Scalar2) -> ScalarResult, check: &dyn Fn([Scalar1; LANES], [Scalar2; LANES]) -> bool, ) where - Scalar1: Copy + Default + core::fmt::Debug + DefaultStrategy, - Scalar2: Copy + Default + core::fmt::Debug + DefaultStrategy, - ScalarResult: Copy + Default + biteq::BitEq + core::fmt::Debug + DefaultStrategy, + Scalar1: Copy + core::fmt::Debug + DefaultStrategy, + Scalar2: Copy + core::fmt::Debug + DefaultStrategy, + ScalarResult: Copy + biteq::BitEq + core::fmt::Debug + DefaultStrategy, Vector1: Into<[Scalar1; LANES]> + From<[Scalar1; LANES]> + Copy, Vector2: Into<[Scalar2; LANES]> + From<[Scalar2; LANES]> + Copy, VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy, @@ -206,13 +228,14 @@ pub fn test_binary_elementwise< test_2(&|x: [Scalar1; LANES], y: [Scalar2; LANES]| { proptest::prop_assume!(check(x, y)); let result_1: [ScalarResult; LANES] = fv(x.into(), y.into()).into(); - let result_2: [ScalarResult; LANES] = { - let mut result = [ScalarResult::default(); LANES]; - for ((i1, i2), o) in x.iter().zip(y.iter()).zip(result.iter_mut()) { - *o = fs(*i1, *i2); - } - result - }; + let result_2: [ScalarResult; LANES] = x + .iter() + .copied() + .zip(y.iter().copied()) + .map(|(x, y)| fs(x, y)) + .collect::<Vec<_>>() + .try_into() + .unwrap(); crate::prop_assert_biteq!(result_1, result_2); Ok(()) }); |
