about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src
diff options
context:
space:
mode:
authorCaleb Zulawski <caleb.zulawski@gmail.com>2021-01-02 23:49:28 -0500
committerCaleb Zulawski <caleb.zulawski@gmail.com>2021-01-03 01:06:54 -0500
commit07db2bfe395a2787376c3cf068bbfe9e1e2eb4d0 (patch)
tree67e52cf7e497445eadf5852f8a065e4e89af0240 /compiler/rustc_codegen_llvm/src
parentfd85ca02f6da4531bcec65b8309ecf38e0eb8938 (diff)
downloadrust-07db2bfe395a2787376c3cf068bbfe9e1e2eb4d0.tar.gz
rust-07db2bfe395a2787376c3cf068bbfe9e1e2eb4d0.zip
Implement floating point SIMD intrinsics over all vector widths, and limit SIMD vector lengths.
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
-rw-r--r--compiler/rustc_codegen_llvm/src/context.rs117
-rw-r--r--compiler/rustc_codegen_llvm/src/intrinsic.rs133
2 files changed, 55 insertions, 195 deletions
diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs
index 8dd40308075..df3ef26d1ad 100644
--- a/compiler/rustc_codegen_llvm/src/context.rs
+++ b/compiler/rustc_codegen_llvm/src/context.rs
@@ -498,25 +498,6 @@ impl CodegenCx<'b, 'tcx> {
         let t_f32 = self.type_f32();
         let t_f64 = self.type_f64();
 
-        macro_rules! vector_types {
-            ($id_out:ident: $elem_ty:ident, $len:expr) => {
-                let $id_out = self.type_vector($elem_ty, $len);
-            };
-            ($($id_out:ident: $elem_ty:ident, $len:expr;)*) => {
-                $(vector_types!($id_out: $elem_ty, $len);)*
-            }
-        }
-        vector_types! {
-            t_v2f32: t_f32, 2;
-            t_v4f32: t_f32, 4;
-            t_v8f32: t_f32, 8;
-            t_v16f32: t_f32, 16;
-
-            t_v2f64: t_f64, 2;
-            t_v4f64: t_f64, 4;
-            t_v8f64: t_f64, 8;
-        }
-
         ifn!("llvm.wasm.trunc.saturate.unsigned.i32.f32", fn(t_f32) -> t_i32);
         ifn!("llvm.wasm.trunc.saturate.unsigned.i32.f64", fn(t_f64) -> t_i32);
         ifn!("llvm.wasm.trunc.saturate.unsigned.i64.f32", fn(t_f32) -> t_i64);
@@ -540,124 +521,40 @@ impl CodegenCx<'b, 'tcx> {
         ifn!("llvm.sideeffect", fn() -> void);
 
         ifn!("llvm.powi.f32", fn(t_f32, t_i32) -> t_f32);
-        ifn!("llvm.powi.v2f32", fn(t_v2f32, t_i32) -> t_v2f32);
-        ifn!("llvm.powi.v4f32", fn(t_v4f32, t_i32) -> t_v4f32);
-        ifn!("llvm.powi.v8f32", fn(t_v8f32, t_i32) -> t_v8f32);
-        ifn!("llvm.powi.v16f32", fn(t_v16f32, t_i32) -> t_v16f32);
         ifn!("llvm.powi.f64", fn(t_f64, t_i32) -> t_f64);
-        ifn!("llvm.powi.v2f64", fn(t_v2f64, t_i32) -> t_v2f64);
-        ifn!("llvm.powi.v4f64", fn(t_v4f64, t_i32) -> t_v4f64);
-        ifn!("llvm.powi.v8f64", fn(t_v8f64, t_i32) -> t_v8f64);
 
         ifn!("llvm.pow.f32", fn(t_f32, t_f32) -> t_f32);
-        ifn!("llvm.pow.v2f32", fn(t_v2f32, t_v2f32) -> t_v2f32);
-        ifn!("llvm.pow.v4f32", fn(t_v4f32, t_v4f32) -> t_v4f32);
-        ifn!("llvm.pow.v8f32", fn(t_v8f32, t_v8f32) -> t_v8f32);
-        ifn!("llvm.pow.v16f32", fn(t_v16f32, t_v16f32) -> t_v16f32);
         ifn!("llvm.pow.f64", fn(t_f64, t_f64) -> t_f64);
-        ifn!("llvm.pow.v2f64", fn(t_v2f64, t_v2f64) -> t_v2f64);
-        ifn!("llvm.pow.v4f64", fn(t_v4f64, t_v4f64) -> t_v4f64);
-        ifn!("llvm.pow.v8f64", fn(t_v8f64, t_v8f64) -> t_v8f64);
 
         ifn!("llvm.sqrt.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.sqrt.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.sqrt.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.sqrt.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.sqrt.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.sqrt.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.sqrt.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.sqrt.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.sqrt.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.sin.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.sin.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.sin.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.sin.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.sin.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.sin.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.sin.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.sin.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.sin.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.cos.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.cos.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.cos.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.cos.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.cos.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.cos.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.cos.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.cos.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.cos.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.exp.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.exp.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.exp.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.exp.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.exp.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.exp.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.exp.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.exp.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.exp.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.exp2.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.exp2.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.exp2.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.exp2.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.exp2.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.exp2.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.exp2.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.exp2.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.exp2.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.log.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.log.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.log.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.log.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.log.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.log.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.log.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.log.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.log.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.log10.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.log10.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.log10.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.log10.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.log10.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.log10.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.log10.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.log10.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.log10.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.log2.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.log2.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.log2.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.log2.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.log2.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.log2.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.log2.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.log2.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.log2.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.fma.f32", fn(t_f32, t_f32, t_f32) -> t_f32);
-        ifn!("llvm.fma.v2f32", fn(t_v2f32, t_v2f32, t_v2f32) -> t_v2f32);
-        ifn!("llvm.fma.v4f32", fn(t_v4f32, t_v4f32, t_v4f32) -> t_v4f32);
-        ifn!("llvm.fma.v8f32", fn(t_v8f32, t_v8f32, t_v8f32) -> t_v8f32);
-        ifn!("llvm.fma.v16f32", fn(t_v16f32, t_v16f32, t_v16f32) -> t_v16f32);
         ifn!("llvm.fma.f64", fn(t_f64, t_f64, t_f64) -> t_f64);
-        ifn!("llvm.fma.v2f64", fn(t_v2f64, t_v2f64, t_v2f64) -> t_v2f64);
-        ifn!("llvm.fma.v4f64", fn(t_v4f64, t_v4f64, t_v4f64) -> t_v4f64);
-        ifn!("llvm.fma.v8f64", fn(t_v8f64, t_v8f64, t_v8f64) -> t_v8f64);
 
         ifn!("llvm.fabs.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.fabs.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.fabs.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.fabs.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.fabs.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.fabs.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.fabs.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.fabs.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.fabs.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.minnum.f32", fn(t_f32, t_f32) -> t_f32);
         ifn!("llvm.minnum.f64", fn(t_f64, t_f64) -> t_f64);
@@ -665,24 +562,10 @@ impl CodegenCx<'b, 'tcx> {
         ifn!("llvm.maxnum.f64", fn(t_f64, t_f64) -> t_f64);
 
         ifn!("llvm.floor.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.floor.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.floor.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.floor.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.floor.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.floor.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.floor.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.floor.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.floor.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.ceil.f32", fn(t_f32) -> t_f32);
-        ifn!("llvm.ceil.v2f32", fn(t_v2f32) -> t_v2f32);
-        ifn!("llvm.ceil.v4f32", fn(t_v4f32) -> t_v4f32);
-        ifn!("llvm.ceil.v8f32", fn(t_v8f32) -> t_v8f32);
-        ifn!("llvm.ceil.v16f32", fn(t_v16f32) -> t_v16f32);
         ifn!("llvm.ceil.f64", fn(t_f64) -> t_f64);
-        ifn!("llvm.ceil.v2f64", fn(t_v2f64) -> t_v2f64);
-        ifn!("llvm.ceil.v4f64", fn(t_v4f64) -> t_v4f64);
-        ifn!("llvm.ceil.v8f64", fn(t_v8f64) -> t_v8f64);
 
         ifn!("llvm.trunc.f32", fn(t_f32) -> t_f32);
         ifn!("llvm.trunc.f64", fn(t_f64) -> t_f64);
diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs
index bf0d499e6c4..81728b8def9 100644
--- a/compiler/rustc_codegen_llvm/src/intrinsic.rs
+++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs
@@ -1009,7 +1009,7 @@ fn generic_simd_intrinsic(
     }
 
     fn simd_simple_float_intrinsic(
-        name: &str,
+        name: Symbol,
         in_elem: &::rustc_middle::ty::TyS<'_>,
         in_ty: &::rustc_middle::ty::TyS<'_>,
         in_len: u64,
@@ -1036,93 +1036,70 @@ fn generic_simd_intrinsic(
                 }
             }
         }
-        let ety = match in_elem.kind() {
-            ty::Float(f) if f.bit_width() == 32 => {
-                if in_len < 2 || in_len > 16 {
-                    return_error!(
-                        "unsupported floating-point vector `{}` with length `{}` \
-                         out-of-range [2, 16]",
-                        in_ty,
-                        in_len
-                    );
-                }
-                "f32"
-            }
-            ty::Float(f) if f.bit_width() == 64 => {
-                if in_len < 2 || in_len > 8 {
+
+        let (elem_ty_str, elem_ty) = if let ty::Float(f) = in_elem.kind() {
+            let elem_ty = bx.cx.type_float_from_ty(*f);
+            match f.bit_width() {
+                32 => ("f32", elem_ty),
+                64 => ("f64", elem_ty),
+                _ => {
                     return_error!(
-                        "unsupported floating-point vector `{}` with length `{}` \
-                                   out-of-range [2, 8]",
-                        in_ty,
-                        in_len
+                        "unsupported element type `{}` of floating-point vector `{}`",
+                        f.name_str(),
+                        in_ty
                     );
                 }
-                "f64"
-            }
-            ty::Float(f) => {
-                return_error!(
-                    "unsupported element type `{}` of floating-point vector `{}`",
-                    f.name_str(),
-                    in_ty
-                );
-            }
-            _ => {
-                return_error!("`{}` is not a floating-point type", in_ty);
             }
+        } else {
+            return_error!("`{}` is not a floating-point type", in_ty);
         };
 
-        let llvm_name = &format!("llvm.{0}.v{1}{2}", name, in_len, ety);
-        let intrinsic = bx.get_intrinsic(&llvm_name);
-        let c =
-            bx.call(intrinsic, &args.iter().map(|arg| arg.immediate()).collect::<Vec<_>>(), None);
+        let vec_ty = bx.type_vector(elem_ty, in_len);
+
+        let (intr_name, fn_ty) = match name {
+            sym::simd_fsqrt => ("sqrt", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_fsin => ("sin", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_fcos => ("cos", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_fabs => ("fabs", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_floor => ("floor", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_ceil => ("ceil", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_fexp => ("exp", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_fexp2 => ("exp2", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_flog10 => ("log10", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_flog2 => ("log2", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_flog => ("log", bx.type_func(&[vec_ty], vec_ty)),
+            sym::simd_fpowi => ("powi", bx.type_func(&[vec_ty, bx.type_i32()], vec_ty)),
+            sym::simd_fpow => ("pow", bx.type_func(&[vec_ty, vec_ty], vec_ty)),
+            sym::simd_fma => ("fma", bx.type_func(&[vec_ty, vec_ty, vec_ty], vec_ty)),
+            _ => return_error!("unrecognized intrinsic `{}`", name),
+        };
+
+        let llvm_name = &format!("llvm.{0}.v{1}{2}", intr_name, in_len, elem_ty_str);
+        let f = bx.declare_cfn(&llvm_name, fn_ty);
+        llvm::SetUnnamedAddress(f, llvm::UnnamedAddr::No);
+        let c = bx.call(f, &args.iter().map(|arg| arg.immediate()).collect::<Vec<_>>(), None);
         unsafe { llvm::LLVMRustSetHasUnsafeAlgebra(c) };
         Ok(c)
     }
 
-    match name {
-        sym::simd_fsqrt => {
-            return simd_simple_float_intrinsic("sqrt", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_fsin => {
-            return simd_simple_float_intrinsic("sin", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_fcos => {
-            return simd_simple_float_intrinsic("cos", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_fabs => {
-            return simd_simple_float_intrinsic("fabs", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_floor => {
-            return simd_simple_float_intrinsic("floor", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_ceil => {
-            return simd_simple_float_intrinsic("ceil", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_fexp => {
-            return simd_simple_float_intrinsic("exp", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_fexp2 => {
-            return simd_simple_float_intrinsic("exp2", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_flog10 => {
-            return simd_simple_float_intrinsic("log10", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_flog2 => {
-            return simd_simple_float_intrinsic("log2", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_flog => {
-            return simd_simple_float_intrinsic("log", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_fpowi => {
-            return simd_simple_float_intrinsic("powi", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_fpow => {
-            return simd_simple_float_intrinsic("pow", in_elem, in_ty, in_len, bx, span, args);
-        }
-        sym::simd_fma => {
-            return simd_simple_float_intrinsic("fma", in_elem, in_ty, in_len, bx, span, args);
-        }
-        _ => { /* fallthrough */ }
+    if std::matches!(
+        name,
+        sym::simd_fsqrt
+            | sym::simd_fsin
+            | sym::simd_fcos
+            | sym::simd_fabs
+            | sym::simd_floor
+            | sym::simd_ceil
+            | sym::simd_fexp
+            | sym::simd_fexp2
+            | sym::simd_flog10
+            | sym::simd_flog2
+            | sym::simd_flog
+            | sym::simd_fpowi
+            | sym::simd_fpow
+            | sym::simd_fma
+    ) {
+        return simd_simple_float_intrinsic(name, in_elem, in_ty, in_len, bx, span, args);
     }
 
     // FIXME: use: