about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2023-12-28 12:14:01 +0100
committerRalf Jung <post@ralfj.de>2023-12-28 12:14:06 +0100
commitee42d1eb9f5983f8f281e87989617caba45007b6 (patch)
tree83317d5e42c7b009436ca5a6a29f68911272e91c
parent0f98c0e6102cadfc8c0b759d09c7f1b73b07564c (diff)
downloadrust-ee42d1eb9f5983f8f281e87989617caba45007b6.tar.gz
rust-ee42d1eb9f5983f8f281e87989617caba45007b6.zip
NaN non-determinism for SIMD intrinsics
-rw-r--r--src/tools/miri/src/shims/intrinsics/simd.rs162
-rw-r--r--src/tools/miri/tests/pass/float_nan.rs58
2 files changed, 145 insertions, 75 deletions
diff --git a/src/tools/miri/src/shims/intrinsics/simd.rs b/src/tools/miri/src/shims/intrinsics/simd.rs
index 2c8493d8aad..d749182ed5e 100644
--- a/src/tools/miri/src/shims/intrinsics/simd.rs
+++ b/src/tools/miri/src/shims/intrinsics/simd.rs
@@ -5,10 +5,17 @@ use rustc_span::{sym, Symbol};
 use rustc_target::abi::{Endian, HasDataLayout};
 
 use crate::helpers::{
-    bool_to_simd_element, check_arg_count, round_to_next_multiple_of, simd_element_to_bool,
+    bool_to_simd_element, check_arg_count, round_to_next_multiple_of, simd_element_to_bool, ToHost,
+    ToSoft,
 };
 use crate::*;
 
+#[derive(Copy, Clone)]
+pub(crate) enum MinMax {
+    Min,
+    Max,
+}
+
 impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
 pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
     /// Calls the simd intrinsic `intrinsic`; the `simd_` prefix has already been removed.
