about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/core/src/num/f32.rs55
-rw-r--r--library/core/tests/num/mod.rs29
2 files changed, 62 insertions, 22 deletions
diff --git a/library/core/src/num/f32.rs b/library/core/src/num/f32.rs
index 19fc4489618..22b24937cbc 100644
--- a/library/core/src/num/f32.rs
+++ b/library/core/src/num/f32.rs
@@ -1030,25 +1030,42 @@ impl f32 {
     /// ```
     #[unstable(feature = "num_midpoint", issue = "110840")]
     pub fn midpoint(self, other: f32) -> f32 {
-        const LO: f32 = f32::MIN_POSITIVE * 2.;
-        const HI: f32 = f32::MAX / 2.;
-
-        let (a, b) = (self, other);
-        let abs_a = a.abs_private();
-        let abs_b = b.abs_private();
-
-        if abs_a <= HI && abs_b <= HI {
-            // Overflow is impossible
-            (a + b) / 2.
-        } else if abs_a < LO {
-            // Not safe to halve a
-            a + (b / 2.)
-        } else if abs_b < LO {
-            // Not safe to halve b
-            (a / 2.) + b
-        } else {
-            // Not safe to halve a and b
-            (a / 2.) + (b / 2.)
+        cfg_if! {
+            if #[cfg(any(
+                    target_arch = "x86_64",
+                    target_arch = "aarch64",
+                    all(any(target_arch="riscv32", target_arch= "riscv64"), target_feature="d"),
+                    all(target_arch = "arm", target_feature="vfp2"),
+                    target_arch = "wasm32",
+                    target_arch = "wasm64",
+                ))] {
+                // whitelist the faster implementation to targets that have known good 64-bit float
+                // implementations. Falling back to the branchy code on targets that don't have
+                // 64-bit hardware floats or buggy implementations.
+                // see: https://github.com/rust-lang/rust/pull/121062#issuecomment-2123408114
+                ((f64::from(self) + f64::from(other)) / 2.0) as f32
+            } else {
+                const LO: f32 = f32::MIN_POSITIVE * 2.;
+                const HI: f32 = f32::MAX / 2.;
+
+                let (a, b) = (self, other);
+                let abs_a = a.abs_private();
+                let abs_b = b.abs_private();
+
+                if abs_a <= HI && abs_b <= HI {
+                    // Overflow is impossible
+                    (a + b) / 2.
+                } else if abs_a < LO {
+                    // Not safe to halve a
+                    a + (b / 2.)
+                } else if abs_b < LO {
+                    // Not safe to halve b
+                    (a / 2.) + b
+                } else {
+                    // Not safe to halve a and b
+                    (a / 2.) + (b / 2.)
+                }
+            }
         }
     }
 
diff --git a/library/core/tests/num/mod.rs b/library/core/tests/num/mod.rs
index 0fed854318d..9d2912c4b22 100644
--- a/library/core/tests/num/mod.rs
+++ b/library/core/tests/num/mod.rs
@@ -729,7 +729,7 @@ assume_usize_width! {
 }
 
 macro_rules! test_float {
-    ($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr) => {
+    ($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr, $max_exp:expr) => {
         mod $modname {
             #[test]
             fn min() {
@@ -880,6 +880,27 @@ macro_rules! test_float {
                 assert!(($nan as $fty).midpoint(1.0).is_nan());
                 assert!((1.0 as $fty).midpoint($nan).is_nan());
                 assert!(($nan as $fty).midpoint($nan).is_nan());
+
+                // test if large differences in magnitude are still correctly computed.
+                // NOTE: that because of how small x and y are, x + y can never overflow
+                // so (x + y) / 2.0 is always correct
+                // in particular, `2.pow(i)` will  never be at the max exponent, so it could
+                // be safely doubled, while j is significantly smaller.
+                for i in $max_exp.saturating_sub(64)..$max_exp {
+                    for j in 0..64u8 {
+                        let large = <$fty>::from(2.0f32).powi(i);
+                        // a much smaller number, such that there is no chance of overflow to test
+                        // potential double rounding in midpoint's implementation.
+                        let small = <$fty>::from(2.0f32).powi($max_exp - 1)
+                            * <$fty>::EPSILON
+                            * <$fty>::from(j);
+
+                        let naive = (large + small) / 2.0;
+                        let midpoint = large.midpoint(small);
+
+                        assert_eq!(naive, midpoint);
+                    }
+                }
             }
             #[test]
             fn rem_euclid() {
@@ -912,7 +933,8 @@ test_float!(
     f32::NAN,
     f32::MIN,
     f32::MAX,
-    f32::MIN_POSITIVE
+    f32::MIN_POSITIVE,
+    f32::MAX_EXP
 );
 test_float!(
     f64,
@@ -922,5 +944,6 @@ test_float!(
     f64::NAN,
     f64::MIN,
     f64::MAX,
-    f64::MIN_POSITIVE
+    f64::MIN_POSITIVE,
+    f64::MAX_EXP
 );