about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2024-11-20 19:34:47 +0000
committerGitHub <noreply@github.com>2024-11-20 19:34:47 +0000
commite6946883ec15cf651ee18d2781c8d9adc8d08d1a (patch)
treea261377046aacb8acdc61179d912b269bc72d192
parent8d3d69434945684d1b260ff8476f6e242ad177d4 (diff)
parent8a5c187948ea2db82539e91a43ea8b8e5d13029a (diff)
downloadrust-e6946883ec15cf651ee18d2781c8d9adc8d08d1a.tar.gz
rust-e6946883ec15cf651ee18d2781c8d9adc8d08d1a.zip
Merge pull request #4026 from eduardosm/soft-sqrt
miri: implement square root without relying on host floats
-rw-r--r--src/tools/miri/src/intrinsics/mod.rs42
-rw-r--r--src/tools/miri/src/intrinsics/simd.rs39
-rw-r--r--src/tools/miri/src/lib.rs1
-rw-r--r--src/tools/miri/src/math.rs164
-rw-r--r--src/tools/miri/src/shims/x86/mod.rs27
-rw-r--r--src/tools/miri/tests/pass/float.rs14
6 files changed, 219 insertions, 68 deletions
diff --git a/src/tools/miri/src/intrinsics/mod.rs b/src/tools/miri/src/intrinsics/mod.rs
index 272dca1594e..9eebbc5d363 100644
--- a/src/tools/miri/src/intrinsics/mod.rs
+++ b/src/tools/miri/src/intrinsics/mod.rs
@@ -218,20 +218,19 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
             => {
                 let [f] = check_arg_count(args)?;
                 let f = this.read_scalar(f)?.to_f32()?;
-                // Using host floats (but it's fine, these operations do not have guaranteed precision).
-                let f_host = f.to_host();
+                // Using host floats except for sqrt (but it's fine, these operations do not have
+                // guaranteed precision).
                 let res = match intrinsic_name {
-                    "sinf32" => f_host.sin(),
-                    "cosf32" => f_host.cos(),
-                    "sqrtf32" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats
-                    "expf32" => f_host.exp(),
-                    "exp2f32" => f_host.exp2(),
-                    "logf32" => f_host.ln(),
-                    "log10f32" => f_host.log10(),
-                    "log2f32" => f_host.log2(),
+                    "sinf32" => f.to_host().sin().to_soft(),
+                    "cosf32" => f.to_host().cos().to_soft(),
+                    "sqrtf32" => math::sqrt(f),
+                    "expf32" => f.to_host().exp().to_soft(),
+                    "exp2f32" => f.to_host().exp2().to_soft(),
+                    "logf32" => f.to_host().ln().to_soft(),
+                    "log10f32" => f.to_host().log10().to_soft(),
+                    "log2f32" => f.to_host().log2().to_soft(),
                     _ => bug!(),
                 };
-                let res = res.to_soft();
                 let res = this.adjust_nan(res, &[f]);
                 this.write_scalar(res, dest)?;
             }
@@ -247,20 +246,19 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
             => {
                 let [f] = check_arg_count(args)?;
                 let f = this.read_scalar(f)?.to_f64()?;
-                // Using host floats (but it's fine, these operations do not have guaranteed precision).
-                let f_host = f.to_host();
+                // Using host floats except for sqrt (but it's fine, these operations do not have
+                // guaranteed precision).
                 let res = match intrinsic_name {
-                    "sinf64" => f_host.sin(),
-                    "cosf64" => f_host.cos(),
-                    "sqrtf64" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats
-                    "expf64" => f_host.exp(),
-                    "exp2f64" => f_host.exp2(),
-                    "logf64" => f_host.ln(),
-                    "log10f64" => f_host.log10(),
-                    "log2f64" => f_host.log2(),
+                    "sinf64" => f.to_host().sin().to_soft(),
+                    "cosf64" => f.to_host().cos().to_soft(),
+                    "sqrtf64" => math::sqrt(f),
+                    "expf64" => f.to_host().exp().to_soft(),
+                    "exp2f64" => f.to_host().exp2().to_soft(),
+                    "logf64" => f.to_host().ln().to_soft(),
+                    "log10f64" => f.to_host().log10().to_soft(),
+                    "log2f64" => f.to_host().log2().to_soft(),
                     _ => bug!(),
                 };
-                let res = res.to_soft();
                 let res = this.adjust_nan(res, &[f]);
                 this.write_scalar(res, dest)?;
             }
