about summary refs log tree commit diff
diff options
context:
space:
mode:
authorsayantn <sayantan.chakraborty@students.iiserpune.ac.in>2024-07-07 16:02:36 +0530
committerAmanieu d'Antras <amanieu@gmail.com>2024-07-26 12:20:06 +0100
commitc878b773d5e7e21b31d025387aa48bb62be2bb14 (patch)
tree27ba4fd91cf47a036ebca2127eee9a2b71a7a8ed
parenta1ad6bf8be51bfbe656acfc100873284932574e2 (diff)
downloadrust-c878b773d5e7e21b31d025387aa48bb62be2bb14.tar.gz
rust-c878b773d5e7e21b31d025387aa48bb62be2bb14.zip
AVX512FP16 Part 0: Types
-rw-r--r--library/stdarch/crates/core_arch/src/lib.rs3
-rw-r--r--library/stdarch/crates/core_arch/src/simd.rs73
-rw-r--r--library/stdarch/crates/core_arch/src/x86/mod.rs79
-rw-r--r--library/stdarch/crates/core_arch/src/x86/test.rs33
-rw-r--r--library/stdarch/crates/stdarch-verify/src/lib.rs4
-rw-r--r--library/stdarch/crates/stdarch-verify/tests/x86-intel.rs19
6 files changed, 208 insertions, 3 deletions
diff --git a/library/stdarch/crates/core_arch/src/lib.rs b/library/stdarch/crates/core_arch/src/lib.rs
index 19011490746..a7a02783e04 100644
--- a/library/stdarch/crates/core_arch/src/lib.rs
+++ b/library/stdarch/crates/core_arch/src/lib.rs
@@ -34,7 +34,8 @@
     target_feature_11,
     generic_arg_infer,
     asm_experimental_arch,
-    sha512_sm_x86
+    sha512_sm_x86,
+    f16
 )]
 #![cfg_attr(test, feature(test, abi_vectorcall, stdarch_internal))]
 #![deny(clippy::missing_inline_in_public_items)]
