about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJubilee Young <workingjubilee@gmail.com>2021-12-21 18:28:57 -0800
committerJubilee Young <workingjubilee@gmail.com>2021-12-22 15:37:05 -0800
commitbc326a2bbccdccb321328e7a1cde3ad3734a5953 (patch)
tree22936fbe0123f0a67a5303de8a91c3f191b45854
parent5dcd397f47a17aec3b049af2d7541530b859e47b (diff)
downloadrust-bc326a2bbccdccb321328e7a1cde3ad3734a5953.tar.gz
rust-bc326a2bbccdccb321328e7a1cde3ad3734a5953.zip
Refactor ops.rs with a recursive macro
This approaches reducing macro nesting in a slightly different way.
Instead of just flattening details, make one macro apply another.
This allows specifying all details up-front in the first macro
invocation, making it easier to audit and refactor in the future.
-rw-r--r--crates/core_simd/src/ops.rs508
1 files changed, 147 insertions, 361 deletions
diff --git a/crates/core_simd/src/ops.rs b/crates/core_simd/src/ops.rs
index e6d7e695391..6cfc8f80b53 100644
--- a/crates/core_simd/src/ops.rs
+++ b/crates/core_simd/src/ops.rs
@@ -31,27 +31,10 @@ where
     }
 }
 