diff --git a/src/tools/miri/src/intrinsics/simd.rs b/src/tools/miri/src/intrinsics/simd.rs
index d5c417e7231..075b6f35e0e 100644
--- a/src/tools/miri/src/intrinsics/simd.rs
+++ b/src/tools/miri/src/intrinsics/simd.rs
@@ -104,42 +104,39 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                             let ty::Float(float_ty) = op.layout.ty.kind() else {
                                 span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
                             };
-                            // Using host floats (but it's fine, these operations do not have guaranteed precision).
+                            // Using host floats except for sqrt (but it's fine, these operations do not
+                            // have guaranteed precision).
                             match float_ty {
                                 FloatTy::F16 => unimplemented!("f16_f128"),
                                 FloatTy::F32 => {
                                     let f = op.to_scalar().to_f32()?;
-                                    let f_host = f.to_host();
                                     let res = match host_op {
-                                        "fsqrt" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats
-                                        "fsin" => f_host.sin(),
-                                        "fcos" => f_host.cos(),
-                                        "fexp" => f_host.exp(),
-                                        "fexp2" => f_host.exp2(),
-                                        "flog" => f_host.ln(),
-                                        "flog2" => f_host.log2(),
-                                        "flog10" => f_host.log10(),
+                                        "fsqrt" => math::sqrt(f),
+                                        "fsin" => f.to_host().sin().to_soft(),
+                                        "fcos" => f.to_host().cos().to_soft(),
+                                        "fexp" => f.to_host().exp().to_soft(),
+                                        "fexp2" => f.to_host().exp2().to_soft(),
+                                        "flog" => f.to_host().ln().to_soft(),
+                                        "flog2" => f.to_host().log2().to_soft(),
+                                        "flog10" => f.to_host().log10().to_soft(),
                                         _ => bug!(),
                                     };
-                                    let res = res.to_soft();
                                     let res = this.adjust_nan(res, &[f]);
                                     Scalar::from(res)
                                 }
                                 FloatTy::F64 => {
                                     let f = op.to_scalar().to_f64()?;
-                                    let f_host = f.to_host();
                                     let res = match host_op {
-                                        "fsqrt" => f_host.sqrt(),
-                                        "fsin" => f_host.sin(),
-                                        "fcos" => f_host.cos(),
-                                        "fexp" => f_host.exp(),
-                                        "fexp2" => f_host.exp2(),
-                                        "flog" => f_host.ln(),
-                                        "flog2" => f_host.log2(),
-                                        "flog10" => f_host.log10(),
+                                        "fsqrt" => math::sqrt(f),
+                                        "fsin" => f.to_host().sin().to_soft(),
+                                        "fcos" => f.to_host().cos().to_soft(),
+                                        "fexp" => f.to_host().exp().to_soft(),
+                                        "fexp2" => f.to_host().exp2().to_soft(),
+                                        "flog" => f.to_host().ln().to_soft(),
+                                        "flog2" => f.to_host().log2().to_soft(),
+                                        "flog10" => f.to_host().log10().to_soft(),
                                         _ => bug!(),
                                     };
-                                    let res = res.to_soft();
                                     let res = this.adjust_nan(res, &[f]);
                                     Scalar::from(res)
                                 }
diff --git a/src/tools/miri/src/lib.rs b/src/tools/miri/src/lib.rs
index e69651631bb..85c896563da 100644
--- a/src/tools/miri/src/lib.rs
+++ b/src/tools/miri/src/lib.rs
@@ -83,6 +83,7 @@ mod eval;
 mod helpers;
 mod intrinsics;
 mod machine;
+mod math;
 mod mono_hash_map;
 mod operator;
 mod provenance_gc;
diff --git a/src/tools/miri/src/math.rs b/src/tools/miri/src/math.rs
new file mode 100644
index 00000000000..ed3d2d55678
--- /dev/null
+++ b/src/tools/miri/src/math.rs
@@ -0,0 +1,164 @@
+use rand::Rng as _;
+use rand::distributions::Distribution as _;
+use rustc_apfloat::Float as _;
+use rustc_apfloat::ieee::IeeeFloat;
+
+/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale).
+pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
+    ecx: &mut crate::MiriInterpCx<'_>,
+    val: F,
+    err_scale: i32,
+) -> F {
+    let rng = ecx.machine.rng.get_mut();
+    // Generate a random integer in the range [0, 2^PREC).
+    let dist = rand::distributions::Uniform::new(0, 1 << F::PRECISION);
+    let err = F::from_u128(dist.sample(rng))
+        .value
+        .scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap()));
+    // give it a random sign
+    let err = if rng.gen::<bool>() { -err } else { err };
+    // multiple the value with (1+err)
+    (val * (F::from_u128(1).value + err).value).value
+}
+
+pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
+    match x.category() {
+        // preserve zero sign
+        rustc_apfloat::Category::Zero => x,
+        // propagate NaN
+        rustc_apfloat::Category::NaN => x,
+        // sqrt of negative number is NaN
+        _ if x.is_negative() => IeeeFloat::NAN,
+        // sqrt(∞) = ∞
+        rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
+        rustc_apfloat::Category::Normal => {
+            // Floating point precision, excluding the integer bit
+            let prec = i32::try_from(S::PRECISION).unwrap() - 1;
+
+            // x = 2^(exp - prec) * mant
+            // where mant is an integer with prec+1 bits
+            // mant is a u128, which should be large enough for the largest prec (112 for f128)
+            let mut exp = x.ilogb();
+            let mut mant = x.scalbn(prec - exp).to_u128(128).value;
+
+            if exp % 2 != 0 {
+                // Make exponent even, so it can be divided by 2
+                exp -= 1;
+                mant <<= 1;
+            }
+
+            // Bit-by-bit (base-2 digit-by-digit) sqrt of mant.
+            // mant is treated here as a fixed point number with prec fractional bits.
+            // mant will be shifted left by one bit to have an extra fractional bit, which
+            // will be used to determine the rounding direction.
+
+            // res is the truncated sqrt of mant, where one bit is added at each iteration.
+            let mut res = 0u128;
+            // rem is the remainder with the current res
+            // rem_i = 2^i * ((mant<<1) - res_i^2)
+            // starting with res = 0, rem = mant<<1
+            let mut rem = mant << 1;
+            // s_i = 2*res_i
+            let mut s = 0u128;
+            // d is used to iterate over bits, from high to low (d_i = 2^(-i))
+            let mut d = 1u128 << (prec + 1);
+
+            // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that
+            //  (res_i + b_j * 2^(-j))^2 <= mant<<1
+            // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b:
+            //  res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1
+            // And rearranging the terms:
+            //  b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2)
+            //  b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i
+
+            while d != 0 {
+                // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1:
+                // t = 2*res_i + 2^(-j)
+                let t = s + d;
+                if rem >= t {
+                    // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem
+                    res += d;
+                    s += d + d;
+                    rem -= t;
+                }
+                // Adjust rem for next iteration
+                rem <<= 1;
+                // Shift iterator
+                d >>= 1;
+            }
+
+            // Remove extra fractional bit from result, rounding to nearest.
+            // If the last bit is 0, then the nearest neighbor is definitely the lower one.
+            // If the last bit is 1, it sounds like this may either be a tie (if there's
+            // infinitely many 0s after this 1), or the nearest neighbor is the upper one.
+            // However, since square roots are either exact or irrational, and an exact root
+            // would lead to the last "extra" bit being 0, we can exclude a tie in this case.
+            // We therefore always round up if the last bit is 1. When the last bit is 0,
+            // adding 1 will not do anything since the shift will discard it.
+            res = (res + 1) >> 1;
+
+            // Build resulting value with res as mantissa and exp/2 as exponent
+            IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, QuadS, SingleS};
+
+    use super::sqrt;
+
+    #[test]
+    fn test_sqrt() {
+        #[track_caller]
+        fn test<S: rustc_apfloat::ieee::Semantics>(x: &str, expected: &str) {
+            let x: IeeeFloat<S> = x.parse().unwrap();
+            let expected: IeeeFloat<S> = expected.parse().unwrap();
+            let result = sqrt(x);
+            assert_eq!(result, expected);
+        }
+
+        fn exact_tests<S: rustc_apfloat::ieee::Semantics>() {
+            test::<S>("0", "0");
+            test::<S>("1", "1");
+            test::<S>("1.5625", "1.25");
+            test::<S>("2.25", "1.5");
+            test::<S>("4", "2");
+            test::<S>("5.0625", "2.25");
+            test::<S>("9", "3");
+            test::<S>("16", "4");
+            test::<S>("25", "5");
+            test::<S>("36", "6");
+            test::<S>("49", "7");
+            test::<S>("64", "8");
+            test::<S>("81", "9");
+            test::<S>("100", "10");
+
+            test::<S>("0.5625", "0.75");
+            test::<S>("0.25", "0.5");
+            test::<S>("0.0625", "0.25");
+            test::<S>("0.00390625", "0.0625");
+        }
+
+        exact_tests::<HalfS>();
+        exact_tests::<SingleS>();
+        exact_tests::<DoubleS>();
+        exact_tests::<QuadS>();
+
+        test::<SingleS>("2", "1.4142135");
+        test::<DoubleS>("2", "1.4142135623730951");
+
+        test::<SingleS>("1.1", "1.0488088");
+        test::<DoubleS>("1.1", "1.0488088481701516");
+
+        test::<SingleS>("2.2", "1.4832398");
+        test::<DoubleS>("2.2", "1.4832396974191326");
+
+        test::<SingleS>("1.22101e-40", "1.10499205e-20");
+        test::<DoubleS>("1.22101e-310", "1.1049932126488395e-155");
+
+        test::<SingleS>("3.4028235e38", "1.8446743e19");
+        test::<DoubleS>("1.7976931348623157e308", "1.3407807929942596e154");
+    }
+}
diff --git a/src/tools/miri/src/shims/x86/mod.rs b/src/tools/miri/src/shims/x86/mod.rs
index 66c8f3b4c2b..3e02a4b3637 100644
--- a/src/tools/miri/src/shims/x86/mod.rs
+++ b/src/tools/miri/src/shims/x86/mod.rs
@@ -1,4 +1,3 @@
-use rand::Rng as _;
 use rustc_abi::{ExternAbi, Size};
 use rustc_apfloat::Float;
 use rustc_apfloat::ieee::Single;
