diff options
| author | Ralf Jung <post@ralfj.de> | 2025-09-18 20:47:26 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-18 20:47:26 +0000 |
| commit | 11ffdb54f63c8eac2322a3ac1a3e4e4b8b5a99a6 (patch) | |
| tree | 464e5e7c25ed258b9d69396be5555a4c040fd275 /src | |
| parent | 045e5e3586375db8464c5cf88e4ea89fbcefe60f (diff) | |
| parent | 77f2d865549e8f2e1ad656c8202de9bfa6d2e354 (diff) | |
| download | rust-11ffdb54f63c8eac2322a3ac1a3e4e4b8b5a99a6.tar.gz rust-11ffdb54f63c8eac2322a3ac1a3e4e4b8b5a99a6.zip | |
Merge pull request #4592 from RalfJung/sqrt
implement sqrt for f16 and f128
Diffstat (limited to 'src')
| -rw-r--r-- | src/tools/miri/src/intrinsics/math.rs | 36 | ||||
| -rw-r--r-- | src/tools/miri/src/math.rs | 12 | ||||
| -rw-r--r-- | src/tools/miri/tests/pass/float.rs | 44 |
3 files changed, 54 insertions, 38 deletions
diff --git a/src/tools/miri/src/intrinsics/math.rs b/src/tools/miri/src/intrinsics/math.rs index 21d4a92e7d2..b9c99f28594 100644 --- a/src/tools/miri/src/intrinsics/math.rs +++ b/src/tools/miri/src/intrinsics/math.rs @@ -1,5 +1,5 @@ use rand::Rng; -use rustc_apfloat::{self, Float, Round}; +use rustc_apfloat::{self, Float, FloatConvert, Round}; use rustc_middle::mir; use rustc_middle::ty::{self, FloatTy}; @@ -7,6 +7,20 @@ use self::helpers::{ToHost, ToSoft}; use super::check_intrinsic_arg_count; use crate::*; +fn sqrt<'tcx, F: Float + FloatConvert<F> + Into<Scalar>>( + this: &mut MiriInterpCx<'tcx>, + args: &[OpTy<'tcx>], + dest: &MPlaceTy<'tcx>, +) -> InterpResult<'tcx> { + let [f] = check_intrinsic_arg_count(args)?; + let f = this.read_scalar(f)?; + let f: F = f.to_float()?; + // Sqrt is specified to be fully precise. + let res = math::sqrt(f); + let res = this.adjust_nan(res, &[f]); + this.write_scalar(res, dest) +} + impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {} pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { fn emulate_math_intrinsic( @@ -20,22 +34,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { match intrinsic_name { // Operations we can do with soft-floats. - "sqrtf32" => { - let [f] = check_intrinsic_arg_count(args)?; - let f = this.read_scalar(f)?.to_f32()?; - // Sqrt is specified to be fully precise. - let res = math::sqrt(f); - let res = this.adjust_nan(res, &[f]); - this.write_scalar(res, dest)?; - } - "sqrtf64" => { - let [f] = check_intrinsic_arg_count(args)?; - let f = this.read_scalar(f)?.to_f64()?; - // Sqrt is specified to be fully precise. - let res = math::sqrt(f); - let res = this.adjust_nan(res, &[f]); - this.write_scalar(res, dest)?; - } + "sqrtf16" => sqrt::<rustc_apfloat::ieee::Half>(this, args, dest)?, + "sqrtf32" => sqrt::<rustc_apfloat::ieee::Single>(this, args, dest)?, + "sqrtf64" => sqrt::<rustc_apfloat::ieee::Double>(this, args, dest)?, + "sqrtf128" => sqrt::<rustc_apfloat::ieee::Quad>(this, args, dest)?, "fmaf32" => { let [a, b, c] = check_intrinsic_arg_count(args)?; diff --git a/src/tools/miri/src/math.rs b/src/tools/miri/src/math.rs index 401e6dd7aab..50472ed3638 100644 --- a/src/tools/miri/src/math.rs +++ b/src/tools/miri/src/math.rs @@ -2,7 +2,7 @@ use std::ops::Neg; use std::{f32, f64}; use rand::Rng as _; -use rustc_apfloat::Float as _; +use rustc_apfloat::Float; use rustc_apfloat::ieee::{DoubleS, IeeeFloat, Semantics, SingleS}; use rustc_middle::ty::{self, FloatTy, ScalarInt}; @@ -317,19 +317,19 @@ where } } -pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> { +pub(crate) fn sqrt<F: Float>(x: F) -> F { match x.category() { // preserve zero sign rustc_apfloat::Category::Zero => x, // propagate NaN rustc_apfloat::Category::NaN => x, // sqrt of negative number is NaN - _ if x.is_negative() => IeeeFloat::NAN, + _ if x.is_negative() => F::NAN, // sqrt(∞) = ∞ - rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY, + rustc_apfloat::Category::Infinity => F::INFINITY, rustc_apfloat::Category::Normal => { // Floating point precision, excluding the integer bit - let prec = i32::try_from(S::PRECISION).unwrap() - 1; + let prec = i32::try_from(F::PRECISION).unwrap() - 1; // x = 2^(exp - prec) * mant // where mant is an integer with prec+1 bits @@ -394,7 +394,7 @@ pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFl res = (res + 1) >> 1; // Build resulting value with res as mantissa and exp/2 as exponent - IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec) + F::from_u128(res).value.scalbn(exp / 2 - prec) } } } diff --git a/src/tools/miri/tests/pass/float.rs b/src/tools/miri/tests/pass/float.rs index 9f1b3f612b2..3ce5ea8356b 100644 --- a/src/tools/miri/tests/pass/float.rs +++ b/src/tools/miri/tests/pass/float.rs @@ -281,6 +281,35 @@ fn basic() { assert_eq!(34.2f64.abs(), 34.2f64); assert_eq!((-1.0f128).abs(), 1.0f128); assert_eq!(34.2f128.abs(), 34.2f128); + + assert_eq!(64_f16.sqrt(), 8_f16); + assert_eq!(64_f32.sqrt(), 8_f32); + assert_eq!(64_f64.sqrt(), 8_f64); + assert_eq!(64_f128.sqrt(), 8_f128); + assert_eq!(f16::INFINITY.sqrt(), f16::INFINITY); + assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY); + assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY); + assert_eq!(f128::INFINITY.sqrt(), f128::INFINITY); + assert_eq!(0.0_f16.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!(0.0_f128.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f16).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f128).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert!((-5.0_f16).sqrt().is_nan()); + assert!((-5.0_f32).sqrt().is_nan()); + assert!((-5.0_f64).sqrt().is_nan()); + assert!((-5.0_f128).sqrt().is_nan()); + assert!(f16::NEG_INFINITY.sqrt().is_nan()); + assert!(f32::NEG_INFINITY.sqrt().is_nan()); + assert!(f64::NEG_INFINITY.sqrt().is_nan()); + assert!(f128::NEG_INFINITY.sqrt().is_nan()); + assert!(f16::NAN.sqrt().is_nan()); + assert!(f32::NAN.sqrt().is_nan()); + assert!(f64::NAN.sqrt().is_nan()); + assert!(f128::NAN.sqrt().is_nan()); } /// Test casts from floats to ints and back @@ -1012,21 +1041,6 @@ pub fn libm() { unsafe { ldexp(a, b) } } - assert_eq!(64_f32.sqrt(), 8_f32); - assert_eq!(64_f64.sqrt(), 8_f64); - assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY); - assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY); - assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); - assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); - assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); - assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); - assert!((-5.0_f32).sqrt().is_nan()); - assert!((-5.0_f64).sqrt().is_nan()); - assert!(f32::NEG_INFINITY.sqrt().is_nan()); - assert!(f64::NEG_INFINITY.sqrt().is_nan()); - assert!(f32::NAN.sqrt().is_nan()); - assert!(f64::NAN.sqrt().is_nan()); - assert_approx_eq!(25f32.powi(-2), 0.0016f32); assert_approx_eq!(23.2f64.powi(2), 538.24f64); |