-macro_rules! unsafe_base_op {
-    ($(impl<const LANES: usize> $op:ident for Simd<$scalar:ty, LANES> {
-        fn $call:ident(self, rhs: Self) -> Self::Output {
-            unsafe{ $simd_call:ident }
-        }
-    })*) => {
-        $(impl<const LANES: usize> $op for Simd<$scalar, LANES>
-            where
-                $scalar: SimdElement,
-                LaneCount<LANES>: SupportedLaneCount,
-            {
-                type Output = Self;
-
-                #[inline]
-                #[must_use = "operator returns a new vector without mutating the inputs"]
-                fn $call(self, rhs: Self) -> Self::Output {
-                    unsafe { $crate::intrinsics::$simd_call(self, rhs) }
-                }
-            }
-        )*
-    }
+macro_rules! unsafe_base {
+    ($lhs:ident, $rhs:ident, {$simd_call:ident}, $($_:tt)*) => {
+        unsafe { $crate::intrinsics::$simd_call($lhs, $rhs) }
+    };
 }
 
 /// SAFETY: This macro should not be used for anything except Shl or Shr, and passed the appropriate shift intrinsic.
@@ -64,388 +47,191 @@ macro_rules! unsafe_base_op {
 // FIXME: Consider implementing this in cg_llvm instead?
 // cg_clif defaults to this, and scalar MIR shifts also default to wrapping
 macro_rules! wrap_bitshift {
-    ($(impl<const LANES: usize> $op:ident for Simd<$int:ty, LANES> {
-        fn $call:ident(self, rhs: Self) -> Self::Output {
-            unsafe { $simd_call:ident }
+    ($lhs:ident, $rhs:ident, {$simd_call:ident}, $int:ident) => {
+        unsafe {
+            $crate::intrinsics::$simd_call($lhs, $rhs.bitand(Simd::splat(<$int>::BITS as $int - 1)))
         }
-    })*) => {
-        $(impl<const LANES: usize> $op for Simd<$int, LANES>
-        where
-            $int: SimdElement,
-            LaneCount<LANES>: SupportedLaneCount,
-        {
-            type Output = Self;
-
-            #[inline]
-            #[must_use = "operator returns a new vector without mutating the inputs"]
-            fn $call(self, rhs: Self) -> Self::Output {
-                unsafe {
-                    $crate::intrinsics::$simd_call(self, rhs.bitand(Simd::splat(<$int>::BITS as $int - 1)))
-                }
-            }
-        })*
     };
 }
 
-macro_rules! bitops {
-    ($(impl<const LANES: usize> BitOps for Simd<$int:ty, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
-     })*) => {
-        $(
-            unsafe_base_op!{
-                impl<const LANES: usize> BitAnd for Simd<$int, LANES> {
-                    fn bitand(self, rhs: Self) -> Self::Output {
-                        unsafe { simd_and }
-                    }
-                }
-
-                impl<const LANES: usize> BitOr for Simd<$int, LANES> {
-                    fn bitor(self, rhs: Self) -> Self::Output {
-                        unsafe { simd_or }
-                    }
-                }
-
-                impl<const LANES: usize> BitXor for Simd<$int, LANES> {
-                    fn bitxor(self, rhs: Self) -> Self::Output {
-                        unsafe { simd_xor }
-                    }
-                }
-            }
-            wrap_bitshift! {
-                impl<const LANES: usize> Shl for Simd<$int, LANES> {
-                    fn shl(self, rhs: Self) -> Self::Output {
-                        unsafe { simd_shl }
-                    }
-                }
-
-                impl<const LANES: usize> Shr for Simd<$int, LANES> {
-                    fn shr(self, rhs: Self) -> Self::Output {
-                        // This automatically monomorphizes to lshr or ashr, depending,
-                        // so it's fine to use it for both UInts and SInts.
-                        unsafe { simd_shr }
-                    }
-                }
-            }
-        )*
+// Division by zero is poison, according to LLVM.
+// So is dividing the MIN value of a signed integer by -1,
+// since that would return MAX + 1.
+// FIXME: Rust allows <SInt>::MIN / -1,
+// so we should probably figure out how to make that safe.
+macro_rules! int_divrem_guard {
+    (   $lhs:ident,
+        $rhs:ident,
+        {   const PANIC_ZERO: &'static str = $zero:literal;
+            const PANIC_OVERFLOW: &'static str = $overflow:literal;
+            $simd_call:ident
+        },
+        $int:ident ) => {
+        if $rhs.lanes_eq(Simd::splat(0)).any() {
+            panic!($zero);
+        } else if <$int>::MIN != 0
+            && $lhs.lanes_eq(Simd::splat(<$int>::MIN)) & $rhs.lanes_eq(Simd::splat(-1 as _))
+                != Mask::splat(false)
+        {
+            panic!($overflow);
+        } else {
+            unsafe { $crate::intrinsics::$simd_call($lhs, $rhs) }
+        }
     };
 }
 
-// Integers can always accept bitand, bitor, and bitxor.
-// The only question is how to handle shifts >= <Int>::BITS?
-// Our current solution uses wrapping logic.
-bitops! {
-    impl<const LANES: usize> BitOps for Simd<i8, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
+macro_rules! for_base_types {
+    (   T = ($($scalar:ident),*);
+        type Lhs = Simd<T, N>;
+        type Rhs = Simd<T, N>;
+        type Output = $out:ty;
+
+        impl $op:ident::$call:ident {
+            $macro_impl:ident $inner:tt
+        }) => {
+            $(
+                impl<const N: usize> $op<Self> for Simd<$scalar, N>
+                where
+                    $scalar: SimdElement,
+                    LaneCount<N>: SupportedLaneCount,
+                {
+                    type Output = $out;
+
+                    #[inline]
+                    #[must_use = "operator returns a new vector without mutating the inputs"]
+                    fn $call(self, rhs: Self) -> Self::Output {
+                        $macro_impl!(self, rhs, $inner, $scalar)
+                    }
+                })*
     }
+}
 
-    impl<const LANES: usize> BitOps for Simd<i16, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
+// A "TokenTree muncher": takes a set of scalar types `T = {};`
+// type parameters for the ops it implements, `Op::fn` names,
+// and a macro that expands into an expr, substituting in an intrinsic.
+// It passes that to for_base_types, which expands an impl for the types,
+// using the expanded expr in the function, and recurses with itself.
+//
+// tl;dr impls a set of ops::{Traits} for a set of types
+macro_rules! for_base_ops {
+    (
+        T = $types:tt;
+        type Lhs = Simd<T, N>;
+        type Rhs = Simd<T, N>;
+        type Output = $out:ident;
+        impl $op:ident::$call:ident
+            $inner:tt
+        $($rest:tt)*
+    ) => {
+        for_base_types! {
+            T = $types;
+            type Lhs = Simd<T, N>;
+            type Rhs = Simd<T, N>;
+            type Output = $out;
+            impl $op::$call
+                $inner
+        }
+        for_base_ops! {
+            T = $types;
+            type Lhs = Simd<T, N>;
+            type Rhs = Simd<T, N>;
+            type Output = $out;
+            $($rest)*
+        }
+    };
+    ($($done:tt)*) => {
+        // Done.
     }
+}
 
-    impl<const LANES: usize> BitOps for Simd<i32, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
-    }
+// Integers can always accept add, mul, sub, bitand, bitor, and bitxor.
+// For all of these operations, simd_* intrinsics apply wrapping logic.
+for_base_ops! {
+    T = (i8, i16, i32, i64, isize, u8, u16, u32, u64, usize);
+    type Lhs = Simd<T, N>;
+    type Rhs = Simd<T, N>;
+    type Output = Self;
 
-    impl<const LANES: usize> BitOps for Simd<i64, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
+    impl Add::add {
+        unsafe_base { simd_add }
     }
 
-    impl<const LANES: usize> BitOps for Simd<isize, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
+    impl Mul::mul {
+        unsafe_base { simd_mul }
     }
 
-    impl<const LANES: usize> BitOps for Simd<u8, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
+    impl Sub::sub {
+        unsafe_base { simd_sub }
     }
 
-    impl<const LANES: usize> BitOps for Simd<u16, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
+    impl BitAnd::bitand {
+        unsafe_base { simd_and }
     }
 
-    impl<const LANES: usize> BitOps for Simd<u32, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
+    impl BitOr::bitor {
+        unsafe_base { simd_or }
     }
 
-    impl<const LANES: usize> BitOps for Simd<u64, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
+    impl BitXor::bitxor {
+        unsafe_base { simd_xor }
     }
 
-    impl<const LANES: usize> BitOps for Simd<usize, LANES> {
-        fn bitand(self, rhs: Self) -> Self::Output;
-        fn bitor(self, rhs: Self) -> Self::Output;
-        fn bitxor(self, rhs: Self) -> Self::Output;
-        fn shl(self, rhs: Self) -> Self::Output;
-        fn shr(self, rhs: Self) -> Self::Output;
+    impl Div::div {
+        int_divrem_guard {
+            const PANIC_ZERO: &'static str = "attempt to divide by zero";
+            const PANIC_OVERFLOW: &'static str = "attempt to divide with overflow";
+            simd_div
+        }
     }
-}
-
-macro_rules! float_arith {
-    ($(impl<const LANES: usize> FloatArith for Simd<$float:ty, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
-     })*) => {
-        $(
-            unsafe_base_op!{
-                impl<const LANES: usize> Add for Simd<$float, LANES> {
-                    fn add(self, rhs: Self) -> Self::Output {
-                        unsafe { simd_add }
-                    }
-                }
-
-                impl<const LANES: usize> Mul for Simd<$float, LANES> {
-                    fn mul(self, rhs: Self) -> Self::Output {
-                        unsafe { simd_mul }
-                    }
-                }
-
-                impl<const LANES: usize> Sub for Simd<$float, LANES> {
-                    fn sub(self, rhs: Self) -> Self::Output {
-                        unsafe { simd_sub }
-                    }
-                }
 
-                impl<const LANES: usize> Div for Simd<$float, LANES> {
-                    fn div(self, rhs: Self) -> Self::Output {
-                        unsafe { simd_div }
-                    }
-                }
-
-                impl<const LANES: usize> Rem for Simd<$float, LANES> {
-                    fn rem(self, rhs: Self) -> Self::Output {
-                        unsafe { simd_rem }
-                    }
-                }
-            }
-        )*
-    };
-}
-
-// We don't need any special precautions here:
-// Floats always accept arithmetic ops, but may become NaN.
-float_arith! {
-    impl<const LANES: usize> FloatArith for Simd<f32, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
+    impl Rem::rem {
+        int_divrem_guard {
+            const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero";
+            const PANIC_OVERFLOW: &'static str = "attempt to calculate the remainder with overflow";
+            simd_rem
+        }
     }
 
-    impl<const LANES: usize> FloatArith for Simd<f64, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
+    // The only question is how to handle shifts >= <Int>::BITS?
+    // Our current solution uses wrapping logic.
+    impl Shl::shl {
+        wrap_bitshift { simd_shl }
     }
-}
-
-// Division by zero is poison, according to LLVM.
-// So is dividing the MIN value of a signed integer by -1,
-// since that would return MAX + 1.
-// FIXME: Rust allows <SInt>::MIN / -1,
-// so we should probably figure out how to make that safe.
-macro_rules! int_divrem_guard {
-    ($(impl<const LANES: usize> $op:ident for Simd<$sint:ty, LANES> {
-        const PANIC_ZERO: &'static str = $zero:literal;
-        const PANIC_OVERFLOW: &'static str = $overflow:literal;
-        fn $call:ident {
-            unsafe { $simd_call:ident }
-        }
-    })*) => {
-        $(impl<const LANES: usize> $op for Simd<$sint, LANES>
-        where
-            $sint: SimdElement,
-            LaneCount<LANES>: SupportedLaneCount,
-        {
-            type Output = Self;
-            #[inline]
-            #[must_use = "operator returns a new vector without mutating the inputs"]
-            fn $call(self, rhs: Self) -> Self::Output {
-                if rhs.lanes_eq(Simd::splat(0)).any() {
-                    panic!("attempt to calculate the remainder with a divisor of zero");
-                } else if <$sint>::MIN != 0 && self.lanes_eq(Simd::splat(<$sint>::MIN)) & rhs.lanes_eq(Simd::splat(-1 as _))
-                    != Mask::splat(false)
-                 {
-                    panic!("attempt to calculate the remainder with overflow");
-                } else {
-                    unsafe { $crate::intrinsics::$simd_call(self, rhs) }
-                 }
-             }
-        })*
-    };
-}
-
-macro_rules! int_arith {
-    ($(impl<const LANES: usize> IntArith for Simd<$sint:ty, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
-    })*) => {
-        $(
-        unsafe_base_op!{
-            impl<const LANES: usize> Add for Simd<$sint, LANES> {
-                fn add(self, rhs: Self) -> Self::Output {
-                    unsafe { simd_add }
-                }
-            }
-
-            impl<const LANES: usize> Mul for Simd<$sint, LANES> {
-                fn mul(self, rhs: Self) -> Self::Output {
-                    unsafe { simd_mul }
-                }
-            }
 
-            impl<const LANES: usize> Sub for Simd<$sint, LANES> {
-                fn sub(self, rhs: Self) -> Self::Output {
-                    unsafe { simd_sub }
-                }
-            }
+    impl Shr::shr {
+        wrap_bitshift {
+            // This automatically monomorphizes to lshr or ashr, depending,
+            // so it's fine to use it for both UInts and SInts.
+            simd_shr
         }
-
-        int_divrem_guard!{
-            impl<const LANES: usize> Div for Simd<$sint, LANES> {
-                const PANIC_ZERO: &'static str = "attempt to divide by zero";
-                const PANIC_OVERFLOW: &'static str = "attempt to divide with overflow";
-                fn div {
-                    unsafe { simd_div }
-                }
-            }
-
-            impl<const LANES: usize> Rem for Simd<$sint, LANES> {
-                const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero";
-                const PANIC_OVERFLOW: &'static str = "attempt to calculate the remainder with overflow";
-                fn rem {
-                    unsafe { simd_rem }
-                }
-            }
-        })*
     }
 }
 
-int_arith! {
-    impl<const LANES: usize> IntArith for Simd<i8, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
-    }
-
-    impl<const LANES: usize> IntArith for Simd<i16, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
-    }
-
-    impl<const LANES: usize> IntArith for Simd<i32, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
-    }
-
-    impl<const LANES: usize> IntArith for Simd<i64, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
-    }
-
-    impl<const LANES: usize> IntArith for Simd<isize, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
-    }
+// We don't need any special precautions here:
+// Floats always accept arithmetic ops, but may become NaN.
+for_base_ops! {
+    T = (f32, f64);
+    type Lhs = Simd<T, N>;
+    type Rhs = Simd<T, N>;
+    type Output = Self;
 
-    impl<const LANES: usize> IntArith for Simd<u8, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
+    impl Add::add {
+        unsafe_base { simd_add }
     }
 
-    impl<const LANES: usize> IntArith for Simd<u16, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
+    impl Mul::mul {
+        unsafe_base { simd_mul }
     }
 
-    impl<const LANES: usize> IntArith for Simd<u32, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
+    impl Sub::sub {
+        unsafe_base { simd_sub }
     }
 
-    impl<const LANES: usize> IntArith for Simd<u64, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
+    impl Div::div {
+        unsafe_base { simd_div }
     }
 
-    impl<const LANES: usize> IntArith for Simd<usize, LANES> {
-        fn add(self, rhs: Self) -> Self::Output;
-        fn mul(self, rhs: Self) -> Self::Output;
-        fn sub(self, rhs: Self) -> Self::Output;
-        fn div(self, rhs: Self) -> Self::Output;
-        fn rem(self, rhs: Self) -> Self::Output;
+    impl Rem::rem {
+        unsafe_base { simd_rem }
     }
 }