about summary refs log tree commit diff
diff options
context:
space:
mode:
authorTrevor Gross <tmgross@umich.edu>2025-01-23 08:28:58 +0000
committerTrevor Gross <t.gross35@gmail.com>2025-02-05 15:10:47 -0600
commit9458abd20439c75584bd1b33a1ef7b6891ab5e8a (patch)
tree43f1cd57deb6956f2db8d6d64f99c579ed679f02
parent466cd81ff5e8950da6fae3b7a76d6768689da0b6 (diff)
downloadrust-9458abd20439c75584bd1b33a1ef7b6891ab5e8a.tar.gz
rust-9458abd20439c75584bd1b33a1ef7b6891ab5e8a.zip
Start converting `fma` to a generic function
This is the first step toward making `fma` usable for `f128`, and
possibly `f32` on platforms where growing to `f64` is not fast. This
does not yet work for anything other than `f64`.
-rw-r--r--library/compiler-builtins/libm/etc/function-definitions.json6
-rw-r--r--library/compiler-builtins/libm/src/math/fma.rs192
-rw-r--r--library/compiler-builtins/libm/src/math/generic/fma.rs227
-rw-r--r--library/compiler-builtins/libm/src/math/generic/mod.rs2
-rw-r--r--library/compiler-builtins/libm/src/math/support/float_traits.rs4
-rw-r--r--library/compiler-builtins/libm/src/math/support/int_traits.rs39
6 files changed, 278 insertions, 192 deletions
diff --git a/library/compiler-builtins/libm/etc/function-definitions.json b/library/compiler-builtins/libm/etc/function-definitions.json
index a1d3adf591f..243862075ff 100644
--- a/library/compiler-builtins/libm/etc/function-definitions.json
+++ b/library/compiler-builtins/libm/etc/function-definitions.json
@@ -344,13 +344,15 @@
     },
     "fma": {
         "sources": [
-            "src/math/fma.rs"
+            "src/math/fma.rs",
+            "src/math/generic/fma.rs"
         ],
         "type": "f64"
     },
     "fmaf": {
         "sources": [
-            "src/math/fmaf.rs"
+            "src/math/fmaf.rs",
+            "src/math/generic/fma.rs"
         ],
         "type": "f32"
     },
diff --git a/library/compiler-builtins/libm/src/math/fma.rs b/library/compiler-builtins/libm/src/math/fma.rs
index 826143d5a47..69cc3eb6726 100644
--- a/library/compiler-builtins/libm/src/math/fma.rs
+++ b/library/compiler-builtins/libm/src/math/fma.rs
@@ -1,195 +1,9 @@
-use core::{f32, f64};
-
-use super::scalbn;
-
-const ZEROINFNAN: i32 = 0x7ff - 0x3ff - 52 - 1;
-
-struct Num {
-    m: u64,
-    e: i32,
-    sign: i32,
-}
-
-fn normalize(x: f64) -> Num {
-    let x1p63: f64 = f64::from_bits(0x43e0000000000000); // 0x1p63 === 2 ^ 63
-
-    let mut ix: u64 = x.to_bits();
-    let mut e: i32 = (ix >> 52) as i32;
-    let sign: i32 = e & 0x800;
-    e &= 0x7ff;
-    if e == 0 {
-        ix = (x * x1p63).to_bits();
-        e = (ix >> 52) as i32 & 0x7ff;
-        e = if e != 0 { e - 63 } else { 0x800 };
-    }
-    ix &= (1 << 52) - 1;
-    ix |= 1 << 52;
-    ix <<= 1;
-    e -= 0x3ff + 52 + 1;
-    Num { m: ix, e, sign }
-}
-
-#[inline]
-fn mul(x: u64, y: u64) -> (u64, u64) {
-    let t = (x as u128).wrapping_mul(y as u128);
-    ((t >> 64) as u64, t as u64)
-}
-
-/// Floating multiply add (f64)
+/// Fused multiply add (f64)
 ///
