diff options
| author | Caleb Zulawski <caleb.zulawski@gmail.com> | 2021-01-02 23:49:28 -0500 |
|---|---|---|
| committer | Caleb Zulawski <caleb.zulawski@gmail.com> | 2021-01-03 01:06:54 -0500 |
| commit | 07db2bfe395a2787376c3cf068bbfe9e1e2eb4d0 (patch) | |
| tree | 67e52cf7e497445eadf5852f8a065e4e89af0240 /compiler/rustc_codegen_llvm/src | |
| parent | fd85ca02f6da4531bcec65b8309ecf38e0eb8938 (diff) | |
| download | rust-07db2bfe395a2787376c3cf068bbfe9e1e2eb4d0.tar.gz rust-07db2bfe395a2787376c3cf068bbfe9e1e2eb4d0.zip | |
Implement floating point SIMD intrinsics over all vector widths, and limit SIMD vector lengths.
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/context.rs | 117 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/intrinsic.rs | 133 |
2 files changed, 55 insertions, 195 deletions
diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 8dd40308075..df3ef26d1ad 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -498,25 +498,6 @@ impl CodegenCx<'b, 'tcx> { let t_f32 = self.type_f32(); let t_f64 = self.type_f64(); - macro_rules! vector_types { - ($id_out:ident: $elem_ty:ident, $len:expr) => { - let $id_out = self.type_vector($elem_ty, $len); - }; - ($($id_out:ident: $elem_ty:ident, $len:expr;)*) => { - $(vector_types!($id_out: $elem_ty, $len);)* - } - } - vector_types! { - t_v2f32: t_f32, 2; - t_v4f32: t_f32, 4; - t_v8f32: t_f32, 8; - t_v16f32: t_f32, 16; - - t_v2f64: t_f64, 2; - t_v4f64: t_f64, 4; - t_v8f64: t_f64, 8; - } - ifn!("llvm.wasm.trunc.saturate.unsigned.i32.f32", fn(t_f32) -> t_i32); ifn!("llvm.wasm.trunc.saturate.unsigned.i32.f64", fn(t_f64) -> t_i32); ifn!("llvm.wasm.trunc.saturate.unsigned.i64.f32", fn(t_f32) -> t_i64); @@ -540,124 +521,40 @@ impl CodegenCx<'b, 'tcx> { ifn!("llvm.sideeffect", fn() -> void); ifn!("llvm.powi.f32", fn(t_f32, t_i32) -> t_f32); - ifn!("llvm.powi.v2f32", fn(t_v2f32, t_i32) -> t_v2f32); - ifn!("llvm.powi.v4f32", fn(t_v4f32, t_i32) -> t_v4f32); - ifn!("llvm.powi.v8f32", fn(t_v8f32, t_i32) -> t_v8f32); - ifn!("llvm.powi.v16f32", fn(t_v16f32, t_i32) -> t_v16f32); ifn!("llvm.powi.f64", fn(t_f64, t_i32) -> t_f64); - ifn!("llvm.powi.v2f64", fn(t_v2f64, t_i32) -> t_v2f64); - ifn!("llvm.powi.v4f64", fn(t_v4f64, t_i32) -> t_v4f64); - ifn!("llvm.powi.v8f64", fn(t_v8f64, t_i32) -> t_v8f64); ifn!("llvm.pow.f32", fn(t_f32, t_f32) -> t_f32); - ifn!("llvm.pow.v2f32", fn(t_v2f32, t_v2f32) -> t_v2f32); - ifn!("llvm.pow.v4f32", fn(t_v4f32, t_v4f32) -> t_v4f32); - ifn!("llvm.pow.v8f32", fn(t_v8f32, t_v8f32) -> t_v8f32); - ifn!("llvm.pow.v16f32", fn(t_v16f32, t_v16f32) -> t_v16f32); ifn!("llvm.pow.f64", fn(t_f64, t_f64) -> t_f64); - ifn!("llvm.pow.v2f64", fn(t_v2f64, t_v2f64) -> t_v2f64); - ifn!("llvm.pow.v4f64", fn(t_v4f64, t_v4f64) -> t_v4f64); - ifn!("llvm.pow.v8f64", fn(t_v8f64, t_v8f64) -> t_v8f64); ifn!("llvm.sqrt.f32", fn(t_f32) -> t_f32); - ifn!("llvm.sqrt.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.sqrt.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.sqrt.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.sqrt.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.sqrt.f64", fn(t_f64) -> t_f64); - ifn!("llvm.sqrt.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.sqrt.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.sqrt.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.sin.f32", fn(t_f32) -> t_f32); - ifn!("llvm.sin.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.sin.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.sin.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.sin.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.sin.f64", fn(t_f64) -> t_f64); - ifn!("llvm.sin.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.sin.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.sin.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.cos.f32", fn(t_f32) -> t_f32); - ifn!("llvm.cos.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.cos.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.cos.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.cos.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.cos.f64", fn(t_f64) -> t_f64); - ifn!("llvm.cos.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.cos.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.cos.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.exp.f32", fn(t_f32) -> t_f32); - ifn!("llvm.exp.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.exp.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.exp.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.exp.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.exp.f64", fn(t_f64) -> t_f64); - ifn!("llvm.exp.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.exp.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.exp.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.exp2.f32", fn(t_f32) -> t_f32); - ifn!("llvm.exp2.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.exp2.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.exp2.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.exp2.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.exp2.f64", fn(t_f64) -> t_f64); - ifn!("llvm.exp2.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.exp2.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.exp2.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.log.f32", fn(t_f32) -> t_f32); - ifn!("llvm.log.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.log.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.log.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.log.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.log.f64", fn(t_f64) -> t_f64); - ifn!("llvm.log.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.log.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.log.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.log10.f32", fn(t_f32) -> t_f32); - ifn!("llvm.log10.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.log10.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.log10.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.log10.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.log10.f64", fn(t_f64) -> t_f64); - ifn!("llvm.log10.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.log10.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.log10.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.log2.f32", fn(t_f32) -> t_f32); - ifn!("llvm.log2.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.log2.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.log2.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.log2.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.log2.f64", fn(t_f64) -> t_f64); - ifn!("llvm.log2.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.log2.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.log2.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.fma.f32", fn(t_f32, t_f32, t_f32) -> t_f32); - ifn!("llvm.fma.v2f32", fn(t_v2f32, t_v2f32, t_v2f32) -> t_v2f32); - ifn!("llvm.fma.v4f32", fn(t_v4f32, t_v4f32, t_v4f32) -> t_v4f32); - ifn!("llvm.fma.v8f32", fn(t_v8f32, t_v8f32, t_v8f32) -> t_v8f32); - ifn!("llvm.fma.v16f32", fn(t_v16f32, t_v16f32, t_v16f32) -> t_v16f32); ifn!("llvm.fma.f64", fn(t_f64, t_f64, t_f64) -> t_f64); - ifn!("llvm.fma.v2f64", fn(t_v2f64, t_v2f64, t_v2f64) -> t_v2f64); - ifn!("llvm.fma.v4f64", fn(t_v4f64, t_v4f64, t_v4f64) -> t_v4f64); - ifn!("llvm.fma.v8f64", fn(t_v8f64, t_v8f64, t_v8f64) -> t_v8f64); ifn!("llvm.fabs.f32", fn(t_f32) -> t_f32); - ifn!("llvm.fabs.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.fabs.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.fabs.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.fabs.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.fabs.f64", fn(t_f64) -> t_f64); - ifn!("llvm.fabs.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.fabs.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.fabs.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.minnum.f32", fn(t_f32, t_f32) -> t_f32); ifn!("llvm.minnum.f64", fn(t_f64, t_f64) -> t_f64); @@ -665,24 +562,10 @@ impl CodegenCx<'b, 'tcx> { ifn!("llvm.maxnum.f64", fn(t_f64, t_f64) -> t_f64); ifn!("llvm.floor.f32", fn(t_f32) -> t_f32); - ifn!("llvm.floor.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.floor.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.floor.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.floor.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.floor.f64", fn(t_f64) -> t_f64); - ifn!("llvm.floor.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.floor.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.floor.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.ceil.f32", fn(t_f32) -> t_f32); - ifn!("llvm.ceil.v2f32", fn(t_v2f32) -> t_v2f32); - ifn!("llvm.ceil.v4f32", fn(t_v4f32) -> t_v4f32); - ifn!("llvm.ceil.v8f32", fn(t_v8f32) -> t_v8f32); - ifn!("llvm.ceil.v16f32", fn(t_v16f32) -> t_v16f32); ifn!("llvm.ceil.f64", fn(t_f64) -> t_f64); - ifn!("llvm.ceil.v2f64", fn(t_v2f64) -> t_v2f64); - ifn!("llvm.ceil.v4f64", fn(t_v4f64) -> t_v4f64); - ifn!("llvm.ceil.v8f64", fn(t_v8f64) -> t_v8f64); ifn!("llvm.trunc.f32", fn(t_f32) -> t_f32); ifn!("llvm.trunc.f64", fn(t_f64) -> t_f64); diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index bf0d499e6c4..81728b8def9 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1009,7 +1009,7 @@ fn generic_simd_intrinsic( } fn simd_simple_float_intrinsic( - name: &str, + name: Symbol, in_elem: &::rustc_middle::ty::TyS<'_>, in_ty: &::rustc_middle::ty::TyS<'_>, in_len: u64, @@ -1036,93 +1036,70 @@ fn generic_simd_intrinsic( } } } - let ety = match in_elem.kind() { - ty::Float(f) if f.bit_width() == 32 => { - if in_len < 2 || in_len > 16 { - return_error!( - "unsupported floating-point vector `{}` with length `{}` \ - out-of-range [2, 16]", - in_ty, - in_len - ); - } - "f32" - } - ty::Float(f) if f.bit_width() == 64 => { - if in_len < 2 || in_len > 8 { + + let (elem_ty_str, elem_ty) = if let ty::Float(f) = in_elem.kind() { + let elem_ty = bx.cx.type_float_from_ty(*f); + match f.bit_width() { + 32 => ("f32", elem_ty), + 64 => ("f64", elem_ty), + _ => { return_error!( - "unsupported floating-point vector `{}` with length `{}` \ - out-of-range [2, 8]", - in_ty, - in_len + "unsupported element type `{}` of floating-point vector `{}`", + f.name_str(), + in_ty ); } - "f64" - } - ty::Float(f) => { - return_error!( - "unsupported element type `{}` of floating-point vector `{}`", - f.name_str(), - in_ty - ); - } - _ => { - return_error!("`{}` is not a floating-point type", in_ty); } + } else { + return_error!("`{}` is not a floating-point type", in_ty); }; - let llvm_name = &format!("llvm.{0}.v{1}{2}", name, in_len, ety); - let intrinsic = bx.get_intrinsic(&llvm_name); - let c = - bx.call(intrinsic, &args.iter().map(|arg| arg.immediate()).collect::<Vec<_>>(), None); + let vec_ty = bx.type_vector(elem_ty, in_len); + + let (intr_name, fn_ty) = match name { + sym::simd_fsqrt => ("sqrt", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_fsin => ("sin", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_fcos => ("cos", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_fabs => ("fabs", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_floor => ("floor", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_ceil => ("ceil", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_fexp => ("exp", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_fexp2 => ("exp2", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_flog10 => ("log10", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_flog2 => ("log2", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_flog => ("log", bx.type_func(&[vec_ty], vec_ty)), + sym::simd_fpowi => ("powi", bx.type_func(&[vec_ty, bx.type_i32()], vec_ty)), + sym::simd_fpow => ("pow", bx.type_func(&[vec_ty, vec_ty], vec_ty)), + sym::simd_fma => ("fma", bx.type_func(&[vec_ty, vec_ty, vec_ty], vec_ty)), + _ => return_error!("unrecognized intrinsic `{}`", name), + }; + + let llvm_name = &format!("llvm.{0}.v{1}{2}", intr_name, in_len, elem_ty_str); + let f = bx.declare_cfn(&llvm_name, fn_ty); + llvm::SetUnnamedAddress(f, llvm::UnnamedAddr::No); + let c = bx.call(f, &args.iter().map(|arg| arg.immediate()).collect::<Vec<_>>(), None); unsafe { llvm::LLVMRustSetHasUnsafeAlgebra(c) }; Ok(c) } - match name { - sym::simd_fsqrt => { - return simd_simple_float_intrinsic("sqrt", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_fsin => { - return simd_simple_float_intrinsic("sin", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_fcos => { - return simd_simple_float_intrinsic("cos", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_fabs => { - return simd_simple_float_intrinsic("fabs", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_floor => { - return simd_simple_float_intrinsic("floor", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_ceil => { - return simd_simple_float_intrinsic("ceil", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_fexp => { - return simd_simple_float_intrinsic("exp", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_fexp2 => { - return simd_simple_float_intrinsic("exp2", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_flog10 => { - return simd_simple_float_intrinsic("log10", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_flog2 => { - return simd_simple_float_intrinsic("log2", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_flog => { - return simd_simple_float_intrinsic("log", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_fpowi => { - return simd_simple_float_intrinsic("powi", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_fpow => { - return simd_simple_float_intrinsic("pow", in_elem, in_ty, in_len, bx, span, args); - } - sym::simd_fma => { - return simd_simple_float_intrinsic("fma", in_elem, in_ty, in_len, bx, span, args); - } - _ => { /* fallthrough */ } + if std::matches!( + name, + sym::simd_fsqrt + | sym::simd_fsin + | sym::simd_fcos + | sym::simd_fabs + | sym::simd_floor + | sym::simd_ceil + | sym::simd_fexp + | sym::simd_fexp2 + | sym::simd_flog10 + | sym::simd_flog2 + | sym::simd_flog + | sym::simd_fpowi + | sym::simd_fpow + | sym::simd_fma + ) { + return simd_simple_float_intrinsic(name, in_elem, in_ty, in_len, bx, span, args); } // FIXME: use: |
