about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-10-06 15:44:37 +0000
committerbors <bors@rust-lang.org>2023-10-06 15:44:37 +0000
commit375ff3e5ce54156db1443dca021e72a6ff8ec75c (patch)
treee49a9fe3ab75fa27dbac8c783a472d54f734d9cf /src
parent3c511bb41780f50cf7eef5a5e1e8a8804ae13ab9 (diff)
parente1e880e9c6a8c756195720a51c0c0b5b1819d959 (diff)
downloadrust-375ff3e5ce54156db1443dca021e72a6ff8ec75c.tar.gz
rust-375ff3e5ce54156db1443dca021e72a6ff8ec75c.zip
Auto merge of #3110 - eduardosm:rounding-without-host-floats, r=RalfJung
Do not use host floats in `simd_{ceil,floor,round,trunc}`
Diffstat (limited to 'src')
-rw-r--r--src/tools/miri/src/shims/intrinsics/simd.rs57
1 files changed, 27 insertions, 30 deletions
diff --git a/src/tools/miri/src/shims/intrinsics/simd.rs b/src/tools/miri/src/shims/intrinsics/simd.rs
index 200f37efa27..70f90aac2c2 100644
--- a/src/tools/miri/src/shims/intrinsics/simd.rs
+++ b/src/tools/miri/src/shims/intrinsics/simd.rs
@@ -33,27 +33,20 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                 assert_eq!(dest_len, op_len);
 
                 #[derive(Copy, Clone)]
-                enum HostFloatOp {
-                    Ceil,
-                    Floor,
-                    Round,
-                    Trunc,
-                    Sqrt,
-                }
-                #[derive(Copy, Clone)]
                 enum Op {
                     MirOp(mir::UnOp),
                     Abs,
-                    HostOp(HostFloatOp),
+                    Sqrt,
+                    Round(rustc_apfloat::Round),
                 }
                 let which = match intrinsic_name {
                     "neg" => Op::MirOp(mir::UnOp::Neg),
                     "fabs" => Op::Abs,
-                    "ceil" => Op::HostOp(HostFloatOp::Ceil),
-                    "floor" => Op::HostOp(HostFloatOp::Floor),
-                    "round" => Op::HostOp(HostFloatOp::Round),
-                    "trunc" => Op::HostOp(HostFloatOp::Trunc),
-                    "fsqrt" => Op::HostOp(HostFloatOp::Sqrt),
+                    "fsqrt" => Op::Sqrt,
+                    "ceil" => Op::Round(rustc_apfloat::Round::TowardPositive),
+                    "floor" => Op::Round(rustc_apfloat::Round::TowardNegative),
+                    "round" => Op::Round(rustc_apfloat::Round::NearestTiesToAway),
+                    "trunc" => Op::Round(rustc_apfloat::Round::TowardZero),
                     _ => unreachable!(),
                 };
 
@@ -73,7 +66,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                                 FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
                             }
                         }
-                        Op::HostOp(host_op) => {
+                        Op::Sqrt => {
                             let ty::Float(float_ty) = op.layout.ty.kind() else {
                                 span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
                             };
@@ -81,28 +74,32 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                             match float_ty {
                                 FloatTy::F32 => {
                                     let f = f32::from_bits(op.to_scalar().to_u32()?);
-                                    let res = match host_op {
-                                        HostFloatOp::Ceil => f.ceil(),
-                                        HostFloatOp::Floor => f.floor(),
-                                        HostFloatOp::Round => f.round(),
-                                        HostFloatOp::Trunc => f.trunc(),
-                                        HostFloatOp::Sqrt => f.sqrt(),
-                                    };
+                                    let res = f.sqrt();
                                     Scalar::from_u32(res.to_bits())
                                 }
                                 FloatTy::F64 => {
                                     let f = f64::from_bits(op.to_scalar().to_u64()?);
-                                    let res = match host_op {
-                                        HostFloatOp::Ceil => f.ceil(),
-                                        HostFloatOp::Floor => f.floor(),
-                                        HostFloatOp::Round => f.round(),
-                                        HostFloatOp::Trunc => f.trunc(),
-                                        HostFloatOp::Sqrt => f.sqrt(),
-                                    };
+                                    let res = f.sqrt();
                                     Scalar::from_u64(res.to_bits())
                                 }
                             }
-
+                        }
+                        Op::Round(rounding) => {
+                            let ty::Float(float_ty) = op.layout.ty.kind() else {
+                                span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
+                            };
+                            match float_ty {
+                                FloatTy::F32 => {
+                                    let f = op.to_scalar().to_f32()?;
+                                    let res = f.round_to_integral(rounding).value;
+                                    Scalar::from_f32(res)
+                                }
+                                FloatTy::F64 => {
+                                    let f = op.to_scalar().to_f64()?;
+                                    let res = f.round_to_integral(rounding).value;
+                                    Scalar::from_f64(res)
+                                }
+                            }
                         }
                     };
                     this.write_scalar(val, &dest)?;