about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/core_simd/src/ops.rs1
-rw-r--r--crates/core_simd/src/ops/shift_scalar.rs58
2 files changed, 59 insertions, 0 deletions
diff --git a/crates/core_simd/src/ops.rs b/crates/core_simd/src/ops.rs
index b007456cf2c..63a96106283 100644
--- a/crates/core_simd/src/ops.rs
+++ b/crates/core_simd/src/ops.rs
@@ -6,6 +6,7 @@ use core::ops::{Shl, Shr};
 
 mod assign;
 mod deref;
+mod shift_scalar;
 mod unary;
 
 impl<I, T, const LANES: usize> core::ops::Index<I> for Simd<T, LANES>
diff --git a/crates/core_simd/src/ops/shift_scalar.rs b/crates/core_simd/src/ops/shift_scalar.rs
new file mode 100644
index 00000000000..77aac656395
--- /dev/null
+++ b/crates/core_simd/src/ops/shift_scalar.rs
@@ -0,0 +1,58 @@
+// Shift operations uniquely typically only have a scalar on the right-hand side.
+// Here, we implement shifts for scalar RHS arguments.
+
+use crate::simd::{LaneCount, Simd, SupportedLaneCount};
+
+macro_rules! impl_splatted_shifts {
+    { impl $trait:ident :: $trait_fn:ident for $ty:ty } => {
+        impl<const N: usize> core::ops::$trait<$ty> for Simd<$ty, N>
+        where
+            LaneCount<N>: SupportedLaneCount,
+        {
+            type Output = Self;
+            fn $trait_fn(self, rhs: $ty) -> Self::Output {
+                self.$trait_fn(Simd::splat(rhs))
+            }
+        }
+
+        impl<const N: usize> core::ops::$trait<&$ty> for Simd<$ty, N>
+        where
+            LaneCount<N>: SupportedLaneCount,
+        {
+            type Output = Self;
+            fn $trait_fn(self, rhs: &$ty) -> Self::Output {
+                self.$trait_fn(Simd::splat(*rhs))
+            }
+        }
+
+        impl<'lhs, const N: usize> core::ops::$trait<$ty> for &'lhs Simd<$ty, N>
+        where
+            LaneCount<N>: SupportedLaneCount,
+        {
+            type Output = Simd<$ty, N>;
+            fn $trait_fn(self, rhs: $ty) -> Self::Output {
+                self.$trait_fn(Simd::splat(rhs))
+            }
+        }
+
+        impl<'lhs, const N: usize> core::ops::$trait<&$ty> for &'lhs Simd<$ty, N>
+        where
+            LaneCount<N>: SupportedLaneCount,
+        {
+            type Output = Simd<$ty, N>;
+            fn $trait_fn(self, rhs: &$ty) -> Self::Output {
+                self.$trait_fn(Simd::splat(*rhs))
+            }
+        }
+    };
+    { $($ty:ty),* } => {
+        $(
+        impl_splatted_shifts! { impl Shl::shl for $ty }
+        impl_splatted_shifts! { impl Shr::shr for $ty }
+        )*
+    }
+}
+
+// In the past there were inference issues when generically splatting arguments.
+// Enumerate them instead.
+impl_splatted_shifts! { i8, i16, i32, i64, isize, u8, u16, u32, u64, usize }