diff options
| -rw-r--r-- | src/tools/miri/src/shims/x86/mod.rs | 50 | ||||
| -rw-r--r-- | src/tools/miri/src/shims/x86/sse41.rs | 39 |
2 files changed, 52 insertions, 37 deletions
diff --git a/src/tools/miri/src/shims/x86/mod.rs b/src/tools/miri/src/shims/x86/mod.rs index 6d361f5d2a5..d8c3b4826a9 100644 --- a/src/tools/miri/src/shims/x86/mod.rs +++ b/src/tools/miri/src/shims/x86/mod.rs @@ -616,6 +616,56 @@ fn horizontal_bin_op<'tcx>( Ok(()) } +/// Conditionally multiplies the packed floating-point elements in +/// `left` and `right` using the high 4 bits in `imm`, sums the calculated +/// products (up to 4), and conditionally stores the sum in `dest` using +/// the low 4 bits of `imm`. +fn conditional_dot_product<'tcx>( + this: &mut crate::MiriInterpCx<'_, 'tcx>, + left: &OpTy<'tcx, Provenance>, + right: &OpTy<'tcx, Provenance>, + imm: &OpTy<'tcx, Provenance>, + dest: &PlaceTy<'tcx, Provenance>, +) -> InterpResult<'tcx, ()> { + let (left, left_len) = this.operand_to_simd(left)?; + let (right, right_len) = this.operand_to_simd(right)?; + let (dest, dest_len) = this.place_to_simd(dest)?; + + assert_eq!(left_len, right_len); + assert!(dest_len <= 4); + + let imm = this.read_scalar(imm)?.to_u8()?; + + let element_layout = left.layout.field(this, 0); + + // Calculate dot product + // Elements are floating point numbers, but we can use `from_int` + // because the representation of 0.0 is all zero bits. + let mut sum = ImmTy::from_int(0u8, element_layout); + for i in 0..left_len { + if imm & (1 << i.checked_add(4).unwrap()) != 0 { + let left = this.read_immediate(&this.project_index(&left, i)?)?; + let right = this.read_immediate(&this.project_index(&right, i)?)?; + + let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?; + sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?; + } + } + + // Write to destination (conditioned to imm) + for i in 0..dest_len { + let dest = this.project_index(&dest, i)?; + + if imm & (1 << i) != 0 { + this.write_immediate(*sum, &dest)?; + } else { + this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?; + } + } + + Ok(()) +} + /// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`. fn bin_op_folded<'tcx, T>( this: &crate::MiriInterpCx<'_, 'tcx>, diff --git a/src/tools/miri/src/shims/x86/sse41.rs b/src/tools/miri/src/shims/x86/sse41.rs index 2683105b00b..08e3404a224 100644 --- a/src/tools/miri/src/shims/x86/sse41.rs +++ b/src/tools/miri/src/shims/x86/sse41.rs @@ -1,8 +1,7 @@ -use rustc_middle::mir; use rustc_span::Symbol; use rustc_target::spec::abi::Abi; -use super::{bin_op_folded, round_all, round_first}; +use super::{bin_op_folded, conditional_dot_product, round_all, round_first}; use crate::*; use shims::foreign_items::EmulateForeignItemResult; @@ -104,41 +103,7 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: let [left, right, imm] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; - let (left, left_len) = this.operand_to_simd(left)?; - let (right, right_len) = this.operand_to_simd(right)?; - let (dest, dest_len) = this.place_to_simd(dest)?; - - assert_eq!(left_len, right_len); - assert!(dest_len <= 4); - - let imm = this.read_scalar(imm)?.to_u8()?; - - let element_layout = left.layout.field(this, 0); - - // Calculate dot product - // Elements are floating point numbers, but we can use `from_int` - // because the representation of 0.0 is all zero bits. - let mut sum = ImmTy::from_int(0u8, element_layout); - for i in 0..left_len { - if imm & (1 << i.checked_add(4).unwrap()) != 0 { - let left = this.read_immediate(&this.project_index(&left, i)?)?; - let right = this.read_immediate(&this.project_index(&right, i)?)?; - - let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?; - sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?; - } - } - - // Write to destination (conditioned to imm) - for i in 0..dest_len { - let dest = this.project_index(&dest, i)?; - - if imm & (1 << i) != 0 { - this.write_immediate(*sum, &dest)?; - } else { - this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?; - } - } + conditional_dot_product(this, left, right, imm, dest)?; } // Used to implement the _mm_floor_ss, _mm_ceil_ss and _mm_round_ss // functions. Rounds the first element of `right` according to `rounding` |
