diff options
| author | Kjetil Kjeka <kjetilkjeka@gmail.com> | 2022-03-07 15:09:28 +0100 |
|---|---|---|
| committer | Kjetil Kjeka <kjetilkjeka@gmail.com> | 2022-04-19 18:03:36 +0200 |
| commit | 352abbaadeab323c1b5a69d4669e052e9a34fb67 (patch) | |
| tree | 5906f2f7fa8e288b9efea73cee854225054f9dd8 | |
| parent | 297273c45b205820a4c055082c71677197a40b55 (diff) | |
| download | rust-352abbaadeab323c1b5a69d4669e052e9a34fb67.tar.gz rust-352abbaadeab323c1b5a69d4669e052e9a34fb67.zip | |
Fix a bug in the ptx-kernel calling convention where structs was passed indirectly
Structs being passed indirectly is suprpising and have a high chance not to work as the device and host usually do not share memory.
| -rw-r--r-- | compiler/rustc_middle/src/ty/layout.rs | 16 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/ty/list.rs | 4 | ||||
| -rw-r--r-- | compiler/rustc_target/src/abi/call/mod.rs | 8 | ||||
| -rw-r--r-- | compiler/rustc_target/src/abi/call/nvptx64.rs | 47 | ||||
| -rw-r--r-- | compiler/rustc_target/src/abi/mod.rs | 32 |
5 files changed, 98 insertions, 9 deletions
diff --git a/compiler/rustc_middle/src/ty/layout.rs b/compiler/rustc_middle/src/ty/layout.rs index 7495449da4c..fe097f05267 100644 --- a/compiler/rustc_middle/src/ty/layout.rs +++ b/compiler/rustc_middle/src/ty/layout.rs @@ -2568,6 +2568,22 @@ where pointee_info } + + fn is_adt(this: TyAndLayout<'tcx>) -> bool { + matches!(this.ty.kind(), ty::Adt(..)) + } + + fn is_never(this: TyAndLayout<'tcx>) -> bool { + this.ty.kind() == &ty::Never + } + + fn is_tuple(this: TyAndLayout<'tcx>) -> bool { + matches!(this.ty.kind(), ty::Tuple(..)) + } + + fn is_unit(this: TyAndLayout<'tcx>) -> bool { + matches!(this.ty.kind(), ty::Tuple(list) if list.len() == 0) + } } impl<'tcx> ty::Instance<'tcx> { diff --git a/compiler/rustc_middle/src/ty/list.rs b/compiler/rustc_middle/src/ty/list.rs index adba7d13159..197dc9205b4 100644 --- a/compiler/rustc_middle/src/ty/list.rs +++ b/compiler/rustc_middle/src/ty/list.rs @@ -61,6 +61,10 @@ impl<T> List<T> { static EMPTY_SLICE: InOrder<usize, MaxAlign> = InOrder(0, MaxAlign); unsafe { &*(&EMPTY_SLICE as *const _ as *const List<T>) } } + + pub fn len(&self) -> usize { + self.len + } } impl<T: Copy> List<T> { diff --git a/compiler/rustc_target/src/abi/call/mod.rs b/compiler/rustc_target/src/abi/call/mod.rs index 34324a58297..e6e98764737 100644 --- a/compiler/rustc_target/src/abi/call/mod.rs +++ b/compiler/rustc_target/src/abi/call/mod.rs @@ -696,7 +696,13 @@ impl<'a, Ty> FnAbi<'a, Ty> { "sparc" => sparc::compute_abi_info(cx, self), "sparc64" => sparc64::compute_abi_info(cx, self), "nvptx" => nvptx::compute_abi_info(self), - "nvptx64" => nvptx64::compute_abi_info(self), + "nvptx64" => { + if cx.target_spec().adjust_abi(abi) == spec::abi::Abi::PtxKernel { + nvptx64::compute_ptx_kernel_abi_info(cx, self) + } else { + nvptx64::compute_abi_info(self) + } + } "hexagon" => hexagon::compute_abi_info(self), "riscv32" | "riscv64" => riscv::compute_abi_info(cx, self), "wasm32" | "wasm64" => { diff --git a/compiler/rustc_target/src/abi/call/nvptx64.rs b/compiler/rustc_target/src/abi/call/nvptx64.rs index 16f331b16d5..fc16f1c97a4 100644 --- a/compiler/rustc_target/src/abi/call/nvptx64.rs +++ b/compiler/rustc_target/src/abi/call/nvptx64.rs @@ -1,21 +1,35 @@ -// Reference: PTX Writer's Guide to Interoperability -// https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability - -use crate::abi::call::{ArgAbi, FnAbi}; +use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform}; +use crate::abi::{HasDataLayout, TyAbiInterface}; fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) { if ret.layout.is_aggregate() && ret.layout.size.bits() > 64 { ret.make_indirect(); - } else { - ret.extend_integer_width_to(64); } } fn classify_arg<Ty>(arg: &mut ArgAbi<'_, Ty>) { if arg.layout.is_aggregate() && arg.layout.size.bits() > 64 { arg.make_indirect(); - } else { - arg.extend_integer_width_to(64); + } +} + +fn classify_arg_kernel<'a, Ty, C>(_cx: &C, arg: &mut ArgAbi<'a, Ty>) +where + Ty: TyAbiInterface<'a, C> + Copy, + C: HasDataLayout, +{ + if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) { + let align_bytes = arg.layout.align.abi.bytes(); + + let unit = match align_bytes { + 1 => Reg::i8(), + 2 => Reg::i16(), + 4 => Reg::i32(), + 8 => Reg::i64(), + 16 => Reg::i128(), + _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"), + }; + arg.cast_to(Uniform { unit, total: Size::from_bytes(2 * align_bytes) }); } } @@ -31,3 +45,20 @@ pub fn compute_abi_info<Ty>(fn_abi: &mut FnAbi<'_, Ty>) { classify_arg(arg); } } + +pub fn compute_ptx_kernel_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>) +where + Ty: TyAbiInterface<'a, C> + Copy, + C: HasDataLayout, +{ + if !fn_abi.ret.layout.is_unit() && !fn_abi.ret.layout.is_never() { + panic!("Kernels should not return anything other than () or !"); + } + + for arg in &mut fn_abi.args { + if arg.is_ignore() { + continue; + } + classify_arg_kernel(cx, arg); + } +} diff --git a/compiler/rustc_target/src/abi/mod.rs b/compiler/rustc_target/src/abi/mod.rs index 4ef86371298..9e6b740f6f9 100644 --- a/compiler/rustc_target/src/abi/mod.rs +++ b/compiler/rustc_target/src/abi/mod.rs @@ -1250,6 +1250,10 @@ pub trait TyAbiInterface<'a, C>: Sized { cx: &C, offset: Size, ) -> Option<PointeeInfo>; + fn is_adt(this: TyAndLayout<'a, Self>) -> bool; + fn is_never(this: TyAndLayout<'a, Self>) -> bool; + fn is_tuple(this: TyAndLayout<'a, Self>) -> bool; + fn is_unit(this: TyAndLayout<'a, Self>) -> bool; } impl<'a, Ty> TyAndLayout<'a, Ty> { @@ -1291,6 +1295,34 @@ impl<'a, Ty> TyAndLayout<'a, Ty> { _ => false, } } + + pub fn is_adt<C>(self) -> bool + where + Ty: TyAbiInterface<'a, C>, + { + Ty::is_adt(self) + } + + pub fn is_never<C>(self) -> bool + where + Ty: TyAbiInterface<'a, C>, + { + Ty::is_never(self) + } + + pub fn is_tuple<C>(self) -> bool + where + Ty: TyAbiInterface<'a, C>, + { + Ty::is_tuple(self) + } + + pub fn is_unit<C>(self) -> bool + where + Ty: TyAbiInterface<'a, C>, + { + Ty::is_unit(self) + } } impl<'a, Ty> TyAndLayout<'a, Ty> { |
