about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2025-09-18 20:47:26 +0000
committerGitHub <noreply@github.com>2025-09-18 20:47:26 +0000
commit11ffdb54f63c8eac2322a3ac1a3e4e4b8b5a99a6 (patch)
tree464e5e7c25ed258b9d69396be5555a4c040fd275 /src
parent045e5e3586375db8464c5cf88e4ea89fbcefe60f (diff)
parent77f2d865549e8f2e1ad656c8202de9bfa6d2e354 (diff)
downloadrust-11ffdb54f63c8eac2322a3ac1a3e4e4b8b5a99a6.tar.gz
rust-11ffdb54f63c8eac2322a3ac1a3e4e4b8b5a99a6.zip
Merge pull request #4592 from RalfJung/sqrt
implement sqrt for f16 and f128
Diffstat (limited to 'src')
-rw-r--r--src/tools/miri/src/intrinsics/math.rs36
-rw-r--r--src/tools/miri/src/math.rs12
-rw-r--r--src/tools/miri/tests/pass/float.rs44
3 files changed, 54 insertions, 38 deletions
diff --git a/src/tools/miri/src/intrinsics/math.rs b/src/tools/miri/src/intrinsics/math.rs
index 21d4a92e7d2..b9c99f28594 100644
--- a/src/tools/miri/src/intrinsics/math.rs
+++ b/src/tools/miri/src/intrinsics/math.rs
@@ -1,5 +1,5 @@
 use rand::Rng;
-use rustc_apfloat::{self, Float, Round};
+use rustc_apfloat::{self, Float, FloatConvert, Round};
 use rustc_middle::mir;
 use rustc_middle::ty::{self, FloatTy};
 
@@ -7,6 +7,20 @@ use self::helpers::{ToHost, ToSoft};
 use super::check_intrinsic_arg_count;
 use crate::*;
 
