about summary refs log tree commit diff
diff options
context:
space:
mode:
authorGuillaume Gomez <guillaume1.gomez@gmail.com>2022-04-26 13:22:27 +0200
committerGitHub <noreply@github.com>2022-04-26 13:22:27 +0200
commitfe49981ea0069f6b89e4511d5de92b7027d9b688 (patch)
tree506d49735cd47ce97bc37e7fbe885cc85c77aa0d
parenteaf8beb3f32b089b426a435794214ee0eaed8fec (diff)
parent5bf5acc324f56fad271a97a93b7976d7d3aa67fb (diff)
downloadrust-fe49981ea0069f6b89e4511d5de92b7027d9b688.tar.gz
rust-fe49981ea0069f6b89e4511d5de92b7027d9b688.zip
Rollup merge of #94703 - kjetilkjeka:nvptx-kernel-args-abi2, r=nagisa
Fix codegen bug in "ptx-kernel" abi related to arg passing

I found a codegen bug in the nvptx abi related to that args are passed as ptrs ([see comment](https://github.com/rust-lang/rust/issues/38788#issuecomment-1048999928)), this is not as specified in the [ptx-interoperability doc](https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/) or how C/C++ does it. It will also almost always fail in practice since device/host uses different memory spaces for most hardware.

This PR fixes the bug and add tests for passing structs to ptx kernels.

I observed that all nvptx assembly tests had been marked as [ignore a long time ago](https://github.com/rust-lang/rust/pull/59752#issuecomment-501713428). I'm not sure if the new one should be marked as ignore, it passed on my computer but it might fail if ptx-linker is missing on the server? I guess this is outside scope for this PR and should be looked at in a different issue/PR.

I only fixed the nvptx64-nvidia-cuda target and not the potential code paths for the non-existing 32bit target. Even though 32bit nvptx is not a supported target there are still some code under the hood supporting codegen for 32 bit ptx. I was advised to create an MCP to find out if this code should be removed or updated.

Perhaps ``@RDambrosio016`` would have interest in taking a quick look at this.
-rw-r--r--compiler/rustc_middle/src/ty/layout.rs16
-rw-r--r--compiler/rustc_middle/src/ty/list.rs4
-rw-r--r--compiler/rustc_target/src/abi/call/mod.rs8
-rw-r--r--compiler/rustc_target/src/abi/call/nvptx64.rs47
-rw-r--r--compiler/rustc_target/src/abi/mod.rs32
-rw-r--r--src/test/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs254
6 files changed, 352 insertions, 9 deletions
diff --git a/compiler/rustc_middle/src/ty/layout.rs b/compiler/rustc_middle/src/ty/layout.rs
index 7cf2984a63f..cd4b23fca39 100644
--- a/compiler/rustc_middle/src/ty/layout.rs
+++ b/compiler/rustc_middle/src/ty/layout.rs
@@ -2592,6 +2592,22 @@ where
 
         pointee_info
     }
+
+    fn is_adt(this: TyAndLayout<'tcx>) -> bool {
+        matches!(this.ty.kind(), ty::Adt(..))
+    }
+
+    fn is_never(this: TyAndLayout<'tcx>) -> bool {
+        this.ty.kind() == &ty::Never
+    }
+
+    fn is_tuple(this: TyAndLayout<'tcx>) -> bool {
+        matches!(this.ty.kind(), ty::Tuple(..))
+    }
+
+    fn is_unit(this: TyAndLayout<'tcx>) -> bool {
+        matches!(this.ty.kind(), ty::Tuple(list) if list.len() == 0)
+    }
 }
 
 impl<'tcx> ty::Instance<'tcx> {
diff --git a/compiler/rustc_middle/src/ty/list.rs b/compiler/rustc_middle/src/ty/list.rs
index adba7d13159..197dc9205b4 100644
--- a/compiler/rustc_middle/src/ty/list.rs
+++ b/compiler/rustc_middle/src/ty/list.rs
@@ -61,6 +61,10 @@ impl<T> List<T> {
         static EMPTY_SLICE: InOrder<usize, MaxAlign> = InOrder(0, MaxAlign);
         unsafe { &*(&EMPTY_SLICE as *const _ as *const List<T>) }
     }
+
+    pub fn len(&self) -> usize {
+        self.len
+    }
 }
 
 impl<T: Copy> List<T> {
diff --git a/compiler/rustc_target/src/abi/call/mod.rs b/compiler/rustc_target/src/abi/call/mod.rs
index ce564d1455b..afce10ff1cb 100644
--- a/compiler/rustc_target/src/abi/call/mod.rs
+++ b/compiler/rustc_target/src/abi/call/mod.rs
@@ -696,7 +696,13 @@ impl<'a, Ty> FnAbi<'a, Ty> {
             "sparc" => sparc::compute_abi_info(cx, self),
             "sparc64" => sparc64::compute_abi_info(cx, self),
             "nvptx" => nvptx::compute_abi_info(self),
-            "nvptx64" => nvptx64::compute_abi_info(self),
+            "nvptx64" => {
+                if cx.target_spec().adjust_abi(abi) == spec::abi::Abi::PtxKernel {
+                    nvptx64::compute_ptx_kernel_abi_info(cx, self)
+                } else {
+                    nvptx64::compute_abi_info(self)
+                }
+            }
             "hexagon" => hexagon::compute_abi_info(self),
             "riscv32" | "riscv64" => riscv::compute_abi_info(cx, self),
             "wasm32" | "wasm64" => {
diff --git a/compiler/rustc_target/src/abi/call/nvptx64.rs b/compiler/rustc_target/src/abi/call/nvptx64.rs
index 16f331b16d5..fc16f1c97a4 100644
--- a/compiler/rustc_target/src/abi/call/nvptx64.rs
+++ b/compiler/rustc_target/src/abi/call/nvptx64.rs
@@ -1,21 +1,35 @@
-// Reference: PTX Writer's Guide to Interoperability
-// https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability
-
-use crate::abi::call::{ArgAbi, FnAbi};
+use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform};
+use crate::abi::{HasDataLayout, TyAbiInterface};
 
 fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
     if ret.layout.is_aggregate() && ret.layout.size.bits() > 64 {
         ret.make_indirect();
-    } else {
-        ret.extend_integer_width_to(64);
     }
 }
 
 fn classify_arg<Ty>(arg: &mut ArgAbi<'_, Ty>) {
     if arg.layout.is_aggregate() && arg.layout.size.bits() > 64 {
         arg.make_indirect();
-    } else {
-        arg.extend_integer_width_to(64);
+    }
+}
+
+fn classify_arg_kernel<'a, Ty, C>(_cx: &C, arg: &mut ArgAbi<'a, Ty>)
+where
+    Ty: TyAbiInterface<'a, C> + Copy,
+    C: HasDataLayout,
+{
+    if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) {
+        let align_bytes = arg.layout.align.abi.bytes();
+
+        let unit = match align_bytes {
+            1 => Reg::i8(),
+            2 => Reg::i16(),
+            4 => Reg::i32(),
+            8 => Reg::i64(),
+            16 => Reg::i128(),
+            _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
+        };
+        arg.cast_to(Uniform { unit, total: Size::from_bytes(2 * align_bytes) });
     }
 }
 