@@ -67,13 +74,17 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     let op = this.read_immediate(&this.project_index(&op, i)?)?;
                     let dest = this.project_index(&dest, i)?;
                     let val = match which {
-                        Op::MirOp(mir_op) => this.wrapping_unary_op(mir_op, &op)?.to_scalar(),
+                        Op::MirOp(mir_op) => {
+                            // This already does NaN adjustments
+                            this.wrapping_unary_op(mir_op, &op)?.to_scalar()
+                        }
                         Op::Abs => {
                             // Works for f32 and f64.
                             let ty::Float(float_ty) = op.layout.ty.kind() else {
                                 span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
                             };
                             let op = op.to_scalar();
+                            // "Bitwise" operation, no NaN adjustments
                             match float_ty {
                                 FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
                                 FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
@@ -86,14 +97,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                             // FIXME using host floats
                             match float_ty {
                                 FloatTy::F32 => {
-                                    let f = f32::from_bits(op.to_scalar().to_u32()?);
-                                    let res = f.sqrt();
-                                    Scalar::from_u32(res.to_bits())
+                                    let f = op.to_scalar().to_f32()?;
+                                    let res = f.to_host().sqrt().to_soft();
+                                    let res = this.adjust_nan(res, &[f]);
+                                    Scalar::from(res)
                                 }
                                 FloatTy::F64 => {
-                                    let f = f64::from_bits(op.to_scalar().to_u64()?);
-                                    let res = f.sqrt();
-                                    Scalar::from_u64(res.to_bits())
+                                    let f = op.to_scalar().to_f64()?;
+                                    let res = f.to_host().sqrt().to_soft();
+                                    let res = this.adjust_nan(res, &[f]);
+                                    Scalar::from(res)
                                 }
                             }
                         }
@@ -105,11 +118,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                                 FloatTy::F32 => {
                                     let f = op.to_scalar().to_f32()?;
                                     let res = f.round_to_integral(rounding).value;
+                                    let res = this.adjust_nan(res, &[f]);
                                     Scalar::from_f32(res)
                                 }
                                 FloatTy::F64 => {
                                     let f = op.to_scalar().to_f64()?;
                                     let res = f.round_to_integral(rounding).value;
+                                    let res = this.adjust_nan(res, &[f]);
                                     Scalar::from_f64(res)
                                 }
                             }
@@ -157,8 +172,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 enum Op {
                     MirOp(BinOp),
                     SaturatingOp(BinOp),
-                    FMax,
-                    FMin,
+                    FMinMax(MinMax),
                     WrappingOffset,
                 }
                 let which = match intrinsic_name {
@@ -178,8 +192,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     "le" => Op::MirOp(BinOp::Le),
                     "gt" => Op::MirOp(BinOp::Gt),
                     "ge" => Op::MirOp(BinOp::Ge),
-                    "fmax" => Op::FMax,
-                    "fmin" => Op::FMin,
+                    "fmax" => Op::FMinMax(MinMax::Max),
+                    "fmin" => Op::FMinMax(MinMax::Min),
                     "saturating_add" => Op::SaturatingOp(BinOp::Add),
                     "saturating_sub" => Op::SaturatingOp(BinOp::Sub),
                     "arith_offset" => Op::WrappingOffset,
@@ -192,6 +206,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     let dest = this.project_index(&dest, i)?;
                     let val = match which {
                         Op::MirOp(mir_op) => {
+                            // This does NaN adjustments.
                             let (val, overflowed) = this.overflowing_binary_op(mir_op, &left, &right)?;
                             if matches!(mir_op, BinOp::Shl | BinOp::Shr) {
                                 // Shifts have extra UB as SIMD operations that the MIR binop does not have.
@@ -225,11 +240,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                             let offset_ptr = ptr.wrapping_signed_offset(offset_bytes, this);
                             Scalar::from_maybe_pointer(offset_ptr, this)
                         }
-                        Op::FMax => {
-                            fmax_op(&left, &right)?
-                        }
-                        Op::FMin => {
-                            fmin_op(&left, &right)?
+                        Op::FMinMax(op) => {
+                            this.fminmax_op(op, &left, &right)?
                         }
                     };
                     this.write_scalar(val, &dest)?;
@@ -259,18 +271,20 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     };
                     let val = match float_ty {
                         FloatTy::F32 => {
-                            let a = f32::from_bits(a.to_u32()?);
-                            let b = f32::from_bits(b.to_u32()?);
-                            let c = f32::from_bits(c.to_u32()?);
-                            let res = a.mul_add(b, c);
-                            Scalar::from_u32(res.to_bits())
+                            let a = a.to_f32()?;
+                            let b = b.to_f32()?;
+                            let c = c.to_f32()?;
+                            let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
+                            let res = this.adjust_nan(res, &[a, b, c]);
+                            Scalar::from(res)
                         }
                         FloatTy::F64 => {
-                            let a = f64::from_bits(a.to_u64()?);
-                            let b = f64::from_bits(b.to_u64()?);
-                            let c = f64::from_bits(c.to_u64()?);
-                            let res = a.mul_add(b, c);
-                            Scalar::from_u64(res.to_bits())
+                            let a = a.to_f64()?;
+                            let b = b.to_f64()?;
+                            let c = c.to_f64()?;
+                            let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
+                            let res = this.adjust_nan(res, &[a, b, c]);
+                            Scalar::from(res)
                         }
                     };
                     this.write_scalar(val, &dest)?;
@@ -295,8 +309,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 enum Op {
                     MirOp(BinOp),
                     MirOpBool(BinOp),
-                    Max,
-                    Min,
+                    MinMax(MinMax),
                 }
                 let which = match intrinsic_name {
                     "reduce_and" => Op::MirOp(BinOp::BitAnd),
@@ -304,8 +317,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     "reduce_xor" => Op::MirOp(BinOp::BitXor),
                     "reduce_any" => Op::MirOpBool(BinOp::BitOr),
                     "reduce_all" => Op::MirOpBool(BinOp::BitAnd),
-                    "reduce_max" => Op::Max,
-                    "reduce_min" => Op::Min,
+                    "reduce_max" => Op::MinMax(MinMax::Max),
+                    "reduce_min" => Op::MinMax(MinMax::Min),
                     _ => unreachable!(),
                 };
 
@@ -325,24 +338,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                             let op = imm_from_bool(simd_element_to_bool(op)?);
                             this.wrapping_binary_op(mir_op, &res, &op)?
                         }
-                        Op::Max => {
-                            if matches!(res.layout.ty.kind(), ty::Float(_)) {
-                                ImmTy::from_scalar(fmax_op(&res, &op)?, res.layout)
-                            } else {
-                                // Just boring integers, so NaNs to worry about
-                                if this.wrapping_binary_op(BinOp::Ge, &res, &op)?.to_scalar().to_bool()? {
-                                    res
-                                } else {
-                                    op
-                                }
-                            }
-                        }
-                        Op::Min => {
+                        Op::MinMax(mmop) => {
                             if matches!(res.layout.ty.kind(), ty::Float(_)) {
-                                ImmTy::from_scalar(fmin_op(&res, &op)?, res.layout)
+                                ImmTy::from_scalar(this.fminmax_op(mmop, &res, &op)?, res.layout)
                             } else {
                                 // Just boring integers, so NaNs to worry about
-                                if this.wrapping_binary_op(BinOp::Le, &res, &op)?.to_scalar().to_bool()? {
+                                let mirop = match mmop {
+                                    MinMax::Min => BinOp::Le,
+                                    MinMax::Max => BinOp::Ge,
+                                };
+                                if this.wrapping_binary_op(mirop, &res, &op)?.to_scalar().to_bool()? {
                                     res
                                 } else {
                                     op
@@ -709,6 +714,43 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         }
         Ok(())
     }
+
+    fn fminmax_op(
+        &self,
+        op: MinMax,
+        left: &ImmTy<'tcx, Provenance>,
+        right: &ImmTy<'tcx, Provenance>,
+    ) -> InterpResult<'tcx, Scalar<Provenance>> {
+        let this = self.eval_context_ref();
+        assert_eq!(left.layout.ty, right.layout.ty);
+        let ty::Float(float_ty) = left.layout.ty.kind() else {
+            bug!("fmax operand is not a float")
+        };
+        let left = left.to_scalar();
+        let right = right.to_scalar();
+        Ok(match float_ty {
+            FloatTy::F32 => {
+                let left = left.to_f32()?;
+                let right = right.to_f32()?;
+                let res = match op {
+                    MinMax::Min => left.min(right),
+                    MinMax::Max => left.max(right),
+                };
+                let res = this.adjust_nan(res, &[left, right]);
+                Scalar::from_f32(res)
+            }
+            FloatTy::F64 => {
+                let left = left.to_f64()?;
+                let right = right.to_f64()?;
+                let res = match op {
+                    MinMax::Min => left.min(right),
+                    MinMax::Max => left.max(right),
+                };
+                let res = this.adjust_nan(res, &[left, right]);
+                Scalar::from_f64(res)
+            }
+        })
+    }
 }
 
 fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
@@ -719,31 +761,3 @@ fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
         Endian::Big => vec_len - 1 - idx, // reverse order of bits
     }
 }
-
-fn fmax_op<'tcx>(
-    left: &ImmTy<'tcx, Provenance>,
-    right: &ImmTy<'tcx, Provenance>,
-) -> InterpResult<'tcx, Scalar<Provenance>> {
-    assert_eq!(left.layout.ty, right.layout.ty);
-    let ty::Float(float_ty) = left.layout.ty.kind() else { bug!("fmax operand is not a float") };
-    let left = left.to_scalar();
-    let right = right.to_scalar();
-    Ok(match float_ty {
-        FloatTy::F32 => Scalar::from_f32(left.to_f32()?.max(right.to_f32()?)),
-        FloatTy::F64 => Scalar::from_f64(left.to_f64()?.max(right.to_f64()?)),
-    })
-}
-
-fn fmin_op<'tcx>(
-    left: &ImmTy<'tcx, Provenance>,
-    right: &ImmTy<'tcx, Provenance>,
-) -> InterpResult<'tcx, Scalar<Provenance>> {
-    assert_eq!(left.layout.ty, right.layout.ty);
-    let ty::Float(float_ty) = left.layout.ty.kind() else { bug!("fmin operand is not a float") };
-    let left = left.to_scalar();
-    let right = right.to_scalar();
-    Ok(match float_ty {
-        FloatTy::F32 => Scalar::from_f32(left.to_f32()?.min(right.to_f32()?)),
-        FloatTy::F64 => Scalar::from_f64(left.to_f64()?.min(right.to_f64()?)),
-    })
-}
diff --git a/src/tools/miri/tests/pass/float_nan.rs b/src/tools/miri/tests/pass/float_nan.rs
index 5e717bdca00..fff103a776f 100644
--- a/src/tools/miri/tests/pass/float_nan.rs
+++ b/src/tools/miri/tests/pass/float_nan.rs
@@ -1,4 +1,4 @@
-#![feature(float_gamma)]
+#![feature(float_gamma, portable_simd, core_intrinsics, platform_intrinsics)]
 use std::collections::HashSet;
 use std::fmt;
 use std::hash::Hash;
