about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2023-11-20 17:20:59 +0000
committerbors <bors@rust-lang.org>2023-11-20 17:20:59 +0000
commit607208cd8dd54ca6e22246fdb13486556ff1ce6d (patch)
tree40bd55d4346a7ee758ffb555bdc2186cc2c13a5c /src
parent6730f222ceb2780217e37e26d3a43f60f034a2dd (diff)
parent81303e7ea5209dbfbdf359b9311681d4151c4c5a (diff)
downloadrust-607208cd8dd54ca6e22246fdb13486556ff1ce6d.tar.gz
rust-607208cd8dd54ca6e22246fdb13486556ff1ce6d.zip
Auto merge of #3176 - eduardosm:cmp, r=RalfJung
Implement all 16 AVX compare operators for 128-bit SIMD vectors

`_mm_cmp_{ss,ps,sd,pd}` functions are AVX functions that use `llvm.x86.sse{,2}.` prefixed intrinsics, so they were "accidentally" partially implemented when SSE and SSE2 intrinsics were implemented.

The 16 AVX compare operators are now implemented and tested.
Diffstat (limited to 'src')
-rw-r--r--src/tools/miri/src/shims/x86/mod.rs120
-rw-r--r--src/tools/miri/src/shims/x86/sse.rs22
-rw-r--r--src/tools/miri/src/shims/x86/sse2.rs20
-rw-r--r--src/tools/miri/tests/pass/intrinsics-x86-aes-vaes.rs2
-rw-r--r--src/tools/miri/tests/pass/intrinsics-x86-avx.rs162
-rw-r--r--src/tools/miri/tests/pass/intrinsics-x86-avx512.rs2
-rw-r--r--src/tools/miri/tests/pass/intrinsics-x86-sse3-ssse3.rs2
-rw-r--r--src/tools/miri/tests/pass/intrinsics-x86-sse41.rs2
8 files changed, 261 insertions, 71 deletions
diff --git a/src/tools/miri/src/shims/x86/mod.rs b/src/tools/miri/src/shims/x86/mod.rs
index d88a3127ecc..2a2171134d4 100644
--- a/src/tools/miri/src/shims/x86/mod.rs
+++ b/src/tools/miri/src/shims/x86/mod.rs
@@ -119,53 +119,32 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
     }
 }
 
-/// Floating point comparison operation
-///
-/// <https://www.felixcloutier.com/x86/cmpss>
-/// <https://www.felixcloutier.com/x86/cmpps>
-/// <https://www.felixcloutier.com/x86/cmpsd>
-/// <https://www.felixcloutier.com/x86/cmppd>
-#[derive(Copy, Clone)]
-enum FloatCmpOp {
-    Eq,
-    Lt,
-    Le,
-    Unord,
-    Neq,
-    /// Not less-than
-    Nlt,
-    /// Not less-or-equal
-    Nle,
-    /// Ordered, i.e. neither of them is NaN
-    Ord,
-}
-
-impl FloatCmpOp {
-    /// Convert from the `imm` argument used to specify the comparison
-    /// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
-    fn from_intrinsic_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
-        match imm {
-            0 => Ok(Self::Eq),
-            1 => Ok(Self::Lt),
-            2 => Ok(Self::Le),
-            3 => Ok(Self::Unord),
-            4 => Ok(Self::Neq),
-            5 => Ok(Self::Nlt),
-            6 => Ok(Self::Nle),
-            7 => Ok(Self::Ord),
-            imm => {
-                throw_unsup_format!("invalid `imm` parameter of {intrinsic}: {imm}");
-            }
-        }
-    }
-}
-
 #[derive(Copy, Clone)]
 enum FloatBinOp {
     /// Arithmetic operation
     Arith(mir::BinOp),
     /// Comparison
-    Cmp(FloatCmpOp),
+    ///
+    /// The semantics of this operator is a case distinction: we compare the two operands,
+    /// and then we return one of the four booleans `gt`, `lt`, `eq`, `unord` depending on
+    /// which class they fall into.
+    ///
+    /// AVX supports all 16 combinations, SSE only a subset
+    ///
+    /// <https://www.felixcloutier.com/x86/cmpss>
+    /// <https://www.felixcloutier.com/x86/cmpps>
+    /// <https://www.felixcloutier.com/x86/cmpsd>
+    /// <https://www.felixcloutier.com/x86/cmppd>
+    Cmp {
+        /// Result when lhs < rhs
+        gt: bool,
+        /// Result when lhs > rhs
+        lt: bool,
+        /// Result when lhs == rhs
+        eq: bool,
+        /// Result when lhs is NaN or rhs is NaN
+        unord: bool,
+    },
     /// Minimum value (with SSE semantics)
     ///
     /// <https://www.felixcloutier.com/x86/minss>
@@ -182,6 +161,44 @@ enum FloatBinOp {
     Max,
 }
 
+impl FloatBinOp {
+    /// Convert from the `imm` argument used to specify the comparison
+    /// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
+    fn cmp_from_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
+        // Only bits 0..=4 are used, remaining should be zero.
+        if imm & !0b1_1111 != 0 {
+            throw_unsup_format!("invalid `imm` parameter of {intrinsic}: 0x{imm:x}");
+        }
+        // Bit 4 specifies whether the operation is quiet or signaling, which
+        // we do not care in Miri.
+        // Bits 0..=2 specifies the operation.
+        // `gt` indicates the result to be returned when the LHS is strictly
+        // greater than the RHS, and so on.
+        let (gt, lt, eq, unord) = match imm & 0b111 {
+            // Equal
+            0x0 => (false, false, true, false),
+            // Less-than
+            0x1 => (false, true, false, false),
+            // Less-or-equal
+            0x2 => (false, true, true, false),
+            // Unordered (either is NaN)
+            0x3 => (false, false, false, true),
+            // Not equal
+            0x4 => (true, true, false, true),
+            // Not less-than
+            0x5 => (true, false, true, true),
+            // Not less-or-equal
+            0x6 => (true, false, false, true),
+            // Ordered (neither is NaN)
+            0x7 => (true, true, true, false),
+            _ => unreachable!(),
+        };
+        // When bit 3 is 1 (only possible in AVX), unord is toggled.
+        let unord = unord ^ (imm & 0b1000 != 0);
+        Ok(Self::Cmp { gt, lt, eq, unord })
+    }
+}
+
 /// Performs `which` scalar operation on `left` and `right` and returns
 /// the result.
 fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