@@ -408,38 +407,20 @@ fn unary_op_f32<'tcx>(
             let div = (Single::from_u128(1).value / op).value;
             // Apply a relative error with a magnitude on the order of 2^-12 to simulate the
             // inaccuracy of RCP.
-            let res = apply_random_float_error(ecx, div, -12);
+            let res = math::apply_random_float_error(ecx, div, -12);
             interp_ok(Scalar::from_f32(res))
         }
         FloatUnaryOp::Rsqrt => {
-            let op = op.to_scalar().to_u32()?;
-            // FIXME using host floats
-            let sqrt = Single::from_bits(f32::from_bits(op).sqrt().to_bits().into());
-            let rsqrt = (Single::from_u128(1).value / sqrt).value;
+            let op = op.to_scalar().to_f32()?;
+            let rsqrt = (Single::from_u128(1).value / math::sqrt(op)).value;
             // Apply a relative error with a magnitude on the order of 2^-12 to simulate the
             // inaccuracy of RSQRT.
-            let res = apply_random_float_error(ecx, rsqrt, -12);
+            let res = math::apply_random_float_error(ecx, rsqrt, -12);
             interp_ok(Scalar::from_f32(res))
         }
     }
 }
 
-/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale).
-#[expect(clippy::arithmetic_side_effects)] // floating point arithmetic cannot panic
-fn apply_random_float_error<F: rustc_apfloat::Float>(
-    ecx: &mut crate::MiriInterpCx<'_>,
-    val: F,
-    err_scale: i32,
-) -> F {
-    let rng = ecx.machine.rng.get_mut();
-    // generates rand(0, 2^64) * 2^(scale - 64) = rand(0, 1) * 2^scale
-    let err = F::from_u128(rng.gen::<u64>().into()).value.scalbn(err_scale.strict_sub(64));
-    // give it a random sign
-    let err = if rng.gen::<bool>() { -err } else { err };
-    // multiple the value with (1+err)
-    (val * (F::from_u128(1).value + err).value).value
-}
-
 /// Performs `which` operation on the first component of `op` and copies
 /// the other components. The result is stored in `dest`.
 fn unary_op_ss<'tcx>(
