about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--crates/core_simd/src/vector/float.rs41
-rw-r--r--crates/core_simd/tests/ops_macros.rs70
-rw-r--r--crates/test_helpers/src/lib.rs21
3 files changed, 132 insertions, 0 deletions
diff --git a/crates/core_simd/src/vector/float.rs b/crates/core_simd/src/vector/float.rs
index 5044ac57ec5..7061b9b0674 100644
--- a/crates/core_simd/src/vector/float.rs
+++ b/crates/core_simd/src/vector/float.rs
@@ -136,6 +136,47 @@ macro_rules! impl_float_vector {
                 let magnitude = self.to_bits() & !Self::splat(-0.).to_bits();
                 Self::from_bits(sign_bit | magnitude)
             }
+
+            /// Returns the minimum of each lane.
+            ///
+            /// If one of the values is `NAN`, then the other value is returned.
+            #[inline]
+            pub fn min(self, other: Self) -> Self {
+                // TODO consider using an intrinsic
+                self.is_nan().select(
+                    other,
+                    self.lanes_ge(other).select(other, self)
+                )
+            }
+
+            /// Returns the maximum of each lane.
+            ///
+            /// If one of the values is `NAN`, then the other value is returned.
+            #[inline]
+            pub fn max(self, other: Self) -> Self {
+                // TODO consider using an intrinsic
+                self.is_nan().select(
+                    other,
+                    self.lanes_le(other).select(other, self)
+                )
+            }
+
+            /// Restrict each lane to a certain interval unless it is NaN.
+            /// 
+            /// For each lane in `self`, returns the corresponding lane in `max` if the lane is
+            /// greater than `max`, and the corresponding lane in `min` if the lane is less
+            /// than `min`.  Otherwise returns the lane in `self`.
+            #[inline]
+            pub fn clamp(self, min: Self, max: Self) -> Self {
+                assert!(
+                    min.lanes_le(max).all(),
+                    "each lane in `min` must be less than or equal to the corresponding lane in `max`",
+                );
+                let mut x = self;
+                x = x.lanes_lt(min).select(min, x);
+                x = x.lanes_gt(max).select(max, x);
+                x
+            }
         }
     };
 }
diff --git a/crates/core_simd/tests/ops_macros.rs b/crates/core_simd/tests/ops_macros.rs
index 9ada95e851e..8ef2edc8370 100644
--- a/crates/core_simd/tests/ops_macros.rs
+++ b/crates/core_simd/tests/ops_macros.rs
@@ -483,6 +483,76 @@ macro_rules! impl_float_tests {
                     )
                 }
 
+                fn min<const LANES: usize>() {
+                    // Regular conditions (both values aren't zero)
+                    test_helpers::test_binary_elementwise(
+                        &Vector::<LANES>::min,
+                        &Scalar::min,
+                        // Reject the case where both values are zero with different signs
+                        &|a, b| {
+                            for (a, b) in a.iter().zip(b.iter()) {
+                                if *a == 0. && *b == 0. && a.signum() != b.signum() {
+                                    return false;
+                                }
+                            }
+                            true
+                        }
+                    );
+
+                    // Special case where both values are zero
+                    let p_zero = Vector::<LANES>::splat(0.);
+                    let n_zero = Vector::<LANES>::splat(-0.);
+                    assert!(p_zero.min(n_zero).to_array().iter().all(|x| *x == 0.));
+                    assert!(n_zero.min(p_zero).to_array().iter().all(|x| *x == 0.));
+                }
+
+                fn max<const LANES: usize>() {
+                    // Regular conditions (both values aren't zero)
+                    test_helpers::test_binary_elementwise(
+                        &Vector::<LANES>::max,
+                        &Scalar::max,
+                        // Reject the case where both values are zero with different signs
+                        &|a, b| {
+                            for (a, b) in a.iter().zip(b.iter()) {
+                                if *a == 0. && *b == 0. && a.signum() != b.signum() {
+                                    return false;
+                                }
+                            }
+                            true
+                        }
+                    );
+
+                    // Special case where both values are zero
+                    let p_zero = Vector::<LANES>::splat(0.);
+                    let n_zero = Vector::<LANES>::splat(-0.);
+                    assert!(p_zero.min(n_zero).to_array().iter().all(|x| *x == 0.));
+                    assert!(n_zero.min(p_zero).to_array().iter().all(|x| *x == 0.));
+                }
+
+                fn clamp<const LANES: usize>() {
+                    test_helpers::test_3(&|value: [Scalar; LANES], mut min: [Scalar; LANES], mut max: [Scalar; LANES]| {
+                        for (min, max) in min.iter_mut().zip(max.iter_mut()) {
+                            if max < min {
+                                core::mem::swap(min, max);
+                            }
+                            if min.is_nan() {
+                                *min = Scalar::NEG_INFINITY;
+                            }
+                            if max.is_nan() {
+                                *max = Scalar::INFINITY;
+                            }
+                        }
+
+                        let mut result_scalar = [Scalar::default(); LANES];
+                        for i in 0..LANES {
+                            result_scalar[i] = value[i].clamp(min[i], max[i]);
+                        }
+                        let result_vector = Vector::from_array(value).clamp(min.into(), max.into()).to_array();
+                        test_helpers::prop_assert_biteq!(result_scalar, result_vector);
+                        Ok(())
+                    })
+                }
+
                 fn horizontal_sum<const LANES: usize>() {
                     test_helpers::test_1(&|x| {
                         test_helpers::prop_assert_biteq! (
diff --git a/crates/test_helpers/src/lib.rs b/crates/test_helpers/src/lib.rs
index fffd088f4da..ff6d30a1afb 100644
--- a/crates/test_helpers/src/lib.rs
+++ b/crates/test_helpers/src/lib.rs
@@ -97,6 +97,27 @@ pub fn test_2<A: core::fmt::Debug + DefaultStrategy, B: core::fmt::Debug + Defau
         .unwrap();
 }
 
+/// Test a function that takes two values.
+pub fn test_3<
+    A: core::fmt::Debug + DefaultStrategy,
+    B: core::fmt::Debug + DefaultStrategy,
+    C: core::fmt::Debug + DefaultStrategy,
+>(
+    f: &dyn Fn(A, B, C) -> proptest::test_runner::TestCaseResult,
+) {
+    let mut runner = proptest::test_runner::TestRunner::default();
+    runner
+        .run(
+            &(
+                A::default_strategy(),
+                B::default_strategy(),
+                C::default_strategy(),
+            ),
+            |(a, b, c)| f(a, b, c),
+        )
+        .unwrap();
+}
+
 /// Test a unary vector function against a unary scalar function, applied elementwise.
 #[inline(never)]
 pub fn test_unary_elementwise<Scalar, ScalarResult, Vector, VectorResult, const LANES: usize>(