-/// Computes `(x*y)+z`, rounded as one ternary operation:
-/// Computes the value (as if) to infinite precision and rounds once to the result format,
-/// according to the rounding mode characterized by the value of FLT_ROUNDS.
+/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
 #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
 pub fn fma(x: f64, y: f64, z: f64) -> f64 {
-    let x1p63: f64 = f64::from_bits(0x43e0000000000000); // 0x1p63 === 2 ^ 63
-    let x0_ffffff8p_63 = f64::from_bits(0x3bfffffff0000000); // 0x0.ffffff8p-63
-
-    /* normalize so top 10bits and last bit are 0 */
-    let nx = normalize(x);
-    let ny = normalize(y);
-    let nz = normalize(z);
-
-    if nx.e >= ZEROINFNAN || ny.e >= ZEROINFNAN {
-        return x * y + z;
-    }
-    if nz.e >= ZEROINFNAN {
-        if nz.e > ZEROINFNAN {
-            /* z==0 */
-            return x * y + z;
-        }
-        return z;
-    }
-
-    /* mul: r = x*y */
-    let zhi: u64;
-    let zlo: u64;
-    let (mut rhi, mut rlo) = mul(nx.m, ny.m);
-    /* either top 20 or 21 bits of rhi and last 2 bits of rlo are 0 */
-
-    /* align exponents */
-    let mut e: i32 = nx.e + ny.e;
-    let mut d: i32 = nz.e - e;
-    /* shift bits z<<=kz, r>>=kr, so kz+kr == d, set e = e+kr (== ez-kz) */
-    if d > 0 {
-        if d < 64 {
-            zlo = nz.m << d;
-            zhi = nz.m >> (64 - d);
-        } else {
-            zlo = 0;
-            zhi = nz.m;
-            e = nz.e - 64;
-            d -= 64;
-            if d == 0 {
-            } else if d < 64 {
-                rlo = (rhi << (64 - d)) | (rlo >> d) | ((rlo << (64 - d)) != 0) as u64;
-                rhi = rhi >> d;
-            } else {
-                rlo = 1;
-                rhi = 0;
-            }
-        }
-    } else {
-        zhi = 0;
-        d = -d;
-        if d == 0 {
-            zlo = nz.m;
-        } else if d < 64 {
-            zlo = (nz.m >> d) | ((nz.m << (64 - d)) != 0) as u64;
-        } else {
-            zlo = 1;
-        }
-    }
-
-    /* add */
-    let mut sign: i32 = nx.sign ^ ny.sign;
-    let samesign: bool = (sign ^ nz.sign) == 0;
-    let mut nonzero: i32 = 1;
-    if samesign {
-        /* r += z */
-        rlo = rlo.wrapping_add(zlo);
-        rhi += zhi + (rlo < zlo) as u64;
-    } else {
-        /* r -= z */
-        let (res, borrow) = rlo.overflowing_sub(zlo);
-        rlo = res;
-        rhi = rhi.wrapping_sub(zhi.wrapping_add(borrow as u64));
-        if (rhi >> 63) != 0 {
-            rlo = (rlo as i64).wrapping_neg() as u64;
-            rhi = (rhi as i64).wrapping_neg() as u64 - (rlo != 0) as u64;
-            sign = (sign == 0) as i32;
-        }
-        nonzero = (rhi != 0) as i32;
-    }
-
-    /* set rhi to top 63bit of the result (last bit is sticky) */
-    if nonzero != 0 {
-        e += 64;
-        d = rhi.leading_zeros() as i32 - 1;
-        /* note: d > 0 */
-        rhi = (rhi << d) | (rlo >> (64 - d)) | ((rlo << d) != 0) as u64;
-    } else if rlo != 0 {
-        d = rlo.leading_zeros() as i32 - 1;
-        if d < 0 {
-            rhi = (rlo >> 1) | (rlo & 1);
-        } else {
-            rhi = rlo << d;
-        }
-    } else {
-        /* exact +-0 */
-        return x * y + z;
-    }
-    e -= d;
-
-    /* convert to double */
-    let mut i: i64 = rhi as i64; /* i is in [1<<62,(1<<63)-1] */
-    if sign != 0 {
-        i = -i;
-    }
-    let mut r: f64 = i as f64; /* |r| is in [0x1p62,0x1p63] */
-
-    if e < -1022 - 62 {
-        /* result is subnormal before rounding */
-        if e == -1022 - 63 {
-            let mut c: f64 = x1p63;
-            if sign != 0 {
-                c = -c;
-            }
-            if r == c {
-                /* min normal after rounding, underflow depends
-                on arch behaviour which can be imitated by
-                a double to float conversion */
-                let fltmin: f32 = (x0_ffffff8p_63 * f32::MIN_POSITIVE as f64 * r) as f32;
-                return f64::MIN_POSITIVE / f32::MIN_POSITIVE as f64 * fltmin as f64;
-            }
-            /* one bit is lost when scaled, add another top bit to
-            only round once at conversion if it is inexact */
-            if (rhi << 53) != 0 {
-                i = ((rhi >> 1) | (rhi & 1) | (1 << 62)) as i64;
-                if sign != 0 {
-                    i = -i;
-                }
-                r = i as f64;
-                r = 2. * r - c; /* remove top bit */
-
-                /* raise underflow portably, such that it
-                cannot be optimized away */
-                {
-                    let tiny: f64 = f64::MIN_POSITIVE / f32::MIN_POSITIVE as f64 * r;
-                    r += (tiny * tiny) * (r - r);
-                }
-            }
-        } else {
-            /* only round once when scaled */
-            d = 10;
-            i = (((rhi >> d) | ((rhi << (64 - d)) != 0) as u64) << d) as i64;
-            if sign != 0 {
-                i = -i;
-            }
-            r = i as f64;
-        }
-    }
-    scalbn(r, e)
+    return super::generic::fma(x, y, z);
 }
 
 #[cfg(test)]
diff --git a/library/compiler-builtins/libm/src/math/generic/fma.rs b/library/compiler-builtins/libm/src/math/generic/fma.rs
new file mode 100644
index 00000000000..3d5459f1a04
--- /dev/null
+++ b/library/compiler-builtins/libm/src/math/generic/fma.rs
@@ -0,0 +1,227 @@
+use core::{f32, f64};
+
+use super::super::support::{DInt, HInt, IntTy};
+use super::super::{CastFrom, CastInto, Float, Int, MinInt};
+
+const ZEROINFNAN: i32 = 0x7ff - 0x3ff - 52 - 1;
+
+/// Fused multiply-add that works when there is not a larger float size available. Currently this
+/// is still specialized only for `f64`. Computes `(x * y) + z`.
+#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
+pub fn fma<F>(x: F, y: F, z: F) -> F
+where
+    F: Float + FmaHelper,
+    F: CastFrom<F::SignedInt>,
+    F: CastFrom<i8>,
+    F::Int: HInt,
+    u32: CastInto<F::Int>,
+{
+    let one = IntTy::<F>::ONE;
+    let zero = IntTy::<F>::ZERO;
+    let magic = F::from_parts(false, F::BITS - 1 + F::EXP_BIAS, zero);
+
+    /* normalize so top 10bits and last bit are 0 */
+    let nx = Norm::from_float(x);
+    let ny = Norm::from_float(y);
+    let nz = Norm::from_float(z);
+
+    if nx.e >= ZEROINFNAN || ny.e >= ZEROINFNAN {
+        return x * y + z;
+    }
+    if nz.e >= ZEROINFNAN {
+        if nz.e > ZEROINFNAN {
+            /* z==0 */
+            return x * y + z;
+        }
+        return z;
+    }
+
+    /* mul: r = x*y */
+    let zhi: F::Int;
+    let zlo: F::Int;
+    let (mut rlo, mut rhi) = nx.m.widen_mul(ny.m).lo_hi();
+
+    /* either top 20 or 21 bits of rhi and last 2 bits of rlo are 0 */
+
+    /* align exponents */
+    let mut e: i32 = nx.e + ny.e;
+    let mut d: i32 = nz.e - e;
+    let sbits = F::BITS as i32;
+
+    /* shift bits z<<=kz, r>>=kr, so kz+kr == d, set e = e+kr (== ez-kz) */
+    if d > 0 {
+        if d < sbits {
+            zlo = nz.m << d;
+            zhi = nz.m >> (sbits - d);
+        } else {
+            zlo = zero;
+            zhi = nz.m;
+            e = nz.e - sbits;
+            d -= sbits;
+            if d == 0 {
+            } else if d < sbits {
+                rlo = (rhi << (sbits - d))
+                    | (rlo >> d)
+                    | IntTy::<F>::from((rlo << (sbits - d)) != zero);
+                rhi = rhi >> d;
+            } else {
+                rlo = one;
+                rhi = zero;
+            }
+        }
+    } else {
+        zhi = zero;
+        d = -d;
+        if d == 0 {
+            zlo = nz.m;
+        } else if d < sbits {
+            zlo = (nz.m >> d) | IntTy::<F>::from((nz.m << (sbits - d)) != zero);
+        } else {
+            zlo = one;
+        }
+    }
+
+    /* add */
+    let mut neg = nx.neg ^ ny.neg;
+    let samesign: bool = !neg ^ nz.neg;
+    let mut nonzero: i32 = 1;
+    if samesign {
+        /* r += z */
+        rlo = rlo.wrapping_add(zlo);
+        rhi += zhi + IntTy::<F>::from(rlo < zlo);
+    } else {
+        /* r -= z */
+        let (res, borrow) = rlo.overflowing_sub(zlo);
+        rlo = res;
+        rhi = rhi.wrapping_sub(zhi.wrapping_add(IntTy::<F>::from(borrow)));
+        if (rhi >> (F::BITS - 1)) != zero {
+            rlo = rlo.signed().wrapping_neg().unsigned();
+            rhi = rhi.signed().wrapping_neg().unsigned() - IntTy::<F>::from(rlo != zero);
+            neg = !neg;
+        }
+        nonzero = (rhi != zero) as i32;
+    }
+
+    /* set rhi to top 63bit of the result (last bit is sticky) */
+    if nonzero != 0 {
+        e += sbits;
+        d = rhi.leading_zeros() as i32 - 1;
+        /* note: d > 0 */
+        rhi = (rhi << d) | (rlo >> (sbits - d)) | IntTy::<F>::from((rlo << d) != zero);
+    } else if rlo != zero {
+        d = rlo.leading_zeros() as i32 - 1;
+        if d < 0 {
+            rhi = (rlo >> 1) | (rlo & one);
+        } else {
+            rhi = rlo << d;
+        }
+    } else {
+        /* exact +-0 */
+        return x * y + z;
+    }
+    e -= d;
+
+    /* convert to double */
+    let mut i: F::SignedInt = rhi.signed(); /* i is in [1<<62,(1<<63)-1] */
+    if neg {
+        i = -i;
+    }
+
+    let mut r: F = F::cast_from_lossy(i); /* |r| is in [0x1p62,0x1p63] */
+
+    if e < -(F::EXP_BIAS as i32 - 1) - (sbits - 2) {
+        /* result is subnormal before rounding */
+        if e == -(F::EXP_BIAS as i32 - 1) - (sbits - 1) {
+            let mut c: F = magic;
+            if neg {
+                c = -c;
+            }
+            if r == c {
+                /* min normal after rounding, underflow depends
+                 * on arch behaviour which can be imitated by
+                 * a double to float conversion */
+                return r.raise_underflow();
+            }
+            /* one bit is lost when scaled, add another top bit to
+             * only round once at conversion if it is inexact */
+            if (rhi << F::SIG_BITS) != zero {
+                let iu: F::Int = (rhi >> 1) | (rhi & one) | (one << 62);
+                i = iu.signed();
+                if neg {
+                    i = -i;
+                }
+                r = F::cast_from_lossy(i);
+                r = F::cast_from(2i8) * r - c; /* remove top bit */
+
+                /* raise underflow portably, such that it
+                 * cannot be optimized away */
+                r += r.raise_underflow2();
+            }
+        } else {
+            /* only round once when scaled */
+            d = 10;
+            i = (((rhi >> d) | IntTy::<F>::from(rhi << (F::BITS as i32 - d) != zero)) << d)
+                .signed();
+            if neg {
+                i = -i;
+            }
+            r = F::cast_from(i);
+        }
+    }
+
+    super::scalbn(r, e)
+}
+
+/// Representation of `F` that has handled subnormals.
+struct Norm<F: Float> {
+    /// Normalized significand with one guard bit.
+    m: F::Int,
+    /// Unbiased exponent, normalized.
+    e: i32,
+    neg: bool,
+}
+
+impl<F: Float> Norm<F> {
+    fn from_float(x: F) -> Self {
+        let mut ix = x.to_bits();
+        let mut e = x.exp() as i32;
+        let neg = x.is_sign_negative();
+        if e == 0 {
+            // Normalize subnormals by multiplication
+            let magic = F::from_parts(false, F::BITS - 1 + F::EXP_BIAS, F::Int::ZERO);
+            let scaled = x * magic;
+            ix = scaled.to_bits();
+            e = scaled.exp() as i32;
+            e = if e != 0 { e - (F::BITS as i32 - 1) } else { 0x800 };
+        }
+
+        e -= F::EXP_BIAS as i32 + 52 + 1;
+
+        ix &= F::SIG_MASK;
+        ix |= F::IMPLICIT_BIT;
+        ix <<= 1; // add a guard bit
+
+        Self { m: ix, e, neg }
+    }
+}
+
+/// Type-specific helpers that are not needed outside of fma.
+pub trait FmaHelper {
+    fn raise_underflow(self) -> Self;
+    fn raise_underflow2(self) -> Self;
+}
+
+impl FmaHelper for f64 {
+    fn raise_underflow(self) -> Self {
+        let x0_ffffff8p_63 = f64::from_bits(0x3bfffffff0000000); // 0x0.ffffff8p-63
+        let fltmin: f32 = (x0_ffffff8p_63 * f32::MIN_POSITIVE as f64 * self) as f32;
+        f64::MIN_POSITIVE / f32::MIN_POSITIVE as f64 * fltmin as f64
+    }
+
+    fn raise_underflow2(self) -> Self {
+        /* raise underflow portably, such that it
+         * cannot be optimized away */
+        let tiny: f64 = f64::MIN_POSITIVE / f32::MIN_POSITIVE as f64 * self;
+        (tiny * tiny) * (self - self)
+    }
+}
diff --git a/library/compiler-builtins/libm/src/math/generic/mod.rs b/library/compiler-builtins/libm/src/math/generic/mod.rs
index 68686b0b255..e19cc83a9ac 100644
--- a/library/compiler-builtins/libm/src/math/generic/mod.rs
+++ b/library/compiler-builtins/libm/src/math/generic/mod.rs
@@ -3,6 +3,7 @@ mod copysign;
 mod fabs;
 mod fdim;
 mod floor;
+mod fma;
 mod fmax;
 mod fmin;
 mod fmod;
@@ -17,6 +18,7 @@ pub use copysign::copysign;
 pub use fabs::fabs;
 pub use fdim::fdim;
 pub use floor::floor;
+pub use fma::fma;
 pub use fmax::fmax;
 pub use fmin::fmin;
 pub use fmod::fmod;
diff --git a/library/compiler-builtins/libm/src/math/support/float_traits.rs b/library/compiler-builtins/libm/src/math/support/float_traits.rs
index 1fe2cb424b9..24cf7d4b05c 100644
--- a/library/compiler-builtins/libm/src/math/support/float_traits.rs
+++ b/library/compiler-builtins/libm/src/math/support/float_traits.rs
@@ -23,7 +23,9 @@ pub trait Float:
     type Int: Int<OtherSign = Self::SignedInt, Unsigned = Self::Int>;
 
     /// A int of the same width as the float
-    type SignedInt: Int + MinInt<OtherSign = Self::Int, Unsigned = Self::Int>;
+    type SignedInt: Int
+        + MinInt<OtherSign = Self::Int, Unsigned = Self::Int>
+        + ops::Neg<Output = Self::SignedInt>;
 
     const ZERO: Self;
     const NEG_ZERO: Self;
diff --git a/library/compiler-builtins/libm/src/math/support/int_traits.rs b/library/compiler-builtins/libm/src/math/support/int_traits.rs
index b403c658cb6..793a0f3069f 100644
--- a/library/compiler-builtins/libm/src/math/support/int_traits.rs
+++ b/library/compiler-builtins/libm/src/math/support/int_traits.rs
@@ -52,10 +52,14 @@ pub trait Int:
     + ops::Sub<Output = Self>
     + ops::Mul<Output = Self>
     + ops::Div<Output = Self>
