about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2024-12-04 08:46:25 +0000
committerGitHub <noreply@github.com>2024-12-04 08:46:25 +0000
commitb4b0e0356c6301ff65381c4df7541c2b631936ef (patch)
tree06e5e7eb4717073158cd6d1af0532961ce8500d6
parent86d6dc008ec2473368a044e21af43f4495dc408f (diff)
parent91bd957a21bc936f2d8732c658c0f4a1737361b3 (diff)
downloadrust-b4b0e0356c6301ff65381c4df7541c2b631936ef.tar.gz
rust-b4b0e0356c6301ff65381c4df7541c2b631936ef.zip
Merge pull request #4071 from RalfJung/simd_relaxed_fma
implement simd_relaxed_fma
-rw-r--r--src/tools/miri/src/intrinsics/simd.rs17
-rw-r--r--src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs91
-rw-r--r--src/tools/miri/tests/pass/intrinsics/portable-simd.rs22
3 files changed, 97 insertions, 33 deletions
diff --git a/src/tools/miri/src/intrinsics/simd.rs b/src/tools/miri/src/intrinsics/simd.rs
index 075b6f35e0e..54bdd3f02c2 100644
--- a/src/tools/miri/src/intrinsics/simd.rs
+++ b/src/tools/miri/src/intrinsics/simd.rs
@@ -1,4 +1,5 @@
 use either::Either;
+use rand::Rng;
 use rustc_abi::{Endian, HasDataLayout};
 use rustc_apfloat::{Float, Round};
 use rustc_middle::ty::FloatTy;
@@ -286,7 +287,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                     this.write_scalar(val, &dest)?;
                 }
             }