diff --git a/src/tools/miri/tests/pass/float.rs b/src/tools/miri/tests/pass/float.rs
index 66843ca584b..4de315e3589 100644
--- a/src/tools/miri/tests/pass/float.rs
+++ b/src/tools/miri/tests/pass/float.rs
@@ -959,10 +959,20 @@ pub fn libm() {
         unsafe { ldexp(a, b) }
     }
 
-    assert_approx_eq!(64f32.sqrt(), 8f32);
-    assert_approx_eq!(64f64.sqrt(), 8f64);
+    assert_eq!(64_f32.sqrt(), 8_f32);
+    assert_eq!(64_f64.sqrt(), 8_f64);
+    assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY);
+    assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY);
+    assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
+    assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
+    assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
+    assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
     assert!((-5.0_f32).sqrt().is_nan());
     assert!((-5.0_f64).sqrt().is_nan());
+    assert!(f32::NEG_INFINITY.sqrt().is_nan());
+    assert!(f64::NEG_INFINITY.sqrt().is_nan());
+    assert!(f32::NAN.sqrt().is_nan());
+    assert!(f64::NAN.sqrt().is_nan());
 
     assert_approx_eq!(25f32.powi(-2), 0.0016f32);
     assert_approx_eq!(23.2f64.powi(2), 538.24f64);