diff --git a/library/stdarch/crates/core_arch/src/simd.rs b/library/stdarch/crates/core_arch/src/simd.rs
index 4c637f49f3f..30823341023 100644
--- a/library/stdarch/crates/core_arch/src/simd.rs
+++ b/library/stdarch/crates/core_arch/src/simd.rs
@@ -3,9 +3,10 @@
 #![allow(non_camel_case_types)]
 
 macro_rules! simd_ty {
-    ($id:ident [$ety:ident]: $($elem_name:ident),*) => {
+    ($(#[$stability:meta])? $id:ident [$ety:ident]: $($elem_name:ident),*) => {
         #[repr(simd)]
         #[derive(Copy, Clone, Debug, PartialEq)]
+        $(#[$stability])?
         pub(crate) struct $id { $(pub $elem_name: $ety),* }
 
         #[allow(clippy::use_self)]
@@ -186,9 +187,20 @@ simd_ty!(
 simd_ty!(i32x4[i32]: x0, x1, x2, x3);
 simd_ty!(i64x2[i64]: x0, x1);
 
+simd_ty!(
+    #[unstable(feature = "f16", issue = "116909")]
+    f16x8[f16]:
+    x0,
+    x1,
+    x2,
+    x3,
+    x4,
+    x5,
+    x6,
+    x7
+);
 simd_ty!(f32x4[f32]: x0, x1, x2, x3);
 simd_ty!(f64x2[f64]: x0, x1);
-simd_ty!(f64x4[f64]: x0, x1, x2, x3);
 
 simd_m_ty!(
     m8x16[i8]:
@@ -360,6 +372,26 @@ simd_ty!(
 simd_ty!(i64x4[i64]: x0, x1, x2, x3);
 
 simd_ty!(
+    #[unstable(feature = "f16", issue = "116909")]
+    f16x16[f16]:
+    x0,
+    x1,
+    x2,
+    x3,
+    x4,
+    x5,
+    x6,
+    x7,
+    x8,
+    x9,
+    x10,
+    x11,
+    x12,
+    x13,
+    x14,
+    x15
+);
+simd_ty!(
     f32x8[f32]:
     x0,
     x1,
@@ -370,6 +402,7 @@ simd_ty!(
     x6,
     x7
 );
+simd_ty!(f64x4[f64]: x0, x1, x2, x3);
 
 simd_m_ty!(
     m8x32[i8]:
@@ -689,6 +722,42 @@ simd_ty!(
 );
 
 simd_ty!(
+    #[unstable(feature = "f16", issue = "116909")]
+    f16x32[f16]:
+    x0,
+    x1,
+    x2,
+    x3,
+    x4,
+    x5,
+    x6,
+    x7,
+    x8,
+    x9,
+    x10,
+    x11,
+    x12,
+    x13,
+    x14,
+    x15,
+    x16,
+    x17,
+    x18,
+    x19,
+    x20,
+    x21,
+    x22,
+    x23,
+    x24,
+    x25,
+    x26,
+    x27,
+    x28,
+    x29,
+    x30,
+    x31
+);
+simd_ty!(
     f32x16[f32]:
     x0,
     x1,
diff --git a/library/stdarch/crates/core_arch/src/x86/mod.rs b/library/stdarch/crates/core_arch/src/x86/mod.rs
index 9365fe10a2d..d3d4381cc7f 100644
--- a/library/stdarch/crates/core_arch/src/x86/mod.rs
+++ b/library/stdarch/crates/core_arch/src/x86/mod.rs
@@ -335,6 +335,41 @@ types! {
         u16, u16, u16, u16, u16, u16, u16, u16,
         u16, u16, u16, u16, u16, u16, u16, u16
     );
+
+    /// 128-bit wide set of 8 `f16` types, x86-specific
+    ///
+    /// This type is the same as the `__m128h` type defined by Intel,
+    /// representing a 128-bit SIMD register which internally is consisted of
+    /// 8 packed `f16` instances. its purpose is for f16 related intrinsic
+    /// implementations.
+    #[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
+    pub struct __m128h(f16, f16, f16, f16, f16, f16, f16, f16);
+
+    /// 256-bit wide set of 16 `f16` types, x86-specific
+    ///
+    /// This type is the same as the `__m256h` type defined by Intel,
+    /// representing a 256-bit SIMD register which internally is consisted of
+    /// 16 packed `f16` instances. its purpose is for f16 related intrinsic
+    /// implementations.
+    #[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
+    pub struct __m256h(
+        f16, f16, f16, f16, f16, f16, f16, f16,
+        f16, f16, f16, f16, f16, f16, f16, f16
+    );
+
+    /// 512-bit wide set of 32 `f16` types, x86-specific
+    ///
+    /// This type is the same as the `__m512h` type defined by Intel,
+    /// representing a 512-bit SIMD register which internally is consisted of
+    /// 32 packed `f16` instances. its purpose is for f16 related intrinsic
+    /// implementations.
+    #[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
+    pub struct __m512h(
+        f16, f16, f16, f16, f16, f16, f16, f16,
+        f16, f16, f16, f16, f16, f16, f16, f16,
+        f16, f16, f16, f16, f16, f16, f16, f16,
+        f16, f16, f16, f16, f16, f16, f16, f16
+    );
 }
 
 /// The BFloat16 type used in AVX-512 intrinsics.
@@ -761,6 +796,50 @@ impl m512bhExt for __m512bh {
     }
 }
 
+#[allow(non_camel_case_types)]
+pub(crate) trait m128hExt: Sized {
+    fn as_m128h(self) -> __m128h;
+
+    #[inline]
+    fn as_f16x8(self) -> crate::core_arch::simd::f16x8 {
+        unsafe { transmute(self.as_m128h()) }
+    }
+}
+
+impl m128hExt for __m128h {
+    #[inline]
+    fn as_m128h(self) -> Self {
+        self
+    }
+}
+
+#[allow(non_camel_case_types)]
+pub(crate) trait m256hExt: Sized {
+    fn as_m256h(self) -> __m256h;
+
+    #[inline]
+    fn as_f16x16(self) -> crate::core_arch::simd::f16x16 {
+        unsafe { transmute(self.as_m256h()) }
+    }
+}
+
+impl m256hExt for __m256h {
+    #[inline]
+    fn as_m256h(self) -> Self {
+        self
+    }
+}
+
+#[allow(non_camel_case_types)]
+pub(crate) trait m512hExt: Sized {
+    fn as_m512h(self) -> __m512h;
+
+    #[inline]
+    fn as_f16x32(self) -> crate::core_arch::simd::f16x32 {
+        unsafe { transmute(self.as_m512h()) }
+    }
+}
+
 mod eflags;
 #[stable(feature = "simd_x86", since = "1.27.0")]
 pub use self::eflags::*;
diff --git a/library/stdarch/crates/core_arch/src/x86/test.rs b/library/stdarch/crates/core_arch/src/x86/test.rs
index 2c88650af38..ebb67356a4e 100644
--- a/library/stdarch/crates/core_arch/src/x86/test.rs
+++ b/library/stdarch/crates/core_arch/src/x86/test.rs
@@ -36,6 +36,17 @@ pub unsafe fn get_m128(a: __m128, idx: usize) -> f32 {
     transmute::<_, [f32; 4]>(a)[idx]
 }
 
+#[track_caller]
+#[target_feature(enable = "avx512fp16")]
+#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
+pub unsafe fn assert_eq_m128h(a: __m128h, b: __m128h) {
+    // FIXME: use `_mm_cmp_ph_mask::<_CMP_EQ_OQ>` when it's implemented
+    let r = _mm_cmpeq_epi16_mask(transmute(a), transmute(b));
+    if r != 0b1111_1111 {
+        panic!("{:?} != {:?}", a, b);
+    }
+}
+
 // not actually an intrinsic but useful in various tests as we proted from
 // `i64x2::new` which is backwards from `_mm_set_epi64x`
 #[target_feature(enable = "sse2")]
@@ -77,6 +88,17 @@ pub unsafe fn get_m256(a: __m256, idx: usize) -> f32 {
     transmute::<_, [f32; 8]>(a)[idx]
 }
 
+#[track_caller]
+#[target_feature(enable = "avx512fp16")]
+#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
+pub unsafe fn assert_eq_m256h(a: __m256h, b: __m256h) {
+    // FIXME: use `_mm256_cmp_ph_mask::<_CMP_EQ_OQ>` when it's implemented
+    let r = _mm256_cmpeq_epi16_mask(transmute(a), transmute(b));
+    if r != 0b11111111_11111111 {
+        panic!("{:?} != {:?}", a, b);
+    }
+}
+
 #[target_feature(enable = "avx512f")]
 pub unsafe fn get_m512(a: __m512, idx: usize) -> f32 {
     transmute::<_, [f32; 16]>(a)[idx]
@@ -139,3 +161,14 @@ pub unsafe fn assert_eq_m512d(a: __m512d, b: __m512d) {
         panic!("{:?} != {:?}", a, b);
     }
 }
+
+#[track_caller]
+#[target_feature(enable = "avx512fp16")]
+#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
+pub unsafe fn assert_eq_m512h(a: __m512h, b: __m512h) {
+    // FIXME: use `_mm512_cmp_ph_mask::<_CMP_EQ_OQ>` when it's implemented
+    let r = _mm512_cmpeq_epi16_mask(transmute(a), transmute(b));
+    if r != 0b11111111_11111111_11111111_11111111 {
+        panic!("{:?} != {:?}", a, b);
+    }
+}
diff --git a/library/stdarch/crates/stdarch-verify/src/lib.rs b/library/stdarch/crates/stdarch-verify/src/lib.rs
index 106aeabdb01..efb5d50e26a 100644
--- a/library/stdarch/crates/stdarch-verify/src/lib.rs
+++ b/library/stdarch/crates/stdarch-verify/src/lib.rs
@@ -182,14 +182,17 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream {
             "__m128" => quote! { &M128 },
             "__m128bh" => quote! { &M128BH },
             "__m128d" => quote! { &M128D },
+            "__m128h" => quote! { &M128H },
             "__m128i" => quote! { &M128I },
             "__m256" => quote! { &M256 },
             "__m256bh" => quote! { &M256BH },
             "__m256d" => quote! { &M256D },
+            "__m256h" => quote! { &M256H },
             "__m256i" => quote! { &M256I },
             "__m512" => quote! { &M512 },
             "__m512bh" => quote! { &M512BH },
             "__m512d" => quote! { &M512D },
+            "__m512h" => quote! { &M512H },
             "__m512i" => quote! { &M512I },
             "__mmask8" => quote! { &MMASK8 },
             "__mmask16" => quote! { &MMASK16 },
@@ -201,6 +204,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream {
             "_MM_PERM_ENUM" => quote! { &MM_PERM_ENUM },
             "bool" => quote! { &BOOL },
             "bf16" => quote! { &BF16 },
+            "f16" => quote! { &F16 },
             "f32" => quote! { &F32 },
             "f64" => quote! { &F64 },
             "i16" => quote! { &I16 },
diff --git a/library/stdarch/crates/stdarch-verify/tests/x86-intel.rs b/library/stdarch/crates/stdarch-verify/tests/x86-intel.rs
index d035b4edff1..fadaa6a4b13 100644
--- a/library/stdarch/crates/stdarch-verify/tests/x86-intel.rs
+++ b/library/stdarch/crates/stdarch-verify/tests/x86-intel.rs
@@ -24,6 +24,7 @@ struct Function {
 }
 
 static BF16: Type = Type::BFloat16;
+static F16: Type = Type::PrimFloat(16);
 static F32: Type = Type::PrimFloat(32);
 static F64: Type = Type::PrimFloat(64);
 static I8: Type = Type::PrimSigned(8);
@@ -41,14 +42,17 @@ static M128: Type = Type::M128;
 static M128BH: Type = Type::M128BH;
 static M128I: Type = Type::M128I;
 static M128D: Type = Type::M128D;
+static M128H: Type = Type::M128H;
 static M256: Type = Type::M256;
 static M256BH: Type = Type::M256BH;
 static M256I: Type = Type::M256I;
 static M256D: Type = Type::M256D;
+static M256H: Type = Type::M256H;
 static M512: Type = Type::M512;
 static M512BH: Type = Type::M512BH;
 static M512I: Type = Type::M512I;
 static M512D: Type = Type::M512D;
+static M512H: Type = Type::M512H;
 static MMASK8: Type = Type::MMASK8;
 static MMASK16: Type = Type::MMASK16;
 static MMASK32: Type = Type::MMASK32;
@@ -73,14 +77,17 @@ enum Type {
     M128,
     M128BH,
     M128D,
+    M128H,
     M128I,
     M256,
     M256BH,
     M256D,
+    M256H,
     M256I,
     M512,
     M512BH,
     M512D,
+    M512H,
     M512I,
     MMASK8,
     MMASK16,
@@ -221,13 +228,16 @@ fn verify_all_signatures() {
                 "_mm_undefined_ps",
                 "_mm_undefined_pd",
                 "_mm_undefined_si128",
+                "_mm_undefined_ph",
                 "_mm256_undefined_ps",
                 "_mm256_undefined_pd",
                 "_mm256_undefined_si256",
+                "_mm256_undefined_ph",
                 "_mm512_undefined_ps",
                 "_mm512_undefined_pd",
                 "_mm512_undefined_epi32",
                 "_mm512_undefined",
+                "_mm512_undefined_ph",
                 // Has doc-tests instead
                 "_mm256_shuffle_epi32",
                 "_mm256_unpackhi_epi8",
@@ -483,6 +493,9 @@ fn matches(rust: &Function, intel: &Intrinsic) -> Result<(), String> {
             // The XML file names BF16 as "avx512_bf16", while Rust calls
             // it "avx512bf16".
             "avx512_bf16" => String::from("avx512bf16"),
+            // The XML file names FP16 as "avx512_fp16", while Rust calls
+            // it "avx512fp16".
+            "avx512_fp16" => String::from("avx512fp16"),
             // The XML file names AVX-VNNI as "avx_vnni", while Rust calls
             // it "avxvnni"
             "avx_vnni" => String::from("avxvnni"),
@@ -709,6 +722,7 @@ fn equate(
         }
     }
     match (t, &intel[..]) {
+        (&Type::PrimFloat(16), "_Float16") => {}
         (&Type::PrimFloat(32), "float") => {}
         (&Type::PrimFloat(64), "double") => {}
         (&Type::PrimSigned(8), "__int8" | "char") => {}
@@ -728,14 +742,17 @@ fn equate(
         (&Type::M128BH, "__m128bh") => {}
         (&Type::M128I, "__m128i") => {}
         (&Type::M128D, "__m128d") => {}
+        (&Type::M128H, "__m128h") => {}
         (&Type::M256, "__m256") => {}
         (&Type::M256BH, "__m256bh") => {}
         (&Type::M256I, "__m256i") => {}
         (&Type::M256D, "__m256d") => {}
+        (&Type::M256H, "__m256h") => {}
         (&Type::M512, "__m512") => {}
         (&Type::M512BH, "__m512bh") => {}
         (&Type::M512I, "__m512i") => {}
         (&Type::M512D, "__m512d") => {}
+        (&Type::M512H, "__m512h") => {}
         (&Type::MMASK64, "__mmask64") => {}
         (&Type::MMASK32, "__mmask32") => {}
         (&Type::MMASK16, "__mmask16") => {}
@@ -771,6 +788,7 @@ fn equate(
         (&Type::MutPtr(&Type::M512D), "__m512d*") => {}
 
         (&Type::ConstPtr(_), "void const*") => {}
+        (&Type::ConstPtr(&Type::PrimFloat(16)), "_Float16 const*") => {}
         (&Type::ConstPtr(&Type::PrimFloat(32)), "float const*") => {}
         (&Type::ConstPtr(&Type::PrimFloat(64)), "double const*") => {}
         (&Type::ConstPtr(&Type::PrimSigned(8)), "char const*") => {}
@@ -785,6 +803,7 @@ fn equate(
         (&Type::ConstPtr(&Type::M128BH), "__m128bh const*") => {}
         (&Type::ConstPtr(&Type::M128I), "__m128i const*") => {}
         (&Type::ConstPtr(&Type::M128D), "__m128d const*") => {}
+        (&Type::ConstPtr(&Type::M128H), "__m128h const*") => {}
         (&Type::ConstPtr(&Type::M256), "__m256 const*") => {}
         (&Type::ConstPtr(&Type::M256BH), "__m256bh const*") => {}
         (&Type::ConstPtr(&Type::M256I), "__m256i const*") => {}