-            "fma" => {
+            "fma" | "relaxed_fma" => {
                 let [a, b, c] = check_arg_count(args)?;
                 let (a, a_len) = this.project_to_simd(a)?;
                 let (b, b_len) = this.project_to_simd(b)?;
@@ -303,6 +304,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                     let c = this.read_scalar(&this.project_index(&c, i)?)?;
                     let dest = this.project_index(&dest, i)?;
 
+                    let fuse: bool = intrinsic_name == "fma" || this.machine.rng.get_mut().gen();
+
                     // Works for f32 and f64.
                     // FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
                     let ty::Float(float_ty) = dest.layout.ty.kind() else {
@@ -314,7 +317,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                             let a = a.to_f32()?;
                             let b = b.to_f32()?;
                             let c = c.to_f32()?;
-                            let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
+                            let res = if fuse {
+                                a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
+                            } else {
+                                ((a * b).value + c).value
+                            };
                             let res = this.adjust_nan(res, &[a, b, c]);
                             Scalar::from(res)
                         }
@@ -322,7 +329,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                             let a = a.to_f64()?;
                             let b = b.to_f64()?;
                             let c = c.to_f64()?;
-                            let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
+                            let res = if fuse {
+                                a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
+                            } else {
+                                ((a * b).value + c).value
+                            };
                             let res = this.adjust_nan(res, &[a, b, c]);
                             Scalar::from(res)
                         }
diff --git a/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs b/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs
index b46cf1ddf65..b688405c4b1 100644
--- a/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs
+++ b/src/tools/miri/tests/pass/intrinsics/fmuladd_nondeterministic.rs
@@ -1,44 +1,75 @@
-#![feature(core_intrinsics)]
+#![feature(core_intrinsics, portable_simd)]
+use std::intrinsics::simd::simd_relaxed_fma;
 use std::intrinsics::{fmuladdf32, fmuladdf64};
+use std::simd::prelude::*;
 
-fn main() {
-    let mut saw_zero = false;
-    let mut saw_nonzero = false;
+fn ensure_both_happen(f: impl Fn() -> bool) -> bool {
+    let mut saw_true = false;
+    let mut saw_false = false;
     for _ in 0..50 {
-        let a = std::hint::black_box(0.1_f64);
-        let b = std::hint::black_box(0.2);
-        let c = std::hint::black_box(-a * b);
-        // It is unspecified whether the following operation is fused or not. The
-        // following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
-        let x = unsafe { fmuladdf64(a, b, c) };
-        if x == 0.0 {
-            saw_zero = true;
+        let b = f();
+        if b {
+            saw_true = true;
         } else {
-            saw_nonzero = true;
+            saw_false = true;
+        }
+        if saw_true && saw_false {
+            return true;
         }
     }
+    false
+}
+
+fn main() {
     assert!(
-        saw_zero && saw_nonzero,
+        ensure_both_happen(|| {
+            let a = std::hint::black_box(0.1_f64);
+            let b = std::hint::black_box(0.2);
+            let c = std::hint::black_box(-a * b);
+            // It is unspecified whether the following operation is fused or not. The
+            // following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
+            let x = unsafe { fmuladdf64(a, b, c) };
+            x == 0.0
+        }),
         "`fmuladdf64` failed to be evaluated as both fused and unfused"
     );
 
-    let mut saw_zero = false;
-    let mut saw_nonzero = false;
-    for _ in 0..50 {
-        let a = std::hint::black_box(0.1_f32);
-        let b = std::hint::black_box(0.2);
-        let c = std::hint::black_box(-a * b);
-        // It is unspecified whether the following operation is fused or not. The
-        // following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
-        let x = unsafe { fmuladdf32(a, b, c) };
-        if x == 0.0 {
-            saw_zero = true;
-        } else {
-            saw_nonzero = true;
-        }
-    }
     assert!(
-        saw_zero && saw_nonzero,
+        ensure_both_happen(|| {
+            let a = std::hint::black_box(0.1_f32);
+            let b = std::hint::black_box(0.2);
+            let c = std::hint::black_box(-a * b);
+            // It is unspecified whether the following operation is fused or not. The
+            // following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
+            let x = unsafe { fmuladdf32(a, b, c) };
+            x == 0.0
+        }),
         "`fmuladdf32` failed to be evaluated as both fused and unfused"
     );
+
+    assert!(
+        ensure_both_happen(|| {
+            let a = f32x4::splat(std::hint::black_box(0.1));
+            let b = f32x4::splat(std::hint::black_box(0.2));
+            let c = std::hint::black_box(-a * b);
+            let x = unsafe { simd_relaxed_fma(a, b, c) };
+            // Whether we fuse or not is a per-element decision, so sometimes these should be
+            // the same and sometimes not.
+            x[0] == x[1]
+        }),
+        "`simd_relaxed_fma` failed to be evaluated as both fused and unfused"
+    );
+
+    assert!(
+        ensure_both_happen(|| {
+            let a = f64x4::splat(std::hint::black_box(0.1));
+            let b = f64x4::splat(std::hint::black_box(0.2));
+            let c = std::hint::black_box(-a * b);
+            let x = unsafe { simd_relaxed_fma(a, b, c) };
+            // Whether we fuse or not is a per-element decision, so sometimes these should be
+            // the same and sometimes not.
+            x[0] == x[1]
+        }),
+        "`simd_relaxed_fma` failed to be evaluated as both fused and unfused"
+    );
 }
diff --git a/src/tools/miri/tests/pass/intrinsics/portable-simd.rs b/src/tools/miri/tests/pass/intrinsics/portable-simd.rs
index f560669dd63..acd3502f528 100644
--- a/src/tools/miri/tests/pass/intrinsics/portable-simd.rs
+++ b/src/tools/miri/tests/pass/intrinsics/portable-simd.rs
@@ -40,6 +40,17 @@ fn simd_ops_f32() {
         f32x4::splat(-3.2).mul_add(b, f32x4::splat(f32::NEG_INFINITY)),
         f32x4::splat(f32::NEG_INFINITY)
     );
+
+    unsafe {
+        assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a);
+        assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a);
+        assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b);
+        assert_eq!(
+            intrinsics::simd_relaxed_fma(f32x4::splat(-3.2), b, f32x4::splat(f32::NEG_INFINITY)),
+            f32x4::splat(f32::NEG_INFINITY)
+        );
+    }
+
     assert_eq!((a * a).sqrt(), a);
     assert_eq!((b * b).sqrt(), b.abs());
 
@@ -94,6 +105,17 @@ fn simd_ops_f64() {
         f64x4::splat(-3.2).mul_add(b, f64x4::splat(f64::NEG_INFINITY)),
         f64x4::splat(f64::NEG_INFINITY)
     );
+
+    unsafe {
+        assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a);
+        assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a);
+        assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b);
+        assert_eq!(
+            intrinsics::simd_relaxed_fma(f64x4::splat(-3.2), b, f64x4::splat(f64::NEG_INFINITY)),
+            f64x4::splat(f64::NEG_INFINITY)
+        );
+    }
+
     assert_eq!((a * a).sqrt(), a);
     assert_eq!((b * b).sqrt(), b.abs());