@@ -195,20 +212,15 @@ fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
             let res = this.wrapping_binary_op(which, left, right)?;
             Ok(res.to_scalar())
         }
-        FloatBinOp::Cmp(which) => {
+        FloatBinOp::Cmp { gt, lt, eq, unord } => {
             let left = left.to_scalar().to_float::<F>()?;
             let right = right.to_scalar().to_float::<F>()?;
-            // FIXME: Make sure that these operations match the semantics
-            // of cmpps/cmpss/cmppd/cmpsd
-            let res = match which {
-                FloatCmpOp::Eq => left == right,
-                FloatCmpOp::Lt => left < right,
-                FloatCmpOp::Le => left <= right,
-                FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
-                FloatCmpOp::Neq => left != right,
-                FloatCmpOp::Nlt => !(left < right),
-                FloatCmpOp::Nle => !(left <= right),
-                FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
+
+            let res = match left.partial_cmp(&right) {
+                None => unord,
+                Some(std::cmp::Ordering::Less) => lt,
+                Some(std::cmp::Ordering::Equal) => eq,
+                Some(std::cmp::Ordering::Greater) => gt,
             };
             Ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
         }
diff --git a/src/tools/miri/src/shims/x86/sse.rs b/src/tools/miri/src/shims/x86/sse.rs
index 831228b7a26..e15023c3c21 100644
--- a/src/tools/miri/src/shims/x86/sse.rs
+++ b/src/tools/miri/src/shims/x86/sse.rs
@@ -5,7 +5,7 @@ use rustc_target::spec::abi::Abi;
 
 use rand::Rng as _;
 
-use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
+use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp};
 use crate::*;
 use shims::foreign_items::EmulateForeignItemResult;
 
@@ -95,33 +95,41 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
 
                 unary_op_ps(this, which, op, dest)?;
             }
-            // Used to implement the _mm_cmp_ss function.
+            // Used to implement the _mm_cmp*_ss functions.
             // Performs a comparison operation on the first component of `left`
             // and `right`, returning 0 if false or `u32::MAX` if true. The remaining
             // components are copied from `left`.
