about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorTrevor Gross <t.gross35@gmail.com>2024-10-11 23:57:44 -0400
committerGitHub <noreply@github.com>2024-10-11 23:57:44 -0400
commit3f9aa50b70e8833656d557aa963e339faca0a4c6 (patch)
treebb838846810b68bd96127303c82deb61c40fa8ae /src
parentfb20e4d3b96d1de459d086980a8b99d5060ad9fe (diff)
parent0d8a978e8a55b08778ec6ee861c2c5ed6703eb6c (diff)
downloadrust-3f9aa50b70e8833656d557aa963e339faca0a4c6.tar.gz
rust-3f9aa50b70e8833656d557aa963e339faca0a4c6.zip
Rollup merge of #124874 - jedbrown:float-mul-add-fast, r=saethlin
intrinsics fmuladdf{32,64}: expose llvm.fmuladd.* semantics

Add intrinsics `fmuladd{f32,f64}`. This computes `(a * b) + c`, to be fused if the code generator determines that (i) the target instruction set has support for a fused operation, and (ii) that the fused operation is more efficient than the equivalent, separate pair of `mul` and `add` instructions.

https://llvm.org/docs/LangRef.html#llvm-fmuladd-intrinsic

The codegen_cranelift uses the `fma` function from libc, which is a correct implementation, but without the desired performance semantic. I think this requires an update to cranelift to expose a suitable instruction in its IR.

I have not tested with codegen_gcc, but it should behave the same way (using `fma` from libc).

---
This topic has been discussed a few times on Zulip and was suggested, for example, by `@workingjubilee` in [Effect of fma disabled](https://rust-lang.zulipchat.com/#narrow/stream/122651-general/topic/Effect.20of.20fma.20disabled/near/274179331).
Diffstat (limited to 'src')
-rw-r--r--src/tools/miri/src/intrinsics/mod.rs31
-rw-r--r--src/tools/miri/tests/pass/float.rs18
-rw-r--r--src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs44
3 files changed, 93 insertions, 0 deletions
diff --git a/src/tools/miri/src/intrinsics/mod.rs b/src/tools/miri/src/intrinsics/mod.rs
index 665dd7c441a..9f772cfa982 100644
--- a/src/tools/miri/src/intrinsics/mod.rs
+++ b/src/tools/miri/src/intrinsics/mod.rs
@@ -295,6 +295,37 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                 this.write_scalar(res, dest)?;
             }
 
+            "fmuladdf32" => {
+                let [a, b, c] = check_arg_count(args)?;
+                let a = this.read_scalar(a)?.to_f32()?;
+                let b = this.read_scalar(b)?.to_f32()?;
+                let c = this.read_scalar(c)?.to_f32()?;
+                let fuse: bool = this.machine.rng.get_mut().gen();
+                let res = if fuse {
+                    // FIXME: Using host floats, to work around https://github.com/rust-lang/rustc_apfloat/issues/11
+                    a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
+                } else {
+                    ((a * b).value + c).value
+                };
+                let res = this.adjust_nan(res, &[a, b, c]);
+                this.write_scalar(res, dest)?;
+            }
+            "fmuladdf64" => {
+                let [a, b, c] = check_arg_count(args)?;
+                let a = this.read_scalar(a)?.to_f64()?;
+                let b = this.read_scalar(b)?.to_f64()?;
+                let c = this.read_scalar(c)?.to_f64()?;
+                let fuse: bool = this.machine.rng.get_mut().gen();
+                let res = if fuse {
+                    // FIXME: Using host floats, to work around https://github.com/rust-lang/rustc_apfloat/issues/11
+                    a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
+                } else {
+                    ((a * b).value + c).value
+                };
+                let res = this.adjust_nan(res, &[a, b, c]);
+                this.write_scalar(res, dest)?;
+            }
+
             "powf32" => {
                 let [f1, f2] = check_arg_count(args)?;
                 let f1 = this.read_scalar(f1)?.to_f32()?;
diff --git a/src/tools/miri/tests/pass/float.rs b/src/tools/miri/tests/pass/float.rs
index 6ab18a5345e..853d3e80517 100644
--- a/src/tools/miri/tests/pass/float.rs
+++ b/src/tools/miri/tests/pass/float.rs
@@ -30,6 +30,7 @@ fn main() {
     libm();
     test_fast();
     test_algebraic();
+    test_fmuladd();
 }
 
 trait Float: Copy + PartialEq + Debug {
@@ -1041,3 +1042,20 @@ fn test_algebraic() {
     test_operations_f32(11., 2.);
     test_operations_f32(10., 15.);
 }
+
+fn test_fmuladd() {
+    use std::intrinsics::{fmuladdf32, fmuladdf64};
+
+    #[inline(never)]
+    pub fn test_operations_f32(a: f32, b: f32, c: f32) {
+        assert_approx_eq!(unsafe { fmuladdf32(a, b, c) }, a * b + c);
+    }
+
+    #[inline(never)]
+    pub fn test_operations_f64(a: f64, b: f64, c: f64) {
+        assert_approx_eq!(unsafe { fmuladdf64(a, b, c) }, a * b + c);
+    }
+
+    test_operations_f32(0.1, 0.2, 0.3);
+    test_operations_f64(1.1, 1.2, 1.3);
+}
diff --git a/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs b/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs
new file mode 100644
index 00000000000..b46cf1ddf65
--- /dev/null
+++ b/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs
@@ -0,0 +1,44 @@
+#![feature(core_intrinsics)]
+use std::intrinsics::{fmuladdf32, fmuladdf64};
+
+fn main() {
+    let mut saw_zero = false;
+    let mut saw_nonzero = false;
+    for _ in 0..50 {
+        let a = std::hint::black_box(0.1_f64);
+        let b = std::hint::black_box(0.2);
+        let c = std::hint::black_box(-a * b);
+        // It is unspecified whether the following operation is fused or not. The
+        // following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
+        let x = unsafe { fmuladdf64(a, b, c) };
+        if x == 0.0 {
+            saw_zero = true;
+        } else {
+            saw_nonzero = true;
+        }
+    }
+    assert!(
+        saw_zero && saw_nonzero,
+        "`fmuladdf64` failed to be evaluated as both fused and unfused"
+    );
+
+    let mut saw_zero = false;
+    let mut saw_nonzero = false;
+    for _ in 0..50 {
+        let a = std::hint::black_box(0.1_f32);
+        let b = std::hint::black_box(0.2);
+        let c = std::hint::black_box(-a * b);
+        // It is unspecified whether the following operation is fused or not. The
+        // following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
+        let x = unsafe { fmuladdf32(a, b, c) };
+        if x == 0.0 {
+            saw_zero = true;
+        } else {
+            saw_nonzero = true;
+        }
+    }
+    assert!(
+        saw_zero && saw_nonzero,
+        "`fmuladdf32` failed to be evaluated as both fused and unfused"
+    );
+}