about summary refs log tree commit diff
diff options
context:
space:
mode:
authorKjetil Kjeka <kjetilkjeka@gmail.com>2022-03-07 15:09:28 +0100
committerKjetil Kjeka <kjetilkjeka@gmail.com>2022-04-19 18:03:36 +0200
commit352abbaadeab323c1b5a69d4669e052e9a34fb67 (patch)
tree5906f2f7fa8e288b9efea73cee854225054f9dd8
parent297273c45b205820a4c055082c71677197a40b55 (diff)
downloadrust-352abbaadeab323c1b5a69d4669e052e9a34fb67.tar.gz
rust-352abbaadeab323c1b5a69d4669e052e9a34fb67.zip
Fix a bug in the ptx-kernel calling convention where structs was passed indirectly
Structs being passed indirectly is suprpising and have a high chance not to work as the device and host usually do not share memory.
-rw-r--r--compiler/rustc_middle/src/ty/layout.rs16
-rw-r--r--compiler/rustc_middle/src/ty/list.rs4
-rw-r--r--compiler/rustc_target/src/abi/call/mod.rs8
-rw-r--r--compiler/rustc_target/src/abi/call/nvptx64.rs47
-rw-r--r--compiler/rustc_target/src/abi/mod.rs32
5 files changed, 98 insertions, 9 deletions
diff --git a/compiler/rustc_middle/src/ty/layout.rs b/compiler/rustc_middle/src/ty/layout.rs
index 7495449da4c..fe097f05267 100644
--- a/compiler/rustc_middle/src/ty/layout.rs
+++ b/compiler/rustc_middle/src/ty/layout.rs
@@ -2568,6 +2568,22 @@ where
 
         pointee_info
     }
+
+    fn is_adt(this: TyAndLayout<'tcx>) -> bool {
+        matches!(this.ty.kind(), ty::Adt(..))
+    }
+
+    fn is_never(this: TyAndLayout<'tcx>) -> bool {
+        this.ty.kind() == &ty::Never
+    }
+
+    fn is_tuple(this: TyAndLayout<'tcx>) -> bool {
+        matches!(this.ty.kind(), ty::Tuple(..))
+    }
+
+    fn is_unit(this: TyAndLayout<'tcx>) -> bool {
+        matches!(this.ty.kind(), ty::Tuple(list) if list.len() == 0)
+    }
 }
 
 impl<'tcx> ty::Instance<'tcx> {
diff --git a/compiler/rustc_middle/src/ty/list.rs b/compiler/rustc_middle/src/ty/list.rs
index adba7d13159..197dc9205b4 100644
--- a/compiler/rustc_middle/src/ty/list.rs
+++ b/compiler/rustc_middle/src/ty/list.rs
@@ -61,6 +61,10 @@ impl<T> List<T> {
         static EMPTY_SLICE: InOrder<usize, MaxAlign> = InOrder(0, MaxAlign);
         unsafe { &*(&EMPTY_SLICE as *const _ as *const List<T>) }
     }
+
+    pub fn len(&self) -> usize {
+        self.len
+    }
 }
 
 impl<T: Copy> List<T> {
diff --git a/compiler/rustc_target/src/abi/call/mod.rs b/compiler/rustc_target/src/abi/call/mod.rs
index 34324a58297..e6e98764737 100644
--- a/compiler/rustc_target/src/abi/call/mod.rs
+++ b/compiler/rustc_target/src/abi/call/mod.rs
@@ -696,7 +696,13 @@ impl<'a, Ty> FnAbi<'a, Ty> {
             "sparc" => sparc::compute_abi_info(cx, self),
             "sparc64" => sparc64::compute_abi_info(cx, self),
             "nvptx" => nvptx::compute_abi_info(self),
-            "nvptx64" => nvptx64::compute_abi_info(self),
+            "nvptx64" => {
+                if cx.target_spec().adjust_abi(abi) == spec::abi::Abi::PtxKernel {
+                    nvptx64::compute_ptx_kernel_abi_info(cx, self)
+                } else {
+                    nvptx64::compute_abi_info(self)
+                }
+            }
             "hexagon" => hexagon::compute_abi_info(self),
             "riscv32" | "riscv64" => riscv::compute_abi_info(cx, self),
             "wasm32" | "wasm64" => {
diff --git a/compiler/rustc_target/src/abi/call/nvptx64.rs b/compiler/rustc_target/src/abi/call/nvptx64.rs
index 16f331b16d5..fc16f1c97a4 100644
--- a/compiler/rustc_target/src/abi/call/nvptx64.rs
+++ b/compiler/rustc_target/src/abi/call/nvptx64.rs
@@ -1,21 +1,35 @@
-// Reference: PTX Writer's Guide to Interoperability
-// https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability
-
-use crate::abi::call::{ArgAbi, FnAbi};
+use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform};
+use crate::abi::{HasDataLayout, TyAbiInterface};
 
 fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
     if ret.layout.is_aggregate() && ret.layout.size.bits() > 64 {
         ret.make_indirect();
-    } else {
-        ret.extend_integer_width_to(64);
     }
 }
 
 fn classify_arg<Ty>(arg: &mut ArgAbi<'_, Ty>) {
     if arg.layout.is_aggregate() && arg.layout.size.bits() > 64 {
         arg.make_indirect();
-    } else {
-        arg.extend_integer_width_to(64);
+    }
+}
+
+fn classify_arg_kernel<'a, Ty, C>(_cx: &C, arg: &mut ArgAbi<'a, Ty>)
+where
+    Ty: TyAbiInterface<'a, C> + Copy,
+    C: HasDataLayout,
+{
+    if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) {
+        let align_bytes = arg.layout.align.abi.bytes();
+
+        let unit = match align_bytes {
+            1 => Reg::i8(),
+            2 => Reg::i16(),
+            4 => Reg::i32(),
+            8 => Reg::i64(),
+            16 => Reg::i128(),
+            _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
+        };
+        arg.cast_to(Uniform { unit, total: Size::from_bytes(2 * align_bytes) });
     }
 }
 