+            // _mm_cmp_ss is actually an AVX function where the operation is specified
+            // by a const parameter.
+            // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ss are SSE functions
+            // with hard-coded operations.
             "cmp.ss" => {
                 let [left, right, imm] =
                     this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
 
-                let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
+                let which = FloatBinOp::cmp_from_imm(
                     this.read_scalar(imm)?.to_i8()?,
                     "llvm.x86.sse.cmp.ss",
-                )?);
+                )?;
 
                 bin_op_simd_float_first::<Single>(this, which, left, right, dest)?;
             }
-            // Used to implement the _mm_cmp_ps function.
+            // Used to implement the _mm_cmp*_ps functions.
             // Performs a comparison operation on each component of `left`
             // and `right`. For each component, returns 0 if false or u32::MAX
             // if true.
+            // _mm_cmp_ps is actually an AVX function where the operation is specified
+            // by a const parameter.
+            // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ps are SSE functions
+            // with hard-coded operations.
             "cmp.ps" => {
                 let [left, right, imm] =
                     this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
 
-                let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
+                let which = FloatBinOp::cmp_from_imm(
                     this.read_scalar(imm)?.to_i8()?,
                     "llvm.x86.sse.cmp.ps",
-                )?);
+                )?;
 
                 bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
             }
diff --git a/src/tools/miri/src/shims/x86/sse2.rs b/src/tools/miri/src/shims/x86/sse2.rs
index 3f2b9f5f0ad..55520771cf6 100644
--- a/src/tools/miri/src/shims/x86/sse2.rs
+++ b/src/tools/miri/src/shims/x86/sse2.rs
@@ -4,7 +4,7 @@ use rustc_middle::ty::Ty;
 use rustc_span::Symbol;
 use rustc_target::spec::abi::Abi;
 
-use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
+use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp};
 use crate::*;
 use shims::foreign_items::EmulateForeignItemResult;
 
@@ -461,18 +461,22 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
                     this.write_scalar(res, &dest)?;
                 }
             }
-            // Used to implement the _mm_cmp*_sd function.
+            // Used to implement the _mm_cmp*_sd functions.
             // Performs a comparison operation on the first component of `left`
             // and `right`, returning 0 if false or `u64::MAX` if true. The remaining
             // components are copied from `left`.
+            // _mm_cmp_sd is actually an AVX function where the operation is specified
+            // by a const parameter.
+            // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_sd are SSE2 functions
+            // with hard-coded operations.
             "cmp.sd" => {
                 let [left, right, imm] =
                     this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
 
-                let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
+                let which = FloatBinOp::cmp_from_imm(
                     this.read_scalar(imm)?.to_i8()?,
                     "llvm.x86.sse2.cmp.sd",
-                )?);
+                )?;
 
                 bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
             }
@@ -480,14 +484,18 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
             // Performs a comparison operation on each component of `left`
             // and `right`. For each component, returns 0 if false or `u64::MAX`
             // if true.
+            // _mm_cmp_pd is actually an AVX function where the operation is specified
+            // by a const parameter.
+            // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_pd are SSE2 functions
+            // with hard-coded operations.
             "cmp.pd" => {
                 let [left, right, imm] =
                     this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
 
-                let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
+                let which = FloatBinOp::cmp_from_imm(
                     this.read_scalar(imm)?.to_i8()?,
                     "llvm.x86.sse2.cmp.pd",
-                )?);
+                )?;
 
                 bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
             }
diff --git a/src/tools/miri/tests/pass/intrinsics-x86-aes-vaes.rs b/src/tools/miri/tests/pass/intrinsics-x86-aes-vaes.rs
index 090b1db0af0..55d1bacdf45 100644
--- a/src/tools/miri/tests/pass/intrinsics-x86-aes-vaes.rs
+++ b/src/tools/miri/tests/pass/intrinsics-x86-aes-vaes.rs
@@ -1,5 +1,5 @@
 // Ignore everything except x86 and x86_64
-// Any additional target are added to CI should be ignored here
+// Any new targets that are added to CI should be ignored here.
 // (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.)
 //@ignore-target-aarch64
 //@ignore-target-arm
