about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJubilee <46493976+workingjubilee@users.noreply.github.com>2024-06-02 12:58:07 -0700
committerGitHub <noreply@github.com>2024-06-02 12:58:07 -0700
commit713cdcd8036fe6a8639100efc93a31a58756356e (patch)
treeb273b0cbe9cf00a3f0d8645a4bee2072b09a7f22
parenteda9d7f987de76b9d61c633a6ac328936e1b94f0 (diff)
parent849c5254afdeb96a740cfac989a382fde69b6067 (diff)
downloadrust-713cdcd8036fe6a8639100efc93a31a58756356e.tar.gz
rust-713cdcd8036fe6a8639100efc93a31a58756356e.zip
Rollup merge of #121062 - RustyYato:f32-midpoint, r=the8472
Change f32::midpoint to upcast to f64

This has been verified by kani as a correct optimization

see: https://github.com/rust-lang/rust/issues/110840#issuecomment-1942587398

The new implementation is branchless and only differs in which NaN values are produced (if any are produced at all), which is fine to change. Aside from NaN handling, this implementation produces bitwise identical results to the original implementation.

Question: do we need a codegen test for this? I didn't add one, since the original PR #92048 didn't have any codegen tests.
-rw-r--r--library/core/src/num/f32.rs55
-rw-r--r--library/core/tests/num/mod.rs29
2 files changed, 62 insertions, 22 deletions
diff --git a/library/core/src/num/f32.rs b/library/core/src/num/f32.rs
index 19fc4489618..22b24937cbc 100644
--- a/library/core/src/num/f32.rs
+++ b/library/core/src/num/f32.rs
@@ -1030,25 +1030,42 @@ impl f32 {
     /// ```
     #[unstable(feature = "num_midpoint", issue = "110840")]
     pub fn midpoint(self, other: f32) -> f32 {
-        const LO: f32 = f32::MIN_POSITIVE * 2.;
-        const HI: f32 = f32::MAX / 2.;
-
-        let (a, b) = (self, other);
-        let abs_a = a.abs_private();
-        let abs_b = b.abs_private();
-
-        if abs_a <= HI && abs_b <= HI {
-            // Overflow is impossible
-            (a + b) / 2.
-        } else if abs_a < LO {
-            // Not safe to halve a
-            a + (b / 2.)
-        } else if abs_b < LO {
-            // Not safe to halve b
-            (a / 2.) + b
-        } else {
-            // Not safe to halve a and b
-            (a / 2.) + (b / 2.)
+        cfg_if! {
+            if #[cfg(any(
+                    target_arch = "x86_64",
+                    target_arch = "aarch64",
+                    all(any(target_arch="riscv32", target_arch= "riscv64"), target_feature="d"),
+                    all(target_arch = "arm", target_feature="vfp2"),
+                    target_arch = "wasm32",
+                    target_arch = "wasm64",
+                ))] {
+                // whitelist the faster implementation to targets that have known good 64-bit float
+                // implementations. Falling back to the branchy code on targets that don't have
+                // 64-bit hardware floats or buggy implementations.
+                // see: https://github.com/rust-lang/rust/pull/121062#issuecomment-2123408114
+                ((f64::from(self) + f64::from(other)) / 2.0) as f32
+            } else {
+                const LO: f32 = f32::MIN_POSITIVE * 2.;
+                const HI: f32 = f32::MAX / 2.;
+
+                let (a, b) = (self, other);
+                let abs_a = a.abs_private();
+                let abs_b = b.abs_private();
+
+                if abs_a <= HI && abs_b <= HI {
+                    // Overflow is impossible
+                    (a + b) / 2.
+                } else if abs_a < LO {
+                    // Not safe to halve a
+                    a + (b / 2.)
+                } else if abs_b < LO {
+                    // Not safe to halve b
+                    (a / 2.) + b
+                } else {
+                    // Not safe to halve a and b
+                    (a / 2.) + (b / 2.)
+                }
+            }
         }
     }
 
diff --git a/library/core/tests/num/mod.rs b/library/core/tests/num/mod.rs
index 0fed854318d..9d2912c4b22 100644
--- a/library/core/tests/num/mod.rs
+++ b/library/core/tests/num/mod.rs
@@ -729,7 +729,7 @@ assume_usize_width! {
 }
 
 macro_rules! test_float {
-    ($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr) => {
+    ($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr, $max_exp:expr) => {
         mod $modname {
             #[test]
             fn min() {
@@ -880,6 +880,27 @@ macro_rules! test_float {
                 assert!(($nan as $fty).midpoint(1.0).is_nan());
                 assert!((1.0 as $fty).midpoint($nan).is_nan());
                 assert!(($nan as $fty).midpoint($nan).is_nan());
+
+                // test if large differences in magnitude are still correctly computed.
+                // NOTE: that because of how small x and y are, x + y can never overflow
+                // so (x + y) / 2.0 is always correct
+                // in particular, `2.pow(i)` will  never be at the max exponent, so it could
+                // be safely doubled, while j is significantly smaller.
+                for i in $max_exp.saturating_sub(64)..$max_exp {
+                    for j in 0..64u8 {
+                        let large = <$fty>::from(2.0f32).powi(i);
+                        // a much smaller number, such that there is no chance of overflow to test
+                        // potential double rounding in midpoint's implementation.
+                        let small = <$fty>::from(2.0f32).powi($max_exp - 1)
+                            * <$fty>::EPSILON
+                            * <$fty>::from(j);
+
+                        let naive = (large + small) / 2.0;
+                        let midpoint = large.midpoint(small);
+
+                        assert_eq!(naive, midpoint);
+                    }
+                }
             }
             #[test]
             fn rem_euclid() {
@@ -912,7 +933,8 @@ test_float!(
     f32::NAN,
     f32::MIN,
     f32::MAX,
-    f32::MIN_POSITIVE
+    f32::MIN_POSITIVE,
+    f32::MAX_EXP
 );
 test_float!(
     f64,
@@ -922,5 +944,6 @@ test_float!(
     f64::NAN,
     f64::MIN,
     f64::MAX,
-    f64::MIN_POSITIVE
+    f64::MIN_POSITIVE,
+    f64::MAX_EXP
 );