@@ -31,3 +45,20 @@ pub fn compute_abi_info<Ty>(fn_abi: &mut FnAbi<'_, Ty>) {
         classify_arg(arg);
     }
 }
+
+pub fn compute_ptx_kernel_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>)
+where
+    Ty: TyAbiInterface<'a, C> + Copy,
+    C: HasDataLayout,
+{
+    if !fn_abi.ret.layout.is_unit() && !fn_abi.ret.layout.is_never() {
+        panic!("Kernels should not return anything other than () or !");
+    }
+
+    for arg in &mut fn_abi.args {
+        if arg.is_ignore() {
+            continue;
+        }
+        classify_arg_kernel(cx, arg);
+    }
+}
diff --git a/compiler/rustc_target/src/abi/mod.rs b/compiler/rustc_target/src/abi/mod.rs
index 4ef86371298..9e6b740f6f9 100644
--- a/compiler/rustc_target/src/abi/mod.rs
+++ b/compiler/rustc_target/src/abi/mod.rs
@@ -1250,6 +1250,10 @@ pub trait TyAbiInterface<'a, C>: Sized {
         cx: &C,
         offset: Size,
     ) -> Option<PointeeInfo>;
+    fn is_adt(this: TyAndLayout<'a, Self>) -> bool;
+    fn is_never(this: TyAndLayout<'a, Self>) -> bool;
+    fn is_tuple(this: TyAndLayout<'a, Self>) -> bool;
+    fn is_unit(this: TyAndLayout<'a, Self>) -> bool;
 }
 
 impl<'a, Ty> TyAndLayout<'a, Ty> {
@@ -1291,6 +1295,34 @@ impl<'a, Ty> TyAndLayout<'a, Ty> {
             _ => false,
         }
     }
+
+    pub fn is_adt<C>(self) -> bool
+    where
+        Ty: TyAbiInterface<'a, C>,
+    {
+        Ty::is_adt(self)
+    }
+
+    pub fn is_never<C>(self) -> bool
+    where
+        Ty: TyAbiInterface<'a, C>,
+    {
+        Ty::is_never(self)
+    }
+
+    pub fn is_tuple<C>(self) -> bool
+    where
+        Ty: TyAbiInterface<'a, C>,
+    {
+        Ty::is_tuple(self)
+    }
+
+    pub fn is_unit<C>(self) -> bool
+    where
+        Ty: TyAbiInterface<'a, C>,
+    {
+        Ty::is_unit(self)
+    }
 }
 
 impl<'a, Ty> TyAndLayout<'a, Ty> {