+    + ops::Shl<i32, Output = Self>
+    + ops::Shl<u32, Output = Self>
+    + ops::Shr<i32, Output = Self>
     + ops::Shr<u32, Output = Self>
     + ops::BitXor<Output = Self>
     + ops::BitAnd<Output = Self>
     + cmp::Ord
+    + From<bool>
     + CastFrom<i32>
     + CastFrom<u16>
     + CastFrom<u32>
@@ -92,6 +96,7 @@ pub trait Int:
     fn wrapping_shr(self, other: u32) -> Self;
     fn rotate_left(self, other: u32) -> Self;
     fn overflowing_add(self, other: Self) -> (Self, bool);
+    fn overflowing_sub(self, other: Self) -> (Self, bool);
     fn leading_zeros(self) -> u32;
     fn ilog2(self) -> u32;
 }
@@ -150,6 +155,10 @@ macro_rules! int_impl_common {
             <Self>::overflowing_add(self, other)
         }
 
+        fn overflowing_sub(self, other: Self) -> (Self, bool) {
+            <Self>::overflowing_sub(self, other)
+        }
+
         fn leading_zeros(self) -> u32 {
             <Self>::leading_zeros(self)
         }
@@ -399,6 +408,30 @@ macro_rules! cast_into {
     )*};
 }
 
+macro_rules! cast_into_float {
+    ($ty:ty) => {
+        #[cfg(f16_enabled)]
+        cast_into_float!($ty; f16);
+
+        cast_into_float!($ty; f32, f64);
+
+        #[cfg(f128_enabled)]
+        cast_into_float!($ty; f128);
+    };
+    ($ty:ty; $($into:ty),*) => {$(
+        impl CastInto<$into> for $ty {
+            fn cast(self) -> $into {
+                debug_assert_eq!(self as $into as $ty, self, "inexact float cast");
+                self as $into
+            }
+
+            fn cast_lossy(self) -> $into {
+                self as $into
+            }
+        }
+    )*};
+}
+
 cast_into!(usize);
 cast_into!(isize);
 cast_into!(u8);
@@ -411,3 +444,9 @@ cast_into!(u64);
 cast_into!(i64);
 cast_into!(u128);
 cast_into!(i128);
+
+cast_into_float!(i8);
+cast_into_float!(i16);
+cast_into_float!(i32);
+cast_into_float!(i64);
+cast_into_float!(i128);