@@ -535,6 +535,61 @@ fn test_casts() {
     );
 }
 
+fn test_simd() {
+    use std::intrinsics::simd::*;
+    use std::simd::*;
+
+    extern "platform-intrinsic" {
+        fn simd_fsqrt<T>(x: T) -> T;
+        fn simd_ceil<T>(x: T) -> T;
+        fn simd_fma<T>(x: T, y: T, z: T) -> T;
+    }
+
+    let nan = F32::nan(Neg, Quiet, 0).as_f32();
+    check_all_outcomes(
+        HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
+        || F32::from(unsafe { simd_div(f32x4::splat(0.0), f32x4::splat(0.0)) }[0]),
+    );
+    check_all_outcomes(
+        HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
+        || F32::from(unsafe { simd_fmin(f32x4::splat(nan), f32x4::splat(nan)) }[0]),
+    );
+    check_all_outcomes(
+        HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
+        || F32::from(unsafe { simd_fmax(f32x4::splat(nan), f32x4::splat(nan)) }[0]),
+    );
+    check_all_outcomes(
+        HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
+        || {
+            F32::from(
+                unsafe { simd_fma(f32x4::splat(nan), f32x4::splat(nan), f32x4::splat(nan)) }[0],
+            )
+        },
+    );
+    check_all_outcomes(
+        HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
+        || F32::from(unsafe { simd_reduce_add_ordered::<_, f32>(f32x4::splat(nan), nan) }),
+    );
+    check_all_outcomes(
+        HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
+        || F32::from(unsafe { simd_reduce_max::<_, f32>(f32x4::splat(nan)) }),
+    );
+    check_all_outcomes(
+        HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
+        || F32::from(unsafe { simd_fsqrt(f32x4::splat(nan)) }[0]),
+    );
+    check_all_outcomes(
+        HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]),
+        || F32::from(unsafe { simd_ceil(f32x4::splat(nan)) }[0]),
+    );
+
+    // Casts
+    check_all_outcomes(
+        HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]),
+        || F64::from(unsafe { simd_cast::<f32x4, f64x4>(f32x4::splat(nan)) }[0]),
+    );
+}
+
 fn main() {
     // Check our constants against std, just to be sure.
     // We add 1 since our numbers are the number of bits stored
@@ -546,4 +601,5 @@ fn main() {
     test_f32();
     test_f64();
     test_casts();
+    test_simd();
 }