diff --git a/src/tools/miri/tests/pass/intrinsics-x86-avx.rs b/src/tools/miri/tests/pass/intrinsics-x86-avx.rs
new file mode 100644
index 00000000000..933e3d4153a
--- /dev/null
+++ b/src/tools/miri/tests/pass/intrinsics-x86-avx.rs
@@ -0,0 +1,162 @@
+// Ignore everything except x86 and x86_64
+// Any new targets that are added to CI should be ignored here.
+// (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.)
+//@ignore-target-aarch64
+//@ignore-target-arm
+//@ignore-target-avr
+//@ignore-target-s390x
+//@ignore-target-thumbv7em
+//@ignore-target-wasm32
+//@compile-flags: -C target-feature=+avx
+
+#[cfg(target_arch = "x86")]
+use std::arch::x86::*;
+#[cfg(target_arch = "x86_64")]
+use std::arch::x86_64::*;
+use std::mem::transmute;
+
+fn main() {
+    assert!(is_x86_feature_detected!("avx"));
+
+    unsafe {
+        test_avx();
+    }
+}
+
+#[target_feature(enable = "avx")]
+unsafe fn test_avx() {
+    fn expected_cmp<F: PartialOrd>(imm: i32, lhs: F, rhs: F, if_t: F, if_f: F) -> F {
+        let res = match imm {
+            _CMP_EQ_OQ => lhs == rhs,
+            _CMP_LT_OS => lhs < rhs,
+            _CMP_LE_OS => lhs <= rhs,
+            _CMP_UNORD_Q => lhs.partial_cmp(&rhs).is_none(),
+            _CMP_NEQ_UQ => lhs != rhs,
+            _CMP_NLT_UQ => !(lhs < rhs),
+            _CMP_NLE_UQ => !(lhs <= rhs),
+            _CMP_ORD_Q => lhs.partial_cmp(&rhs).is_some(),
+            _CMP_EQ_UQ => lhs == rhs || lhs.partial_cmp(&rhs).is_none(),
+            _CMP_NGE_US => !(lhs >= rhs),
+            _CMP_NGT_US => !(lhs > rhs),
+            _CMP_FALSE_OQ => false,
+            _CMP_NEQ_OQ => lhs != rhs && lhs.partial_cmp(&rhs).is_some(),
+            _CMP_GE_OS => lhs >= rhs,
+            _CMP_GT_OS => lhs > rhs,
+            _CMP_TRUE_US => true,
+            _ => unreachable!(),
+        };
+        if res { if_t } else { if_f }
+    }
+    fn expected_cmp_f32(imm: i32, lhs: f32, rhs: f32) -> f32 {
+        expected_cmp(imm, lhs, rhs, f32::from_bits(u32::MAX), 0.0)
+    }
+    fn expected_cmp_f64(imm: i32, lhs: f64, rhs: f64) -> f64 {
+        expected_cmp(imm, lhs, rhs, f64::from_bits(u64::MAX), 0.0)
+    }
+
+    #[target_feature(enable = "avx")]
+    unsafe fn test_mm_cmp_ss<const IMM: i32>() {
+        let values = [
+            (1.0, 1.0),
+            (0.0, 1.0),
+            (1.0, 0.0),
+            (f32::NAN, 0.0),
+            (0.0, f32::NAN),
+            (f32::NAN, f32::NAN),
+        ];
+
+        for (lhs, rhs) in values {
+            let a = _mm_setr_ps(lhs, 2.0, 3.0, 4.0);
+            let b = _mm_setr_ps(rhs, 5.0, 6.0, 7.0);
+            let r: [u32; 4] = transmute(_mm_cmp_ss::<IMM>(a, b));
+            let e: [u32; 4] =
+                transmute(_mm_setr_ps(expected_cmp_f32(IMM, lhs, rhs), 2.0, 3.0, 4.0));
+            assert_eq!(r, e);
+        }
+    }
+
+    #[target_feature(enable = "avx")]
+    unsafe fn test_mm_cmp_ps<const IMM: i32>() {
+        let values = [
+            (1.0, 1.0),
+            (0.0, 1.0),
+            (1.0, 0.0),
+            (f32::NAN, 0.0),
+            (0.0, f32::NAN),
+            (f32::NAN, f32::NAN),
+        ];
+
+        for (lhs, rhs) in values {
+            let a = _mm_set1_ps(lhs);
+            let b = _mm_set1_ps(rhs);
+            let r: [u32; 4] = transmute(_mm_cmp_ps::<IMM>(a, b));
+            let e: [u32; 4] = transmute(_mm_set1_ps(expected_cmp_f32(IMM, lhs, rhs)));
+            assert_eq!(r, e);
+        }
+    }
+
+    #[target_feature(enable = "avx")]
+    unsafe fn test_mm_cmp_sd<const IMM: i32>() {
+        let values = [
+            (1.0, 1.0),
+            (0.0, 1.0),
+            (1.0, 0.0),
+            (f64::NAN, 0.0),
+            (0.0, f64::NAN),
+            (f64::NAN, f64::NAN),
+        ];
+
+        for (lhs, rhs) in values {
+            let a = _mm_setr_pd(lhs, 2.0);
+            let b = _mm_setr_pd(rhs, 3.0);
+            let r: [u64; 2] = transmute(_mm_cmp_sd::<IMM>(a, b));
+            let e: [u64; 2] = transmute(_mm_setr_pd(expected_cmp_f64(IMM, lhs, rhs), 2.0));
+            assert_eq!(r, e);
+        }
+    }
+
+    #[target_feature(enable = "avx")]
+    unsafe fn test_mm_cmp_pd<const IMM: i32>() {
+        let values = [
+            (1.0, 1.0),
+            (0.0, 1.0),
+            (1.0, 0.0),
+            (f64::NAN, 0.0),
+            (0.0, f64::NAN),
+            (f64::NAN, f64::NAN),
+        ];
+
+        for (lhs, rhs) in values {
+            let a = _mm_set1_pd(lhs);
+            let b = _mm_set1_pd(rhs);
+            let r: [u64; 2] = transmute(_mm_cmp_pd::<IMM>(a, b));
+            let e: [u64; 2] = transmute(_mm_set1_pd(expected_cmp_f64(IMM, lhs, rhs)));
+            assert_eq!(r, e);
+        }
+    }
+
+    #[target_feature(enable = "avx")]
+    unsafe fn test_cmp<const IMM: i32>() {
+        test_mm_cmp_ss::<IMM>();
+        test_mm_cmp_ps::<IMM>();
+        test_mm_cmp_sd::<IMM>();
+        test_mm_cmp_pd::<IMM>();
+    }
+
+    test_cmp::<_CMP_EQ_OQ>();
+    test_cmp::<_CMP_LT_OS>();
+    test_cmp::<_CMP_LE_OS>();
+    test_cmp::<_CMP_UNORD_Q>();
+    test_cmp::<_CMP_NEQ_UQ>();
+    test_cmp::<_CMP_NLT_UQ>();
+    test_cmp::<_CMP_NLE_UQ>();
+    test_cmp::<_CMP_ORD_Q>();
+    test_cmp::<_CMP_EQ_UQ>();
+    test_cmp::<_CMP_NGE_US>();
+    test_cmp::<_CMP_NGT_US>();
+    test_cmp::<_CMP_FALSE_OQ>();
+    test_cmp::<_CMP_NEQ_OQ>();
+    test_cmp::<_CMP_GE_OS>();
+    test_cmp::<_CMP_GT_OS>();
+    test_cmp::<_CMP_TRUE_US>();
+}
diff --git a/src/tools/miri/tests/pass/intrinsics-x86-avx512.rs b/src/tools/miri/tests/pass/intrinsics-x86-avx512.rs
index c38158dc797..394412a2354 100644
--- a/src/tools/miri/tests/pass/intrinsics-x86-avx512.rs
+++ b/src/tools/miri/tests/pass/intrinsics-x86-avx512.rs
@@ -1,5 +1,5 @@
 // Ignore everything except x86 and x86_64
-// Any additional target are added to CI should be ignored here
+// Any new targets that are added to CI should be ignored here.
 // (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.)
 //@ignore-target-aarch64
 //@ignore-target-arm
diff --git a/src/tools/miri/tests/pass/intrinsics-x86-sse3-ssse3.rs b/src/tools/miri/tests/pass/intrinsics-x86-sse3-ssse3.rs
index 0805d9bc300..7566be4431b 100644
--- a/src/tools/miri/tests/pass/intrinsics-x86-sse3-ssse3.rs
+++ b/src/tools/miri/tests/pass/intrinsics-x86-sse3-ssse3.rs
@@ -1,5 +1,5 @@
 // Ignore everything except x86 and x86_64
-// Any additional target are added to CI should be ignored here
+// Any new targets that are added to CI should be ignored here.
 // (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.)
 //@ignore-target-aarch64
 //@ignore-target-arm
diff --git a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs
index 8c565a2d6e0..13856d29d3f 100644
--- a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs
+++ b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs
@@ -1,5 +1,5 @@
 // Ignore everything except x86 and x86_64
-// Any additional target are added to CI should be ignored here
+// Any new targets that are added to CI should be ignored here.
 // (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.)
 //@ignore-target-aarch64
 //@ignore-target-arm