@@ -31,3 +45,20 @@ pub fn compute_abi_info<Ty>(fn_abi: &mut FnAbi<'_, Ty>) {
         classify_arg(arg);
     }
 }
+
+pub fn compute_ptx_kernel_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>)
+where
+    Ty: TyAbiInterface<'a, C> + Copy,
+    C: HasDataLayout,
+{
+    if !fn_abi.ret.layout.is_unit() && !fn_abi.ret.layout.is_never() {
+        panic!("Kernels should not return anything other than () or !");
+    }
+
+    for arg in &mut fn_abi.args {
+        if arg.is_ignore() {
+            continue;
+        }
+        classify_arg_kernel(cx, arg);
+    }
+}
diff --git a/compiler/rustc_target/src/abi/mod.rs b/compiler/rustc_target/src/abi/mod.rs
index 169167f69bf..0e8fd9cc93f 100644
--- a/compiler/rustc_target/src/abi/mod.rs
+++ b/compiler/rustc_target/src/abi/mod.rs
@@ -1355,6 +1355,10 @@ pub trait TyAbiInterface<'a, C>: Sized {
         cx: &C,
         offset: Size,
     ) -> Option<PointeeInfo>;
+    fn is_adt(this: TyAndLayout<'a, Self>) -> bool;
+    fn is_never(this: TyAndLayout<'a, Self>) -> bool;
+    fn is_tuple(this: TyAndLayout<'a, Self>) -> bool;
+    fn is_unit(this: TyAndLayout<'a, Self>) -> bool;
 }
 
 impl<'a, Ty> TyAndLayout<'a, Ty> {
@@ -1396,6 +1400,34 @@ impl<'a, Ty> TyAndLayout<'a, Ty> {
             _ => false,
         }
     }
+
+    pub fn is_adt<C>(self) -> bool
+    where
+        Ty: TyAbiInterface<'a, C>,
+    {
+        Ty::is_adt(self)
+    }
+
+    pub fn is_never<C>(self) -> bool
+    where
+        Ty: TyAbiInterface<'a, C>,
+    {
+        Ty::is_never(self)
+    }
+
+    pub fn is_tuple<C>(self) -> bool
+    where
+        Ty: TyAbiInterface<'a, C>,
+    {
+        Ty::is_tuple(self)
+    }
+
+    pub fn is_unit<C>(self) -> bool
+    where
+        Ty: TyAbiInterface<'a, C>,
+    {
+        Ty::is_unit(self)
+    }
 }
 
 impl<'a, Ty> TyAndLayout<'a, Ty> {
diff --git a/src/test/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs b/src/test/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs
new file mode 100644
index 00000000000..5bf44f949fd
--- /dev/null
+++ b/src/test/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs
@@ -0,0 +1,254 @@
+// assembly-output: ptx-linker
+// compile-flags: --crate-type cdylib -C target-cpu=sm_86
+// only-nvptx64
+// ignore-nvptx64
+
+// The following ABI tests are made with nvcc 11.6 does.
+//
+// The PTX ABI stability is tied to major versions of the PTX ISA
+// These tests assume major version 7
+//
+//
+// The following correspondence between types are assumed:
+// u<N> - uint<N>_t
+// i<N> - int<N>_t
+// [T, N] - std::array<T, N>
+// &T - T const*
+// &mut T - T*
+
+// CHECK: .version 7
+
+#![feature(abi_ptx, lang_items, no_core)]
+#![no_core]
+
+#[lang = "sized"]
+trait Sized {}
+#[lang = "copy"]
+trait Copy {}
+
+#[repr(C)]
+pub struct SingleU8 {
+    f: u8,
+}
+
+#[repr(C)]
+pub struct DoubleU8 {
+    f: u8,
+    g: u8,
+}
+
+#[repr(C)]
+pub struct TripleU8 {
+    f: u8,
+    g: u8,
+    h: u8,
+}
+
+#[repr(C)]
+pub struct TripleU16 {
+    f: u16,
+    g: u16,
+    h: u16,
+}
+#[repr(C)]
+pub struct TripleU32 {
+    f: u32,
+    g: u32,
+    h: u32,
+}
+#[repr(C)]
+pub struct TripleU64 {
+    f: u64,
+    g: u64,
+    h: u64,
+}
+
+#[repr(C)]
+pub struct DoubleFloat {
+    f: f32,
+    g: f32,
+}
+
+#[repr(C)]
+pub struct TripleFloat {
+    f: f32,
+    g: f32,
+    h: f32,
+}
+
+#[repr(C)]
+pub struct TripleDouble {
+    f: f64,
+    g: f64,
+    h: f64,
+}
+
+#[repr(C)]
+pub struct ManyIntegers {
+    f: u8,
+    g: u16,
+    h: u32,
+    i: u64,
+}
+
+#[repr(C)]
+pub struct ManyNumerics {
+    f: u8,
+    g: u16,
+    h: u32,
+    i: u64,
+    j: f32,
+    k: f64,
+}
+
+// CHECK: .visible .entry f_u8_arg(
+// CHECK: .param .u8 f_u8_arg_param_0
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_u8_arg(_a: u8) {}
+
+// CHECK: .visible .entry f_u16_arg(
+// CHECK: .param .u16 f_u16_arg_param_0
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_u16_arg(_a: u16) {}
+
+// CHECK: .visible .entry f_u32_arg(
+// CHECK: .param .u32 f_u32_arg_param_0
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_u32_arg(_a: u32) {}
+
+// CHECK: .visible .entry f_u64_arg(
+// CHECK: .param .u64 f_u64_arg_param_0
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_u64_arg(_a: u64) {}
+
+// CHECK: .visible .entry f_u128_arg(
+// CHECK: .param .align 16 .b8 f_u128_arg_param_0[16]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_u128_arg(_a: u128) {}
+
+// CHECK: .visible .entry f_i8_arg(
+// CHECK: .param .u8 f_i8_arg_param_0
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_i8_arg(_a: i8) {}
+
+// CHECK: .visible .entry f_i16_arg(
+// CHECK: .param .u16 f_i16_arg_param_0
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_i16_arg(_a: i16) {}
+
+// CHECK: .visible .entry f_i32_arg(
+// CHECK: .param .u32 f_i32_arg_param_0
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_i32_arg(_a: i32) {}
+
+// CHECK: .visible .entry f_i64_arg(
+// CHECK: .param .u64 f_i64_arg_param_0
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_i64_arg(_a: i64) {}
+
+// CHECK: .visible .entry f_i128_arg(
+// CHECK: .param .align 16 .b8 f_i128_arg_param_0[16]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_i128_arg(_a: i128) {}
+
+// CHECK: .visible .entry f_f32_arg(
+// CHECK: .param .f32 f_f32_arg_param_0
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_f32_arg(_a: f32) {}
+
+// CHECK: .visible .entry f_f64_arg(
+// CHECK: .param .f64 f_f64_arg_param_0
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_f64_arg(_a: f64) {}
+
+// CHECK: .visible .entry f_single_u8_arg(
+// CHECK: .param .align 1 .b8 f_single_u8_arg_param_0[1]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_single_u8_arg(_a: SingleU8) {}
+
+// CHECK: .visible .entry f_double_u8_arg(
+// CHECK: .param .align 1 .b8 f_double_u8_arg_param_0[2]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_double_u8_arg(_a: DoubleU8) {}
+
+// CHECK: .visible .entry f_triple_u8_arg(
+// CHECK: .param .align 1 .b8 f_triple_u8_arg_param_0[3]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_triple_u8_arg(_a: TripleU8) {}
+
+// CHECK: .visible .entry f_triple_u16_arg(
+// CHECK: .param .align 2 .b8 f_triple_u16_arg_param_0[6]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_triple_u16_arg(_a: TripleU16) {}
+
+// CHECK: .visible .entry f_triple_u32_arg(
+// CHECK: .param .align 4 .b8 f_triple_u32_arg_param_0[12]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_triple_u32_arg(_a: TripleU32) {}
+
+// CHECK: .visible .entry f_triple_u64_arg(
+// CHECK: .param .align 8 .b8 f_triple_u64_arg_param_0[24]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_triple_u64_arg(_a: TripleU64) {}
+
+// CHECK: .visible .entry f_many_integers_arg(
+// CHECK: .param .align 8 .b8 f_many_integers_arg_param_0[16]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_many_integers_arg(_a: ManyIntegers) {}
+
+// CHECK: .visible .entry f_double_float_arg(
+// CHECK: .param .align 4 .b8 f_double_float_arg_param_0[8]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_double_float_arg(_a: DoubleFloat) {}
+
+// CHECK: .visible .entry f_triple_float_arg(
+// CHECK: .param .align 4 .b8 f_triple_float_arg_param_0[12]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_triple_float_arg(_a: TripleFloat) {}
+
+// CHECK: .visible .entry f_triple_double_arg(
+// CHECK: .param .align 8 .b8 f_triple_double_arg_param_0[24]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_triple_double_arg(_a: TripleDouble) {}
+
+// CHECK: .visible .entry f_many_numerics_arg(
+// CHECK: .param .align 8 .b8 f_many_numerics_arg_param_0[32]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_many_numerics_arg(_a: ManyNumerics) {}
+
+// CHECK: .visible .entry f_byte_array_arg(
+// CHECK: .param .align 1 .b8 f_byte_array_arg_param_0[5]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_byte_array_arg(_a: [u8; 5]) {}
+
+// CHECK: .visible .entry f_float_array_arg(
+// CHECK: .param .align 4 .b8 f_float_array_arg_param_0[20]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_float_array_arg(_a: [f32; 5]) {}
+
+// CHECK: .visible .entry f_u128_array_arg(
+// CHECK: .param .align 16 .b8 f_u128_array_arg_param_0[80]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_u128_array_arg(_a: [u128; 5]) {}
+
+// CHECK: .visible .entry f_u32_slice_arg(
+// CHECK: .param .u64 f_u32_slice_arg_param_0
+// CHECK: .param .u64 f_u32_slice_arg_param_1
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_u32_slice_arg(_a: &[u32]) {}
+
+// CHECK: .visible .entry f_tuple_u8_u8_arg(
+// CHECK: .param .align 1 .b8 f_tuple_u8_u8_arg_param_0[2]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_arg(_a: (u8, u8)) {}
+
+// CHECK: .visible .entry f_tuple_u32_u32_arg(
+// CHECK: .param .align 4 .b8 f_tuple_u32_u32_arg_param_0[8]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_tuple_u32_u32_arg(_a: (u32, u32)) {}
+
+
+// CHECK: .visible .entry f_tuple_u8_u8_u32_arg(
+// CHECK: .param .align 4 .b8 f_tuple_u8_u8_u32_arg_param_0[8]
+#[no_mangle]
+pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_u32_arg(_a: (u8, u8, u32)) {}