+fn sqrt<'tcx, F: Float + FloatConvert<F> + Into<Scalar>>(
+    this: &mut MiriInterpCx<'tcx>,
+    args: &[OpTy<'tcx>],
+    dest: &MPlaceTy<'tcx>,
+) -> InterpResult<'tcx> {
+    let [f] = check_intrinsic_arg_count(args)?;
+    let f = this.read_scalar(f)?;
+    let f: F = f.to_float()?;
+    // Sqrt is specified to be fully precise.
+    let res = math::sqrt(f);
+    let res = this.adjust_nan(res, &[f]);
+    this.write_scalar(res, dest)
+}
+
 impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
 pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
     fn emulate_math_intrinsic(
@@ -20,22 +34,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
 
         match intrinsic_name {
             // Operations we can do with soft-floats.
-            "sqrtf32" => {
-                let [f] = check_intrinsic_arg_count(args)?;
-                let f = this.read_scalar(f)?.to_f32()?;
-                // Sqrt is specified to be fully precise.
-                let res = math::sqrt(f);
-                let res = this.adjust_nan(res, &[f]);
-                this.write_scalar(res, dest)?;
-            }
-            "sqrtf64" => {
-                let [f] = check_intrinsic_arg_count(args)?;
-                let f = this.read_scalar(f)?.to_f64()?;
-                // Sqrt is specified to be fully precise.
-                let res = math::sqrt(f);
-                let res = this.adjust_nan(res, &[f]);
-                this.write_scalar(res, dest)?;
-            }
+            "sqrtf16" => sqrt::<rustc_apfloat::ieee::Half>(this, args, dest)?,
+            "sqrtf32" => sqrt::<rustc_apfloat::ieee::Single>(this, args, dest)?,
+            "sqrtf64" => sqrt::<rustc_apfloat::ieee::Double>(this, args, dest)?,
+            "sqrtf128" => sqrt::<rustc_apfloat::ieee::Quad>(this, args, dest)?,
 
             "fmaf32" => {
                 let [a, b, c] = check_intrinsic_arg_count(args)?;
diff --git a/src/tools/miri/src/math.rs b/src/tools/miri/src/math.rs
index 401e6dd7aab..50472ed3638 100644
--- a/src/tools/miri/src/math.rs
+++ b/src/tools/miri/src/math.rs
@@ -2,7 +2,7 @@ use std::ops::Neg;
 use std::{f32, f64};
 
 use rand::Rng as _;
-use rustc_apfloat::Float as _;
+use rustc_apfloat::Float;
 use rustc_apfloat::ieee::{DoubleS, IeeeFloat, Semantics, SingleS};
 use rustc_middle::ty::{self, FloatTy, ScalarInt};
 
@@ -317,19 +317,19 @@ where
     }
 }
 
-pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
+pub(crate) fn sqrt<F: Float>(x: F) -> F {
     match x.category() {
         // preserve zero sign
         rustc_apfloat::Category::Zero => x,
         // propagate NaN
         rustc_apfloat::Category::NaN => x,
         // sqrt of negative number is NaN
-        _ if x.is_negative() => IeeeFloat::NAN,
+        _ if x.is_negative() => F::NAN,
         // sqrt(∞) = ∞
-        rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
+        rustc_apfloat::Category::Infinity => F::INFINITY,
         rustc_apfloat::Category::Normal => {
             // Floating point precision, excluding the integer bit
-            let prec = i32::try_from(S::PRECISION).unwrap() - 1;
+            let prec = i32::try_from(F::PRECISION).unwrap() - 1;
 
             // x = 2^(exp - prec) * mant
             // where mant is an integer with prec+1 bits
@@ -394,7 +394,7 @@ pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFl
             res = (res + 1) >> 1;
 
             // Build resulting value with res as mantissa and exp/2 as exponent
-            IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
+            F::from_u128(res).value.scalbn(exp / 2 - prec)
         }
     }
 }
diff --git a/src/tools/miri/tests/pass/float.rs b/src/tools/miri/tests/pass/float.rs
index 9f1b3f612b2..3ce5ea8356b 100644
--- a/src/tools/miri/tests/pass/float.rs
+++ b/src/tools/miri/tests/pass/float.rs
@@ -281,6 +281,35 @@ fn basic() {
     assert_eq!(34.2f64.abs(), 34.2f64);
     assert_eq!((-1.0f128).abs(), 1.0f128);
     assert_eq!(34.2f128.abs(), 34.2f128);
+
+    assert_eq!(64_f16.sqrt(), 8_f16);
+    assert_eq!(64_f32.sqrt(), 8_f32);
+    assert_eq!(64_f64.sqrt(), 8_f64);
+    assert_eq!(64_f128.sqrt(), 8_f128);
+    assert_eq!(f16::INFINITY.sqrt(), f16::INFINITY);
+    assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY);
+    assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY);
+    assert_eq!(f128::INFINITY.sqrt(), f128::INFINITY);
+    assert_eq!(0.0_f16.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
+    assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
+    assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
+    assert_eq!(0.0_f128.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
+    assert_eq!((-0.0_f16).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
+    assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
+    assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
+    assert_eq!((-0.0_f128).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
+    assert!((-5.0_f16).sqrt().is_nan());
+    assert!((-5.0_f32).sqrt().is_nan());
+    assert!((-5.0_f64).sqrt().is_nan());
+    assert!((-5.0_f128).sqrt().is_nan());
+    assert!(f16::NEG_INFINITY.sqrt().is_nan());
+    assert!(f32::NEG_INFINITY.sqrt().is_nan());
+    assert!(f64::NEG_INFINITY.sqrt().is_nan());
+    assert!(f128::NEG_INFINITY.sqrt().is_nan());
+    assert!(f16::NAN.sqrt().is_nan());
+    assert!(f32::NAN.sqrt().is_nan());
+    assert!(f64::NAN.sqrt().is_nan());
+    assert!(f128::NAN.sqrt().is_nan());
 }
 
 /// Test casts from floats to ints and back
@@ -1012,21 +1041,6 @@ pub fn libm() {
         unsafe { ldexp(a, b) }
     }
 
-    assert_eq!(64_f32.sqrt(), 8_f32);
-    assert_eq!(64_f64.sqrt(), 8_f64);
-    assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY);
-    assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY);
-    assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
-    assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
-    assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
-    assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
-    assert!((-5.0_f32).sqrt().is_nan());
-    assert!((-5.0_f64).sqrt().is_nan());
-    assert!(f32::NEG_INFINITY.sqrt().is_nan());
-    assert!(f64::NEG_INFINITY.sqrt().is_nan());
-    assert!(f32::NAN.sqrt().is_nan());
-    assert!(f64::NAN.sqrt().is_nan());
-
     assert_approx_eq!(25f32.powi(-2), 0.0016f32);
     assert_approx_eq!(23.2f64.powi(2), 538.24f64);