about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/compiler-builtins/libm/etc/function-definitions.json6
-rw-r--r--library/compiler-builtins/libm/src/math/arch/aarch64.rs126
-rw-r--r--library/compiler-builtins/libm/src/math/arch/mod.rs21
-rw-r--r--library/compiler-builtins/libm/src/math/fma.rs6
-rw-r--r--library/compiler-builtins/libm/src/math/fma_wide.rs6
-rw-r--r--library/compiler-builtins/libm/src/math/rint.rs10
-rw-r--r--library/compiler-builtins/libm/src/math/sqrt.rs1
-rw-r--r--library/compiler-builtins/libm/src/math/sqrtf.rs1
-rw-r--r--library/compiler-builtins/libm/src/math/sqrtf16.rs6
9 files changed, 155 insertions, 28 deletions
diff --git a/library/compiler-builtins/libm/etc/function-definitions.json b/library/compiler-builtins/libm/etc/function-definitions.json
index 64a775ba9f1..bca58402f4e 100644
--- a/library/compiler-builtins/libm/etc/function-definitions.json
+++ b/library/compiler-builtins/libm/etc/function-definitions.json
@@ -342,12 +342,14 @@
     },
     "fma": {
         "sources": [
+            "src/math/arch/aarch64.rs",
             "src/math/fma.rs"
         ],
         "type": "f64"
     },
     "fmaf": {
         "sources": [
+            "src/math/arch/aarch64.rs",
             "src/math/fma_wide.rs"
         ],
         "type": "f32"
@@ -806,6 +808,7 @@
     },
     "rintf16": {
         "sources": [
+            "src/math/arch/aarch64.rs",
             "src/math/rint.rs"
         ],
         "type": "f16"
@@ -928,6 +931,7 @@
     },
     "sqrt": {
         "sources": [
+            "src/math/arch/aarch64.rs",
             "src/math/arch/i686.rs",
             "src/math/arch/wasm32.rs",
             "src/math/generic/sqrt.rs",
@@ -937,6 +941,7 @@
     },
     "sqrtf": {
         "sources": [
+            "src/math/arch/aarch64.rs",
             "src/math/arch/i686.rs",
             "src/math/arch/wasm32.rs",
             "src/math/generic/sqrt.rs",
@@ -953,6 +958,7 @@
     },
     "sqrtf16": {
         "sources": [
+            "src/math/arch/aarch64.rs",
             "src/math/generic/sqrt.rs",
             "src/math/sqrtf16.rs"
         ],
diff --git a/library/compiler-builtins/libm/src/math/arch/aarch64.rs b/library/compiler-builtins/libm/src/math/arch/aarch64.rs
index 374ec11bfec..020bb731cdc 100644
--- a/library/compiler-builtins/libm/src/math/arch/aarch64.rs
+++ b/library/compiler-builtins/libm/src/math/arch/aarch64.rs
@@ -1,33 +1,115 @@
-use core::arch::aarch64::{
-    float32x2_t, float64x1_t, vdup_n_f32, vdup_n_f64, vget_lane_f32, vget_lane_f64, vrndn_f32,
-    vrndn_f64,
-};
+//! Architecture-specific support for aarch64 with neon.
 
-pub fn rint(x: f64) -> f64 {
-    // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
-    let x_vec: float64x1_t = unsafe { vdup_n_f64(x) };
+use core::arch::asm;
 
-    // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
-    let result_vec: float64x1_t = unsafe { vrndn_f64(x_vec) };
+pub fn fma(mut x: f64, y: f64, z: f64) -> f64 {
+    // SAFETY: `fmadd` is available with neon and has no side effects.
+    unsafe {
+        asm!(
+            "fmadd {x:d}, {x:d}, {y:d}, {z:d}",
+            x = inout(vreg) x,
+            y = in(vreg) y,
+            z = in(vreg) z,
+            options(nomem, nostack, pure)
+        );
+    }
+    x
+}
 
-    // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
-    let result: f64 = unsafe { vget_lane_f64::<0>(result_vec) };
+pub fn fmaf(mut x: f32, y: f32, z: f32) -> f32 {
+    // SAFETY: `fmadd` is available with neon and has no side effects.
+    unsafe {
+        asm!(
+            "fmadd {x:s}, {x:s}, {y:s}, {z:s}",
+            x = inout(vreg) x,
+            y = in(vreg) y,
+            z = in(vreg) z,
+            options(nomem, nostack, pure)
+        );
+    }
+    x
+}
 
-    result
+pub fn rint(mut x: f64) -> f64 {
+    // SAFETY: `frintn` is available with neon and has no side effects.
+    //
+    // `frintn` is always round-to-nearest which does not match the C specification, but Rust does
+    // not support rounding modes.
+    unsafe {
+        asm!(
+            "frintn {x:d}, {x:d}",
+            x = inout(vreg) x,
+            options(nomem, nostack, pure)
+        );
+    }
+    x
 }
 
-pub fn rintf(x: f32) -> f32 {
-    // There's a scalar form of this instruction (FRINTN) but core::arch doesn't expose it, so we
-    // have to use the vector form and drop the other lanes afterwards.
+pub fn rintf(mut x: f32) -> f32 {
+    // SAFETY: `frintn` is available with neon and has no side effects.
+    //
+    // `frintn` is always round-to-nearest which does not match the C specification, but Rust does
+    // not support rounding modes.
+    unsafe {
+        asm!(
+            "frintn {x:s}, {x:s}",
+            x = inout(vreg) x,
+            options(nomem, nostack, pure)
+        );
+    }
+    x
+}
 
-    // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
-    let x_vec: float32x2_t = unsafe { vdup_n_f32(x) };
+#[cfg(all(f16_enabled, target_feature = "fp16"))]
+pub fn rintf16(mut x: f16) -> f16 {
+    // SAFETY: `frintn` is available for `f16` with `fp16` (implies `neon`) and has no side effects.
+    //
+    // `frintn` is always round-to-nearest which does not match the C specification, but Rust does
+    // not support rounding modes.
+    unsafe {
+        asm!(
+            "frintn {x:h}, {x:h}",
+            x = inout(vreg) x,
+            options(nomem, nostack, pure)
+        );
+    }
+    x
+}
 
-    // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
-    let result_vec: float32x2_t = unsafe { vrndn_f32(x_vec) };
+pub fn sqrt(mut x: f64) -> f64 {
+    // SAFETY: `fsqrt` is available with neon and has no side effects.
+    unsafe {
+        asm!(
+            "fsqrt {x:d}, {x:d}",
+            x = inout(vreg) x,
+            options(nomem, nostack, pure)
+        );
+    }
+    x
+}
 
-    // SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
-    let result: f32 = unsafe { vget_lane_f32::<0>(result_vec) };
+pub fn sqrtf(mut x: f32) -> f32 {
+    // SAFETY: `fsqrt` is available with neon and has no side effects.
+    unsafe {
+        asm!(
+            "fsqrt {x:s}, {x:s}",
+            x = inout(vreg) x,
+            options(nomem, nostack, pure)
+        );
+    }
+    x
+}
 
-    result
+#[cfg(all(f16_enabled, target_feature = "fp16"))]
+pub fn sqrtf16(mut x: f16) -> f16 {
+    // SAFETY: `fsqrt` is available for `f16` with `fp16` (implies `neon`) and has no
+    // side effects.
+    unsafe {
+        asm!(
+            "fsqrt {x:h}, {x:h}",
+            x = inout(vreg) x,
+            options(nomem, nostack, pure)
+        );
+    }
+    x
 }
diff --git a/library/compiler-builtins/libm/src/math/arch/mod.rs b/library/compiler-builtins/libm/src/math/arch/mod.rs
index 091d7650a5a..d9f2aad66d4 100644
--- a/library/compiler-builtins/libm/src/math/arch/mod.rs
+++ b/library/compiler-builtins/libm/src/math/arch/mod.rs
@@ -18,12 +18,25 @@ cfg_if! {
         mod i686;
         pub use i686::{sqrt, sqrtf};
     } else if #[cfg(all(
-        target_arch = "aarch64", // TODO: also arm64ec?
-        target_feature = "neon",
-        target_endian = "little", // see https://github.com/rust-lang/stdarch/issues/1484
+        any(target_arch = "aarch64", target_arch = "arm64ec"),
+        target_feature = "neon"
     ))] {
         mod aarch64;
-        pub use aarch64::{rint, rintf};
+
+        pub use aarch64::{
+            fma,
+            fmaf,
+            rint,
+            rintf,
+            sqrt,
+            sqrtf,
+        };
+
+        #[cfg(all(f16_enabled, target_feature = "fp16"))]
+        pub use aarch64::{
+            rintf16,
+            sqrtf16,
+        };
     }
 }
 
diff --git a/library/compiler-builtins/libm/src/math/fma.rs b/library/compiler-builtins/libm/src/math/fma.rs
index 049f573cc92..789b0836afb 100644
--- a/library/compiler-builtins/libm/src/math/fma.rs
+++ b/library/compiler-builtins/libm/src/math/fma.rs
@@ -9,6 +9,12 @@ use super::{CastFrom, CastInto, Float, Int, MinInt};
 /// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
 #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
 pub fn fma(x: f64, y: f64, z: f64) -> f64 {
+    select_implementation! {
+        name: fma,
+        use_arch: all(target_arch = "aarch64", target_feature = "neon"),
+        args: x, y, z,
+    }
+
     fma_round(x, y, z, Round::Nearest).val
 }
 
diff --git a/library/compiler-builtins/libm/src/math/fma_wide.rs b/library/compiler-builtins/libm/src/math/fma_wide.rs
index d0cf33baf7a..8e908a14f21 100644
--- a/library/compiler-builtins/libm/src/math/fma_wide.rs
+++ b/library/compiler-builtins/libm/src/math/fma_wide.rs
@@ -17,6 +17,12 @@ pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 {
 /// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
 #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
 pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
+    select_implementation! {
+        name: fmaf,
+        use_arch: all(target_arch = "aarch64", target_feature = "neon"),
+        args: x, y, z,
+    }
+
     fma_wide_round(x, y, z, Round::Nearest).val
 }
 
diff --git a/library/compiler-builtins/libm/src/math/rint.rs b/library/compiler-builtins/libm/src/math/rint.rs
index 8a5cbeab497..e1c32c94355 100644
--- a/library/compiler-builtins/libm/src/math/rint.rs
+++ b/library/compiler-builtins/libm/src/math/rint.rs
@@ -4,6 +4,12 @@ use super::support::Round;
 #[cfg(f16_enabled)]
 #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
 pub fn rintf16(x: f16) -> f16 {
+    select_implementation! {
+        name: rintf16,
+        use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
+        args: x,
+    }
+
     super::generic::rint_round(x, Round::Nearest).val
 }
 
@@ -13,8 +19,8 @@ pub fn rintf(x: f32) -> f32 {
     select_implementation! {
         name: rintf,
         use_arch: any(
+            all(target_arch = "aarch64", target_feature = "neon"),
             all(target_arch = "wasm32", intrinsics_enabled),
-            all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"),
         ),
         args: x,
     }
@@ -28,8 +34,8 @@ pub fn rint(x: f64) -> f64 {
     select_implementation! {
         name: rint,
         use_arch: any(
+            all(target_arch = "aarch64", target_feature = "neon"),
             all(target_arch = "wasm32", intrinsics_enabled),
-            all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"),
         ),
         args: x,
     }
diff --git a/library/compiler-builtins/libm/src/math/sqrt.rs b/library/compiler-builtins/libm/src/math/sqrt.rs
index 0e1d0cd2c1c..2bfc42bcfed 100644
--- a/library/compiler-builtins/libm/src/math/sqrt.rs
+++ b/library/compiler-builtins/libm/src/math/sqrt.rs
@@ -4,6 +4,7 @@ pub fn sqrt(x: f64) -> f64 {
     select_implementation! {
         name: sqrt,
         use_arch: any(
+            all(target_arch = "aarch64", target_feature = "neon"),
             all(target_arch = "wasm32", intrinsics_enabled),
             target_feature = "sse2"
         ),
diff --git a/library/compiler-builtins/libm/src/math/sqrtf.rs b/library/compiler-builtins/libm/src/math/sqrtf.rs
index 2e69a4b6694..c28a705e378 100644
--- a/library/compiler-builtins/libm/src/math/sqrtf.rs
+++ b/library/compiler-builtins/libm/src/math/sqrtf.rs
@@ -4,6 +4,7 @@ pub fn sqrtf(x: f32) -> f32 {
     select_implementation! {
         name: sqrtf,
         use_arch: any(
+            all(target_arch = "aarch64", target_feature = "neon"),
             all(target_arch = "wasm32", intrinsics_enabled),
             target_feature = "sse2"
         ),
diff --git a/library/compiler-builtins/libm/src/math/sqrtf16.rs b/library/compiler-builtins/libm/src/math/sqrtf16.rs
index 549bf902c72..7bedb7f8bbb 100644
--- a/library/compiler-builtins/libm/src/math/sqrtf16.rs
+++ b/library/compiler-builtins/libm/src/math/sqrtf16.rs
@@ -1,5 +1,11 @@
 /// The square root of `x` (f16).
 #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
 pub fn sqrtf16(x: f16) -> f16 {
+    select_implementation! {
+        name: sqrtf16,
+        use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
+        args: x,
+    }
+
     return super::generic::sqrt(x);
 }