diff options
| author | Jubilee Young <workingjubilee@gmail.com> | 2021-12-21 18:28:57 -0800 |
|---|---|---|
| committer | Jubilee Young <workingjubilee@gmail.com> | 2021-12-22 15:37:05 -0800 |
| commit | bc326a2bbccdccb321328e7a1cde3ad3734a5953 (patch) | |
| tree | 22936fbe0123f0a67a5303de8a91c3f191b45854 | |
| parent | 5dcd397f47a17aec3b049af2d7541530b859e47b (diff) | |
| download | rust-bc326a2bbccdccb321328e7a1cde3ad3734a5953.tar.gz rust-bc326a2bbccdccb321328e7a1cde3ad3734a5953.zip | |
Refactor ops.rs with a recursive macro
This approaches reducing macro nesting in a slightly different way. Instead of just flattening details, make one macro apply another. This allows specifying all details up-front in the first macro invocation, making it easier to audit and refactor in the future.
| -rw-r--r-- | crates/core_simd/src/ops.rs | 508 |
1 files changed, 147 insertions, 361 deletions
diff --git a/crates/core_simd/src/ops.rs b/crates/core_simd/src/ops.rs index e6d7e695391..6cfc8f80b53 100644 --- a/crates/core_simd/src/ops.rs +++ b/crates/core_simd/src/ops.rs @@ -31,27 +31,10 @@ where } } -macro_rules! unsafe_base_op { - ($(impl<const LANES: usize> $op:ident for Simd<$scalar:ty, LANES> { - fn $call:ident(self, rhs: Self) -> Self::Output { - unsafe{ $simd_call:ident } - } - })*) => { - $(impl<const LANES: usize> $op for Simd<$scalar, LANES> - where - $scalar: SimdElement, - LaneCount<LANES>: SupportedLaneCount, - { - type Output = Self; - - #[inline] - #[must_use = "operator returns a new vector without mutating the inputs"] - fn $call(self, rhs: Self) -> Self::Output { - unsafe { $crate::intrinsics::$simd_call(self, rhs) } - } - } - )* - } +macro_rules! unsafe_base { + ($lhs:ident, $rhs:ident, {$simd_call:ident}, $($_:tt)*) => { + unsafe { $crate::intrinsics::$simd_call($lhs, $rhs) } + }; } /// SAFETY: This macro should not be used for anything except Shl or Shr, and passed the appropriate shift intrinsic. @@ -64,388 +47,191 @@ macro_rules! unsafe_base_op { // FIXME: Consider implementing this in cg_llvm instead? // cg_clif defaults to this, and scalar MIR shifts also default to wrapping macro_rules! wrap_bitshift { - ($(impl<const LANES: usize> $op:ident for Simd<$int:ty, LANES> { - fn $call:ident(self, rhs: Self) -> Self::Output { - unsafe { $simd_call:ident } + ($lhs:ident, $rhs:ident, {$simd_call:ident}, $int:ident) => { + unsafe { + $crate::intrinsics::$simd_call($lhs, $rhs.bitand(Simd::splat(<$int>::BITS as $int - 1))) } - })*) => { - $(impl<const LANES: usize> $op for Simd<$int, LANES> - where - $int: SimdElement, - LaneCount<LANES>: SupportedLaneCount, - { - type Output = Self; - - #[inline] - #[must_use = "operator returns a new vector without mutating the inputs"] - fn $call(self, rhs: Self) -> Self::Output { - unsafe { - $crate::intrinsics::$simd_call(self, rhs.bitand(Simd::splat(<$int>::BITS as $int - 1))) - } - } - })* }; } -macro_rules! bitops { - ($(impl<const LANES: usize> BitOps for Simd<$int:ty, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - })*) => { - $( - unsafe_base_op!{ - impl<const LANES: usize> BitAnd for Simd<$int, LANES> { - fn bitand(self, rhs: Self) -> Self::Output { - unsafe { simd_and } - } - } - - impl<const LANES: usize> BitOr for Simd<$int, LANES> { - fn bitor(self, rhs: Self) -> Self::Output { - unsafe { simd_or } - } - } - - impl<const LANES: usize> BitXor for Simd<$int, LANES> { - fn bitxor(self, rhs: Self) -> Self::Output { - unsafe { simd_xor } - } - } - } - wrap_bitshift! { - impl<const LANES: usize> Shl for Simd<$int, LANES> { - fn shl(self, rhs: Self) -> Self::Output { - unsafe { simd_shl } - } - } - - impl<const LANES: usize> Shr for Simd<$int, LANES> { - fn shr(self, rhs: Self) -> Self::Output { - // This automatically monomorphizes to lshr or ashr, depending, - // so it's fine to use it for both UInts and SInts. - unsafe { simd_shr } - } - } - } - )* +// Division by zero is poison, according to LLVM. +// So is dividing the MIN value of a signed integer by -1, +// since that would return MAX + 1. +// FIXME: Rust allows <SInt>::MIN / -1, +// so we should probably figure out how to make that safe. +macro_rules! int_divrem_guard { + ( $lhs:ident, + $rhs:ident, + { const PANIC_ZERO: &'static str = $zero:literal; + const PANIC_OVERFLOW: &'static str = $overflow:literal; + $simd_call:ident + }, + $int:ident ) => { + if $rhs.lanes_eq(Simd::splat(0)).any() { + panic!($zero); + } else if <$int>::MIN != 0 + && $lhs.lanes_eq(Simd::splat(<$int>::MIN)) & $rhs.lanes_eq(Simd::splat(-1 as _)) + != Mask::splat(false) + { + panic!($overflow); + } else { + unsafe { $crate::intrinsics::$simd_call($lhs, $rhs) } + } }; } -// Integers can always accept bitand, bitor, and bitxor. -// The only question is how to handle shifts >= <Int>::BITS? -// Our current solution uses wrapping logic. -bitops! { - impl<const LANES: usize> BitOps for Simd<i8, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; +macro_rules! for_base_types { + ( T = ($($scalar:ident),*); + type Lhs = Simd<T, N>; + type Rhs = Simd<T, N>; + type Output = $out:ty; + + impl $op:ident::$call:ident { + $macro_impl:ident $inner:tt + }) => { + $( + impl<const N: usize> $op<Self> for Simd<$scalar, N> + where + $scalar: SimdElement, + LaneCount<N>: SupportedLaneCount, + { + type Output = $out; + + #[inline] + #[must_use = "operator returns a new vector without mutating the inputs"] + fn $call(self, rhs: Self) -> Self::Output { + $macro_impl!(self, rhs, $inner, $scalar) + } + })* } +} - impl<const LANES: usize> BitOps for Simd<i16, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; +// A "TokenTree muncher": takes a set of scalar types `T = {};` +// type parameters for the ops it implements, `Op::fn` names, +// and a macro that expands into an expr, substituting in an intrinsic. +// It passes that to for_base_types, which expands an impl for the types, +// using the expanded expr in the function, and recurses with itself. +// +// tl;dr impls a set of ops::{Traits} for a set of types +macro_rules! for_base_ops { + ( + T = $types:tt; + type Lhs = Simd<T, N>; + type Rhs = Simd<T, N>; + type Output = $out:ident; + impl $op:ident::$call:ident + $inner:tt + $($rest:tt)* + ) => { + for_base_types! { + T = $types; + type Lhs = Simd<T, N>; + type Rhs = Simd<T, N>; + type Output = $out; + impl $op::$call + $inner + } + for_base_ops! { + T = $types; + type Lhs = Simd<T, N>; + type Rhs = Simd<T, N>; + type Output = $out; + $($rest)* + } + }; + ($($done:tt)*) => { + // Done. } +} - impl<const LANES: usize> BitOps for Simd<i32, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; - } +// Integers can always accept add, mul, sub, bitand, bitor, and bitxor. +// For all of these operations, simd_* intrinsics apply wrapping logic. +for_base_ops! { + T = (i8, i16, i32, i64, isize, u8, u16, u32, u64, usize); + type Lhs = Simd<T, N>; + type Rhs = Simd<T, N>; + type Output = Self; - impl<const LANES: usize> BitOps for Simd<i64, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; + impl Add::add { + unsafe_base { simd_add } } - impl<const LANES: usize> BitOps for Simd<isize, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; + impl Mul::mul { + unsafe_base { simd_mul } } - impl<const LANES: usize> BitOps for Simd<u8, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; + impl Sub::sub { + unsafe_base { simd_sub } } - impl<const LANES: usize> BitOps for Simd<u16, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; + impl BitAnd::bitand { + unsafe_base { simd_and } } - impl<const LANES: usize> BitOps for Simd<u32, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; + impl BitOr::bitor { + unsafe_base { simd_or } } - impl<const LANES: usize> BitOps for Simd<u64, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; + impl BitXor::bitxor { + unsafe_base { simd_xor } } - impl<const LANES: usize> BitOps for Simd<usize, LANES> { - fn bitand(self, rhs: Self) -> Self::Output; - fn bitor(self, rhs: Self) -> Self::Output; - fn bitxor(self, rhs: Self) -> Self::Output; - fn shl(self, rhs: Self) -> Self::Output; - fn shr(self, rhs: Self) -> Self::Output; + impl Div::div { + int_divrem_guard { + const PANIC_ZERO: &'static str = "attempt to divide by zero"; + const PANIC_OVERFLOW: &'static str = "attempt to divide with overflow"; + simd_div + } } -} - -macro_rules! float_arith { - ($(impl<const LANES: usize> FloatArith for Simd<$float:ty, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - })*) => { - $( - unsafe_base_op!{ - impl<const LANES: usize> Add for Simd<$float, LANES> { - fn add(self, rhs: Self) -> Self::Output { - unsafe { simd_add } - } - } - - impl<const LANES: usize> Mul for Simd<$float, LANES> { - fn mul(self, rhs: Self) -> Self::Output { - unsafe { simd_mul } - } - } - - impl<const LANES: usize> Sub for Simd<$float, LANES> { - fn sub(self, rhs: Self) -> Self::Output { - unsafe { simd_sub } - } - } - impl<const LANES: usize> Div for Simd<$float, LANES> { - fn div(self, rhs: Self) -> Self::Output { - unsafe { simd_div } - } - } - - impl<const LANES: usize> Rem for Simd<$float, LANES> { - fn rem(self, rhs: Self) -> Self::Output { - unsafe { simd_rem } - } - } - } - )* - }; -} - -// We don't need any special precautions here: -// Floats always accept arithmetic ops, but may become NaN. -float_arith! { - impl<const LANES: usize> FloatArith for Simd<f32, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Rem::rem { + int_divrem_guard { + const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero"; + const PANIC_OVERFLOW: &'static str = "attempt to calculate the remainder with overflow"; + simd_rem + } } - impl<const LANES: usize> FloatArith for Simd<f64, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + // The only question is how to handle shifts >= <Int>::BITS? + // Our current solution uses wrapping logic. + impl Shl::shl { + wrap_bitshift { simd_shl } } -} - -// Division by zero is poison, according to LLVM. -// So is dividing the MIN value of a signed integer by -1, -// since that would return MAX + 1. -// FIXME: Rust allows <SInt>::MIN / -1, -// so we should probably figure out how to make that safe. -macro_rules! int_divrem_guard { - ($(impl<const LANES: usize> $op:ident for Simd<$sint:ty, LANES> { - const PANIC_ZERO: &'static str = $zero:literal; - const PANIC_OVERFLOW: &'static str = $overflow:literal; - fn $call:ident { - unsafe { $simd_call:ident } - } - })*) => { - $(impl<const LANES: usize> $op for Simd<$sint, LANES> - where - $sint: SimdElement, - LaneCount<LANES>: SupportedLaneCount, - { - type Output = Self; - #[inline] - #[must_use = "operator returns a new vector without mutating the inputs"] - fn $call(self, rhs: Self) -> Self::Output { - if rhs.lanes_eq(Simd::splat(0)).any() { - panic!("attempt to calculate the remainder with a divisor of zero"); - } else if <$sint>::MIN != 0 && self.lanes_eq(Simd::splat(<$sint>::MIN)) & rhs.lanes_eq(Simd::splat(-1 as _)) - != Mask::splat(false) - { - panic!("attempt to calculate the remainder with overflow"); - } else { - unsafe { $crate::intrinsics::$simd_call(self, rhs) } - } - } - })* - }; -} - -macro_rules! int_arith { - ($(impl<const LANES: usize> IntArith for Simd<$sint:ty, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - })*) => { - $( - unsafe_base_op!{ - impl<const LANES: usize> Add for Simd<$sint, LANES> { - fn add(self, rhs: Self) -> Self::Output { - unsafe { simd_add } - } - } - - impl<const LANES: usize> Mul for Simd<$sint, LANES> { - fn mul(self, rhs: Self) -> Self::Output { - unsafe { simd_mul } - } - } - impl<const LANES: usize> Sub for Simd<$sint, LANES> { - fn sub(self, rhs: Self) -> Self::Output { - unsafe { simd_sub } - } - } + impl Shr::shr { + wrap_bitshift { + // This automatically monomorphizes to lshr or ashr, depending, + // so it's fine to use it for both UInts and SInts. + simd_shr } - - int_divrem_guard!{ - impl<const LANES: usize> Div for Simd<$sint, LANES> { - const PANIC_ZERO: &'static str = "attempt to divide by zero"; - const PANIC_OVERFLOW: &'static str = "attempt to divide with overflow"; - fn div { - unsafe { simd_div } - } - } - - impl<const LANES: usize> Rem for Simd<$sint, LANES> { - const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero"; - const PANIC_OVERFLOW: &'static str = "attempt to calculate the remainder with overflow"; - fn rem { - unsafe { simd_rem } - } - } - })* } } -int_arith! { - impl<const LANES: usize> IntArith for Simd<i8, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - } - - impl<const LANES: usize> IntArith for Simd<i16, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - } - - impl<const LANES: usize> IntArith for Simd<i32, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - } - - impl<const LANES: usize> IntArith for Simd<i64, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - } - - impl<const LANES: usize> IntArith for Simd<isize, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; - } +// We don't need any special precautions here: +// Floats always accept arithmetic ops, but may become NaN. +for_base_ops! { + T = (f32, f64); + type Lhs = Simd<T, N>; + type Rhs = Simd<T, N>; + type Output = Self; - impl<const LANES: usize> IntArith for Simd<u8, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Add::add { + unsafe_base { simd_add } } - impl<const LANES: usize> IntArith for Simd<u16, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Mul::mul { + unsafe_base { simd_mul } } - impl<const LANES: usize> IntArith for Simd<u32, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Sub::sub { + unsafe_base { simd_sub } } - impl<const LANES: usize> IntArith for Simd<u64, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Div::div { + unsafe_base { simd_div } } - impl<const LANES: usize> IntArith for Simd<usize, LANES> { - fn add(self, rhs: Self) -> Self::Output; - fn mul(self, rhs: Self) -> Self::Output; - fn sub(self, rhs: Self) -> Self::Output; - fn div(self, rhs: Self) -> Self::Output; - fn rem(self, rhs: Self) -> Self::Output; + impl Rem::rem { + unsafe_base { simd_rem } } } |
