about summary refs log tree commit diff
diff options
context:
space:
mode:
authorCaleb Zulawski <caleb.zulawski@gmail.com>2021-06-13 19:59:17 +0000
committerCaleb Zulawski <caleb.zulawski@gmail.com>2021-06-13 19:59:17 +0000
commitf102de7c8b2f59bcdc8f27dfe42a94725c91fd36 (patch)
treecea2ddd8fbc42b50ef126ec0705e22870b5e3c1b
parent74e6262ce4ad8efb8d0addd461fdf9d25bea9538 (diff)
downloadrust-f102de7c8b2f59bcdc8f27dfe42a94725c91fd36.tar.gz
rust-f102de7c8b2f59bcdc8f27dfe42a94725c91fd36.zip
Add mul_add
-rw-r--r--crates/core_simd/src/intrinsics.rs3
-rw-r--r--crates/core_simd/src/vector/float.rs12
-rw-r--r--crates/core_simd/tests/ops_macros.rs8
-rw-r--r--crates/test_helpers/src/lib.rs41
4 files changed, 64 insertions, 0 deletions
diff --git a/crates/core_simd/src/intrinsics.rs b/crates/core_simd/src/intrinsics.rs
index 7adf4c24e10..3983beb82ec 100644
--- a/crates/core_simd/src/intrinsics.rs
+++ b/crates/core_simd/src/intrinsics.rs
@@ -49,6 +49,9 @@ extern "platform-intrinsic" {
     /// fsqrt
     pub(crate) fn simd_fsqrt<T>(x: T) -> T;
 
+    /// fma
+    pub(crate) fn simd_fma<T>(x: T, y: T, z: T) -> T;
+
     pub(crate) fn simd_eq<T, U>(x: T, y: T) -> U;
     pub(crate) fn simd_ne<T, U>(x: T, y: T) -> U;
     pub(crate) fn simd_lt<T, U>(x: T, y: T) -> U;
diff --git a/crates/core_simd/src/vector/float.rs b/crates/core_simd/src/vector/float.rs
index 7061b9b0674..4f0888f29f9 100644
--- a/crates/core_simd/src/vector/float.rs
+++ b/crates/core_simd/src/vector/float.rs
@@ -36,6 +36,18 @@ macro_rules! impl_float_vector {
                 unsafe { crate::intrinsics::simd_fabs(self) }
             }
 
+            /// Fused multiply-add.  Computes `(self * a) + b` with only one rounding error,
+            /// yielding a more accurate result than an unfused multiply-add.
+            ///
+            /// Using `mul_add` *may* be more performant than an unfused multiply-add if the target
+            /// architecture has a dedicated `fma` CPU instruction.  However, this is not always
+            /// true, and will be heavily dependent on designing algorithms with specific target
+            /// hardware in mind.
+            #[inline]
+            pub fn mul_add(self, a: Self, b: Self) -> Self {
+                unsafe { crate::intrinsics::simd_fma(self, a, b) }
+            }
+
             /// Produces a vector where every lane has the square root value
             /// of the equivalently-indexed lane in `self`
             #[inline]
diff --git a/crates/core_simd/tests/ops_macros.rs b/crates/core_simd/tests/ops_macros.rs
index 8ef2edc8370..4057f33d447 100644
--- a/crates/core_simd/tests/ops_macros.rs
+++ b/crates/core_simd/tests/ops_macros.rs
@@ -435,6 +435,14 @@ macro_rules! impl_float_tests {
                     )
                 }
 
+                fn mul_add<const LANES: usize>() {
+                    test_helpers::test_ternary_elementwise(
+                        &Vector::<LANES>::mul_add,
+                        &Scalar::mul_add,
+                        &|_, _, _| true,
+                    )
+                }
+
                 fn sqrt<const LANES: usize>() {
                     test_helpers::test_unary_elementwise(
                         &Vector::<LANES>::sqrt,
diff --git a/crates/test_helpers/src/lib.rs b/crates/test_helpers/src/lib.rs
index ff6d30a1afb..4f2380b8e5b 100644
--- a/crates/test_helpers/src/lib.rs
+++ b/crates/test_helpers/src/lib.rs
@@ -278,6 +278,47 @@ pub fn test_binary_scalar_lhs_elementwise<
     });
 }
 
+/// Test a ternary vector function against a ternary scalar function, applied elementwise.
+#[inline(never)]
+pub fn test_ternary_elementwise<
+    Scalar1,
+    Scalar2,
+    Scalar3,
+    ScalarResult,
+    Vector1,
+    Vector2,
+    Vector3,
+    VectorResult,
+    const LANES: usize,
+>(
+    fv: &dyn Fn(Vector1, Vector2, Vector3) -> VectorResult,
+    fs: &dyn Fn(Scalar1, Scalar2, Scalar3) -> ScalarResult,
+    check: &dyn Fn([Scalar1; LANES], [Scalar2; LANES], [Scalar3; LANES]) -> bool,
+) where
+    Scalar1: Copy + Default + core::fmt::Debug + DefaultStrategy,
+    Scalar2: Copy + Default + core::fmt::Debug + DefaultStrategy,
+    Scalar3: Copy + Default + core::fmt::Debug + DefaultStrategy,
+    ScalarResult: Copy + Default + biteq::BitEq + core::fmt::Debug + DefaultStrategy,
+    Vector1: Into<[Scalar1; LANES]> + From<[Scalar1; LANES]> + Copy,
+    Vector2: Into<[Scalar2; LANES]> + From<[Scalar2; LANES]> + Copy,
+    Vector3: Into<[Scalar3; LANES]> + From<[Scalar3; LANES]> + Copy,
+    VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy,
+{
+    test_3(&|x: [Scalar1; LANES], y: [Scalar2; LANES], z: [Scalar3; LANES]| {
+        proptest::prop_assume!(check(x, y, z));
+        let result_1: [ScalarResult; LANES] = fv(x.into(), y.into(), z.into()).into();
+        let result_2: [ScalarResult; LANES] = {
+            let mut result = [ScalarResult::default(); LANES];
+            for ((i1, (i2, i3)), o) in x.iter().zip(y.iter().zip(z.iter())).zip(result.iter_mut()) {
+                *o = fs(*i1, *i2, *i3);
+            }
+            result
+        };
+        crate::prop_assert_biteq!(result_1, result_2);
+        Ok(())
+    });
+}
+
 /// Expand a const-generic test into separate tests for each possible lane count.
 #[macro_export]
 macro_rules! test_lanes {