about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/miri/src/shims/x86/mod.rs50
-rw-r--r--src/tools/miri/src/shims/x86/sse41.rs39
2 files changed, 52 insertions, 37 deletions
diff --git a/src/tools/miri/src/shims/x86/mod.rs b/src/tools/miri/src/shims/x86/mod.rs
index 6d361f5d2a5..d8c3b4826a9 100644
--- a/src/tools/miri/src/shims/x86/mod.rs
+++ b/src/tools/miri/src/shims/x86/mod.rs
@@ -616,6 +616,56 @@ fn horizontal_bin_op<'tcx>(
     Ok(())
 }
 
+/// Conditionally multiplies the packed floating-point elements in
+/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
+/// products (up to 4), and conditionally stores the sum in `dest` using
+/// the low 4 bits of `imm`.
+fn conditional_dot_product<'tcx>(
+    this: &mut crate::MiriInterpCx<'_, 'tcx>,
+    left: &OpTy<'tcx, Provenance>,
+    right: &OpTy<'tcx, Provenance>,
+    imm: &OpTy<'tcx, Provenance>,
+    dest: &PlaceTy<'tcx, Provenance>,
+) -> InterpResult<'tcx, ()> {
+    let (left, left_len) = this.operand_to_simd(left)?;
+    let (right, right_len) = this.operand_to_simd(right)?;
+    let (dest, dest_len) = this.place_to_simd(dest)?;
+
+    assert_eq!(left_len, right_len);
+    assert!(dest_len <= 4);
+
+    let imm = this.read_scalar(imm)?.to_u8()?;
+
+    let element_layout = left.layout.field(this, 0);
+
+    // Calculate dot product
+    // Elements are floating point numbers, but we can use `from_int`
+    // because the representation of 0.0 is all zero bits.
+    let mut sum = ImmTy::from_int(0u8, element_layout);
+    for i in 0..left_len {
+        if imm & (1 << i.checked_add(4).unwrap()) != 0 {
+            let left = this.read_immediate(&this.project_index(&left, i)?)?;
+            let right = this.read_immediate(&this.project_index(&right, i)?)?;
+
+            let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
+            sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
+        }
+    }
+
+    // Write to destination (conditioned to imm)
+    for i in 0..dest_len {
+        let dest = this.project_index(&dest, i)?;
+
+        if imm & (1 << i) != 0 {
+            this.write_immediate(*sum, &dest)?;
+        } else {
+            this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
+        }
+    }
+
+    Ok(())
+}
+
 /// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
 fn bin_op_folded<'tcx, T>(
     this: &crate::MiriInterpCx<'_, 'tcx>,
diff --git a/src/tools/miri/src/shims/x86/sse41.rs b/src/tools/miri/src/shims/x86/sse41.rs
index 2683105b00b..08e3404a224 100644
--- a/src/tools/miri/src/shims/x86/sse41.rs
+++ b/src/tools/miri/src/shims/x86/sse41.rs
@@ -1,8 +1,7 @@
-use rustc_middle::mir;
 use rustc_span::Symbol;
 use rustc_target::spec::abi::Abi;
 
-use super::{bin_op_folded, round_all, round_first};
+use super::{bin_op_folded, conditional_dot_product, round_all, round_first};
 use crate::*;
 use shims::foreign_items::EmulateForeignItemResult;
 
@@ -104,41 +103,7 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
                 let [left, right, imm] =
                     this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
 
-                let (left, left_len) = this.operand_to_simd(left)?;
-                let (right, right_len) = this.operand_to_simd(right)?;
-                let (dest, dest_len) = this.place_to_simd(dest)?;
-
-                assert_eq!(left_len, right_len);
-                assert!(dest_len <= 4);
-
-                let imm = this.read_scalar(imm)?.to_u8()?;
-
-                let element_layout = left.layout.field(this, 0);
-
-                // Calculate dot product
-                // Elements are floating point numbers, but we can use `from_int`
-                // because the representation of 0.0 is all zero bits.
-                let mut sum = ImmTy::from_int(0u8, element_layout);
-                for i in 0..left_len {
-                    if imm & (1 << i.checked_add(4).unwrap()) != 0 {
-                        let left = this.read_immediate(&this.project_index(&left, i)?)?;
-                        let right = this.read_immediate(&this.project_index(&right, i)?)?;
-
-                        let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
-                        sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
-                    }
-                }
-
-                // Write to destination (conditioned to imm)
-                for i in 0..dest_len {
-                    let dest = this.project_index(&dest, i)?;
-
-                    if imm & (1 << i) != 0 {
-                        this.write_immediate(*sum, &dest)?;
-                    } else {
-                        this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
-                    }
-                }
+                conditional_dot_product(this, left, right, imm, dest)?;
             }
             // Used to implement the _mm_floor_ss, _mm_ceil_ss and _mm_round_ss
             // functions. Rounds the first element of `right` according to `rounding`