diff options
| author | 许杰友 Jieyou Xu (Joe) <39484203+jieyouxu@users.noreply.github.com> | 2025-07-22 00:54:24 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-22 00:54:24 +0800 |
| commit | 5e3eb2512591df0cef52404f0ea4202f58935a54 (patch) | |
| tree | 69324c15901f5c3341342f35ba348de15b24bea5 | |
| parent | 3f9f20f71dd945fe7d044e274094a53c90788269 (diff) | |
| parent | c068599173f8670caeeda252f115cc28445f7df0 (diff) | |
| download | rust-5e3eb2512591df0cef52404f0ea4202f58935a54.tar.gz rust-5e3eb2512591df0cef52404f0ea4202f58935a54.zip | |
Rollup merge of #142097 - ZuseZ4:offload-host1, r=oli-obk
gpu offload host code generation r? ghost This will generate most of the host side code to use llvm's offload feature. The first PR will only handle automatic mem-transfers to and from the device. So if a user calls a kernel, we will copy inputs back and forth, but we won't do the actual kernel launch. Before merging, we will use LLVM's Info infrastructure to verify that the memcopies match what openmp offloa generates in C++. `LIBOMPTARGET_INFO=-1 ./my_rust_binary` should print that a memcpy to and later from the device is happening. A follow-up PR will generate the actual device-side kernel which will then do computations on the GPU. A third PR will implement manual host2device and device2host functionality, but the goal is to minimize cases where a user has to overwrite our default handling due to performance issues. I'm trying to get a full MVP out first, so this just recognizes GPU functions based on magic names. The final frontend will obviously move this over to use proper macros, like I'm already doing it for the autodiff work. This work will also be compatible with std::autodiff, so one can differentiate GPU kernels. Tracking: - https://github.com/rust-lang/rust/issues/131513
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/lto.rs | 7 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder.rs | 69 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs | 439 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/common.rs | 9 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/context.rs | 18 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/declare.rs | 6 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/lib.rs | 22 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 10 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 8 | ||||
| -rw-r--r-- | compiler/rustc_codegen_ssa/src/back/write.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_interface/src/tests.rs | 9 | ||||
| -rw-r--r-- | compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp | 37 | ||||
| -rw-r--r-- | compiler/rustc_session/src/config.rs | 19 | ||||
| -rw-r--r-- | compiler/rustc_session/src/options.rs | 27 | ||||
| -rw-r--r-- | src/doc/unstable-book/src/compiler-flags/offload.md | 8 | ||||
| m--------- | src/tools/enzyme | 0 | ||||
| -rw-r--r-- | tests/codegen/gpu_offload/gpu_host.rs | 80 |
17 files changed, 754 insertions, 16 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 655e1c95373..84302009da9 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -654,6 +654,7 @@ pub(crate) fn run_pass_manager( // We then run the llvm_optimize function a second time, to optimize the code which we generated // in the enzyme differentiation pass. let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable); + let enable_gpu = config.offload.contains(&config::Offload::Enable); let stage = if thin { write::AutodiffStage::PreAD } else { @@ -668,6 +669,12 @@ pub(crate) fn run_pass_manager( write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?; } + if enable_gpu && !thin { + let cx = + SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); + crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx); + } + if cfg!(llvm_enzyme) && enable_ad && !thin { let cx = SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 514923ad6f3..0ade9edb0d2 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -3,6 +3,7 @@ use std::ops::Deref; use std::{iter, ptr}; pub(crate) mod autodiff; +pub(crate) mod gpu_offload; use libc::{c_char, c_uint, size_t}; use rustc_abi as abi; @@ -117,6 +118,74 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> { } bx } + + // The generic builder has less functionality and thus (unlike the other alloca) we can not + // easily jump to the beginning of the function to place our allocas there. We trust the user + // to manually do that. FIXME(offload): improve the genericCx and add more llvm wrappers to + // handle this. + pub(crate) fn direct_alloca(&mut self, ty: &'ll Type, align: Align, name: &str) -> &'ll Value { + let val = unsafe { + let alloca = llvm::LLVMBuildAlloca(self.llbuilder, ty, UNNAMED); + llvm::LLVMSetAlignment(alloca, align.bytes() as c_uint); + // Cast to default addrspace if necessary + llvm::LLVMBuildPointerCast(self.llbuilder, alloca, self.cx.type_ptr(), UNNAMED) + }; + if name != "" { + let name = std::ffi::CString::new(name).unwrap(); + llvm::set_value_name(val, &name.as_bytes()); + } + val + } + + pub(crate) fn inbounds_gep( + &mut self, + ty: &'ll Type, + ptr: &'ll Value, + indices: &[&'ll Value], + ) -> &'ll Value { + unsafe { + llvm::LLVMBuildGEPWithNoWrapFlags( + self.llbuilder, + ty, + ptr, + indices.as_ptr(), + indices.len() as c_uint, + UNNAMED, + GEPNoWrapFlags::InBounds, + ) + } + } + + pub(crate) fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align) -> &'ll Value { + debug!("Store {:?} -> {:?}", val, ptr); + assert_eq!(self.cx.type_kind(self.cx.val_ty(ptr)), TypeKind::Pointer); + unsafe { + let store = llvm::LLVMBuildStore(self.llbuilder, val, ptr); + llvm::LLVMSetAlignment(store, align.bytes() as c_uint); + store + } + } + + pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value { + unsafe { + let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED); + llvm::LLVMSetAlignment(load, align.bytes() as c_uint); + load + } + } + + fn memset(&mut self, ptr: &'ll Value, fill_byte: &'ll Value, size: &'ll Value, align: Align) { + unsafe { + llvm::LLVMRustBuildMemSet( + self.llbuilder, + ptr, + align.bytes() as c_uint, + fill_byte, + size, + false, + ); + } + } } /// Empty string, to be used where LLVM expects an instruction name, indicating diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs new file mode 100644 index 00000000000..1280ab1442a --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -0,0 +1,439 @@ +use std::ffi::CString; + +use llvm::Linkage::*; +use rustc_abi::Align; +use rustc_codegen_ssa::back::write::CodegenContext; +use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; + +use crate::builder::SBuilder; +use crate::common::AsCCharPtr; +use crate::llvm::AttributePlace::Function; +use crate::llvm::{self, Linkage, Type, Value}; +use crate::{LlvmCodegenBackend, SimpleCx, attributes}; + +pub(crate) fn handle_gpu_code<'ll>( + _cgcx: &CodegenContext<LlvmCodegenBackend>, + cx: &'ll SimpleCx<'_>, +) { + // The offload memory transfer type for each kernel + let mut o_types = vec![]; + let mut kernels = vec![]; + let offload_entry_ty = add_tgt_offload_entry(&cx); + for num in 0..9 { + let kernel = cx.get_function(&format!("kernel_{num}")); + if let Some(kernel) = kernel { + o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num)); + kernels.push(kernel); + } + } + + gen_call_handling(&cx, &kernels, &o_types); +} + +// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper: +// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 +// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 +fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { + // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 + let unknown_txt = ";unknown;unknown;0;0;;"; + let c_entry_name = CString::new(unknown_txt).unwrap(); + let c_val = c_entry_name.as_bytes_with_nul(); + let initializer = crate::common::bytes_in_context(cx.llcx, c_val); + let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage); + llvm::set_alignment(at_zero, Align::ONE); + + // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 + let struct_ident_ty = cx.type_named_struct("struct.ident_t"); + let struct_elems = vec![ + cx.get_const_i32(0), + cx.get_const_i32(2), + cx.get_const_i32(0), + cx.get_const_i32(22), + at_zero, + ]; + let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect(); + let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems); + cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false); + let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage); + llvm::set_alignment(at_one, Align::EIGHT); + at_one +} + +pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { + let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let ti16 = cx.type_i16(); + // For each kernel to run on the gpu, we will later generate one entry of this type. + // copied from LLVM + // typedef struct { + // uint64_t Reserved; + // uint16_t Version; + // uint16_t Kind; + // uint32_t Flags; Flags associated with the entry (see Target Region Entry Flags) + // void *Address; Address of global symbol within device image (function or global) + // char *SymbolName; + // uint64_t Size; Size of the entry info (0 if it is a function) + // uint64_t Data; + // void *AuxAddr; + // } __tgt_offload_entry; + let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]; + cx.set_struct_body(offload_entry_ty, &entry_elements, false); + offload_entry_ty +} + +fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) { + let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let tarr = cx.type_array(ti32, 3); + + // Taken from the LLVM APITypes.h declaration: + //struct KernelArgsTy { + // uint32_t Version = 0; // Version of this struct for ABI compatibility. + // uint32_t NumArgs = 0; // Number of arguments in each input pointer. + // void **ArgBasePtrs = + // nullptr; // Base pointer of each argument (e.g. a struct). + // void **ArgPtrs = nullptr; // Pointer to the argument data. + // int64_t *ArgSizes = nullptr; // Size of the argument data in bytes. + // int64_t *ArgTypes = nullptr; // Type of the data (e.g. to / from). + // void **ArgNames = nullptr; // Name of the data for debugging, possibly null. + // void **ArgMappers = nullptr; // User-defined mappers, possibly null. + // uint64_t Tripcount = + // 0; // Tripcount for the teams / distribute loop, 0 otherwise. + // struct { + // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause. + // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA. + // uint64_t Unused : 62; + // } Flags = {0, 0, 0}; + // // The number of teams (for x,y,z dimension). + // uint32_t NumTeams[3] = {0, 0, 0}; + // // The number of threads (for x,y,z dimension). + // uint32_t ThreadLimit[3] = {0, 0, 0}; + // uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested. + //}; + let kernel_elements = + vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; + + cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); + // For now we don't handle kernels, so for now we just add a global dummy + // to make sure that the __tgt_offload_entry is defined and handled correctly. + cx.declare_global("my_struct_global2", kernel_arguments_ty); +} + +fn gen_tgt_data_mappers<'ll>( + cx: &'ll SimpleCx<'_>, +) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) { + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + + let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr]; + let mapper_fn_ty = cx.type_func(&args, cx.type_void()); + let mapper_begin = "__tgt_target_data_begin_mapper"; + let mapper_update = "__tgt_target_data_update_mapper"; + let mapper_end = "__tgt_target_data_end_mapper"; + let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty); + let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty); + let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty); + + let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx); + attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]); + attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]); + attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]); + + (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty) +} + +fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value { + let ti64 = cx.type_i64(); + let mut size_val = Vec::with_capacity(vals.len()); + for &val in vals { + size_val.push(cx.get_const_i64(val)); + } + let initializer = cx.const_array(ti64, &size_val); + add_unnamed_global(cx, name, initializer, PrivateLinkage) +} + +pub(crate) fn add_unnamed_global<'ll>( + cx: &SimpleCx<'ll>, + name: &str, + initializer: &'ll llvm::Value, + l: Linkage, +) -> &'ll llvm::Value { + let llglobal = add_global(cx, name, initializer, l); + llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global); + llglobal +} + +pub(crate) fn add_global<'ll>( + cx: &SimpleCx<'ll>, + name: &str, + initializer: &'ll llvm::Value, + l: Linkage, +) -> &'ll llvm::Value { + let c_name = CString::new(name).unwrap(); + let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name); + llvm::set_global_constant(llglobal, true); + llvm::set_linkage(llglobal, l); + llvm::set_initializer(llglobal, initializer); + llglobal +} + +fn gen_define_handling<'ll>( + cx: &'ll SimpleCx<'_>, + kernel: &'ll llvm::Value, + offload_entry_ty: &'ll llvm::Type, + num: i64, +) -> &'ll llvm::Value { + let types = cx.func_params_types(cx.get_type_of_global(kernel)); + // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or + // reference) types. + let num_ptr_types = types + .iter() + .map(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer)) + .count(); + + // We do not know their size anymore at this level, so hardcode a placeholder. + // A follow-up pr will track these from the frontend, where we still have Rust types. + // Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes. + // I decided that 1024 bytes is a great placeholder value for now. + add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]); + // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2), + // or both to and from the gpu (=3). Other values shouldn't affect us for now. + // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten + // will be 2. For now, everything is 3, until we have our frontend set up. + let o_types = + add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3; num_ptr_types]); + // Next: For each function, generate these three entries. A weak constant, + // the llvm.rodata entry name, and the omp_offloading_entries value + + let name = format!(".kernel_{num}.region_id"); + let initializer = cx.get_const_i8(0); + let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage); + + let c_entry_name = CString::new(format!("kernel_{num}")).unwrap(); + let c_val = c_entry_name.as_bytes_with_nul(); + let offload_entry_name = format!(".offloading.entry_name.{num}"); + + let initializer = crate::common::bytes_in_context(cx.llcx, c_val); + let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage); + llvm::set_alignment(llglobal, Align::ONE); + llvm::set_section(llglobal, c".llvm.rodata.offloading"); + + // Not actively used yet, for calling real kernels + let name = format!(".offloading.entry.kernel_{num}"); + + // See the __tgt_offload_entry documentation above. + let reserved = cx.get_const_i64(0); + let version = cx.get_const_i16(1); + let kind = cx.get_const_i16(1); + let flags = cx.get_const_i32(0); + let size = cx.get_const_i64(0); + let data = cx.get_const_i64(0); + let aux_addr = cx.const_null(cx.type_ptr()); + let elems = vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]; + + let initializer = crate::common::named_struct(offload_entry_ty, &elems); + let c_name = CString::new(name).unwrap(); + let llglobal = llvm::add_global(cx.llmod, offload_entry_ty, &c_name); + llvm::set_global_constant(llglobal, true); + llvm::set_linkage(llglobal, WeakAnyLinkage); + llvm::set_initializer(llglobal, initializer); + llvm::set_alignment(llglobal, Align::ONE); + let c_section_name = CString::new(".omp_offloading_entries").unwrap(); + llvm::set_section(llglobal, &c_section_name); + o_types +} + +fn declare_offload_fn<'ll>( + cx: &'ll SimpleCx<'_>, + name: &str, + ty: &'ll llvm::Type, +) -> &'ll llvm::Value { + crate::declare::declare_simple_fn( + cx, + name, + llvm::CallConv::CCallConv, + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + ty, + ) +} + +// For each kernel *call*, we now use some of our previous declared globals to move data to and from +// the gpu. We don't have a proper frontend yet, so we assume that every call to a kernel function +// from main is intended to run on the GPU. For now, we only handle the data transfer part of it. +// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu. +// Since in our frontend users (by default) don't have to specify data transfer, this is something +// we should optimize in the future! We also assume that everything should be copied back and forth, +// but sometimes we can directly zero-allocate on the device and only move back, or if something is +// immutable, we might only copy it to the device, but not back. +// +// Current steps: +// 0. Alloca some variables for the following steps +// 1. set insert point before kernel call. +// 2. generate all the GEPS and stores, to be used in 3) +// 3. generate __tgt_target_data_begin calls to move data to the GPU +// +// unchanged: keep kernel call. Later move the kernel to the GPU +// +// 4. set insert point after kernel call. +// 5. generate all the GEPS and stores, to be used in 6) +// 6. generate __tgt_target_data_end calls to move data from the GPU +fn gen_call_handling<'ll>( + cx: &'ll SimpleCx<'_>, + _kernels: &[&'ll llvm::Value], + o_types: &[&'ll llvm::Value], +) { + // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } + let tptr = cx.type_ptr(); + let ti32 = cx.type_i32(); + let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr]; + let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc"); + cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false); + + gen_tgt_kernel_global(&cx); + let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx); + + let main_fn = cx.get_function("main"); + let Some(main_fn) = main_fn else { return }; + let kernel_name = "kernel_1"; + let call = unsafe { + llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len()) + }; + let Some(kernel_call) = call else { + return; + }; + let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) }; + let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() }; + let mut builder = SBuilder::build(cx, kernel_call_bb); + + let types = cx.func_params_types(cx.get_type_of_global(called)); + let num_args = types.len() as u64; + + // Step 0) + // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } + // %6 = alloca %struct.__tgt_bin_desc, align 8 + unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) }; + + let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc"); + + let ty = cx.type_array(cx.type_ptr(), num_args); + // Baseptr are just the input pointer to the kernel, stored in a local alloca + let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs"); + // Ptrs are the result of a gep into the baseptr, at least for our trivial types. + let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs"); + // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16. + let ty2 = cx.type_array(cx.type_i64(), num_args); + let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes"); + // Now we allocate once per function param, a copy to be passed to one of our maps. + let mut vals = vec![]; + let mut geps = vec![]; + let i32_0 = cx.get_const_i32(0); + for (index, in_ty) in types.iter().enumerate() { + // get function arg, store it into the alloca, and read it. + let p = llvm::get_param(called, index as u32); + let name = llvm::get_value_name(p); + let name = str::from_utf8(&name).unwrap(); + let arg_name = format!("{name}.addr"); + let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name); + + builder.store(p, alloca, Align::EIGHT); + let val = builder.load(in_ty, alloca, Align::EIGHT); + let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]); + vals.push(val); + geps.push(gep); + } + + // Step 1) + unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) }; + builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT); + + let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void()); + let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty); + let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty); + let init_ty = cx.type_func(&[], cx.type_void()); + let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty); + + // call void @__tgt_register_lib(ptr noundef %6) + builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None); + // call void @__tgt_init_all_rtls() + builder.call(init_ty, init_rtls_decl, &[], None); + + for i in 0..num_args { + let idx = cx.get_const_i32(i); + let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]); + builder.store(vals[i as usize], gep1, Align::EIGHT); + let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]); + builder.store(geps[i as usize], gep2, Align::EIGHT); + let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]); + // As mentioned above, we don't use Rust type information yet. So for now we will just + // assume that we have 1024 bytes, 256 f32 values. + // FIXME(offload): write an offload frontend and handle arbitrary types. + builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT); + } + + // For now we have a very simplistic indexing scheme into our + // offload_{baseptrs,ptrs,sizes}. We will probably improve this along with our gpu frontend pr. + fn get_geps<'a, 'll>( + builder: &mut SBuilder<'a, 'll>, + cx: &'ll SimpleCx<'ll>, + ty: &'ll Type, + ty2: &'ll Type, + a1: &'ll Value, + a2: &'ll Value, + a4: &'ll Value, + ) -> (&'ll Value, &'ll Value, &'ll Value) { + let i32_0 = cx.get_const_i32(0); + + let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]); + let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]); + let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]); + (gep1, gep2, gep3) + } + + fn generate_mapper_call<'a, 'll>( + builder: &mut SBuilder<'a, 'll>, + cx: &'ll SimpleCx<'ll>, + geps: (&'ll Value, &'ll Value, &'ll Value), + o_type: &'ll Value, + fn_to_call: &'ll Value, + fn_ty: &'ll Type, + num_args: u64, + s_ident_t: &'ll Value, + ) { + let nullptr = cx.const_null(cx.type_ptr()); + let i64_max = cx.get_const_i64(u64::MAX); + let num_args = cx.get_const_i32(num_args); + let args = + vec![s_ident_t, i64_max, num_args, geps.0, geps.1, geps.2, o_type, nullptr, nullptr]; + builder.call(fn_ty, fn_to_call, &args, None); + } + + // Step 2) + let s_ident_t = generate_at_one(&cx); + let o = o_types[0]; + let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); + generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t); + + // Step 3) + // Here we will add code for the actual kernel launches in a follow-up PR. + // FIXME(offload): launch kernels + + // Step 4) + unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) }; + + let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); + generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t); + + builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None); + + // With this we generated the following begin and end mappers. We could easily generate the + // update mapper in an update. + // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null) + // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null) + // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null) +} diff --git a/compiler/rustc_codegen_llvm/src/common.rs b/compiler/rustc_codegen_llvm/src/common.rs index f9ab96b5789..f29fefb66f0 100644 --- a/compiler/rustc_codegen_llvm/src/common.rs +++ b/compiler/rustc_codegen_llvm/src/common.rs @@ -118,6 +118,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> { r } } + + pub(crate) fn const_null(&self, t: &'ll Type) -> &'ll Value { + unsafe { llvm::LLVMConstNull(t) } + } } impl<'ll, 'tcx> ConstCodegenMethods for CodegenCx<'ll, 'tcx> { @@ -377,6 +381,11 @@ pub(crate) fn bytes_in_context<'ll>(llcx: &'ll llvm::Context, bytes: &[u8]) -> & } } +pub(crate) fn named_struct<'ll>(ty: &'ll Type, elts: &[&'ll Value]) -> &'ll Value { + let len = c_uint::try_from(elts.len()).expect("LLVMConstStructInContext elements len overflow"); + unsafe { llvm::LLVMConstNamedStruct(ty, elts.as_ptr(), len) } +} + fn struct_in_context<'ll>( llcx: &'ll llvm::Context, elts: &[&'ll Value], diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 34bed2a1d2a..ee77774c688 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -216,7 +216,7 @@ pub(crate) unsafe fn create_module<'ll>( // Ensure the data-layout values hardcoded remain the defaults. { - let tm = crate::back::write::create_informational_target_machine(tcx.sess, false); + let tm = crate::back::write::create_informational_target_machine(sess, false); unsafe { llvm::LLVMRustSetDataLayoutFromTargetMachine(llmod, tm.raw()); } @@ -685,6 +685,22 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> { unsafe { llvm::LLVMConstInt(ty, val, llvm::False) } } + pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value { + self.get_const_int(self.type_i64(), n) + } + + pub(crate) fn get_const_i32(&self, n: u64) -> &'ll Value { + self.get_const_int(self.type_i32(), n) + } + + pub(crate) fn get_const_i16(&self, n: u64) -> &'ll Value { + self.get_const_int(self.type_i16(), n) + } + + pub(crate) fn get_const_i8(&self, n: u64) -> &'ll Value { + self.get_const_int(self.type_i8(), n) + } + pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> { let name = SmallCStr::new(name); unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) } diff --git a/compiler/rustc_codegen_llvm/src/declare.rs b/compiler/rustc_codegen_llvm/src/declare.rs index eb75716d768..960a895a203 100644 --- a/compiler/rustc_codegen_llvm/src/declare.rs +++ b/compiler/rustc_codegen_llvm/src/declare.rs @@ -215,7 +215,9 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { llfn } +} +impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> { /// Declare a global with an intention to define it. /// /// Use this function when you intend to define a global. This function will @@ -234,13 +236,13 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { /// /// Use this function when you intend to define a global without a name. pub(crate) fn define_private_global(&self, ty: &'ll Type) -> &'ll Value { - unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod, ty) } + unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod(), ty) } } /// Gets declared value by name. pub(crate) fn get_declared_value(&self, name: &str) -> Option<&'ll Value> { debug!("get_declared_value(name={:?})", name); - unsafe { llvm::LLVMRustGetNamedValue(self.llmod, name.as_c_char_ptr(), name.len()) } + unsafe { llvm::LLVMRustGetNamedValue(self.llmod(), name.as_c_char_ptr(), name.len()) } } /// Gets defined or externally defined (AvailableExternally linkage) value by diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 6db4e122ad6..aaf21f9ada9 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -412,6 +412,20 @@ impl ModuleLlvm { } } + fn tm_from_cgcx( + cgcx: &CodegenContext<LlvmCodegenBackend>, + name: &str, + dcx: DiagCtxtHandle<'_>, + ) -> Result<OwnedTargetMachine, FatalError> { + let tm_factory_config = TargetMachineFactoryConfig::new(cgcx, name); + match (cgcx.tm_factory)(tm_factory_config) { + Ok(m) => Ok(m), + Err(e) => { + return Err(dcx.emit_almost_fatal(ParseTargetMachineConfig(e))); + } + } + } + fn parse( cgcx: &CodegenContext<LlvmCodegenBackend>, name: &CStr, @@ -421,13 +435,7 @@ impl ModuleLlvm { unsafe { let llcx = llvm::LLVMRustContextCreate(cgcx.fewer_names); let llmod_raw = back::lto::parse_module(llcx, name, buffer, dcx)?; - let tm_factory_config = TargetMachineFactoryConfig::new(cgcx, name.to_str().unwrap()); - let tm = match (cgcx.tm_factory)(tm_factory_config) { - Ok(m) => m, - Err(e) => { - return Err(dcx.emit_almost_fatal(ParseTargetMachineConfig(e))); - } - }; + let tm = ModuleLlvm::tm_from_cgcx(cgcx, name.to_str().unwrap(), dcx)?; Ok(ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }) } diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index c696b8d8ff2..56d756e52cc 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -4,7 +4,7 @@ use libc::{c_char, c_uint}; use super::MetadataKindId; use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value}; -use crate::llvm::Bool; +use crate::llvm::{Bool, Builder}; #[link(name = "llvm-wrapper", kind = "static")] unsafe extern "C" { @@ -31,6 +31,14 @@ unsafe extern "C" { index: c_uint, kind: AttributeKind, ); + pub(crate) fn LLVMRustPositionBefore<'a>(B: &'a Builder<'_>, I: &'a Value); + pub(crate) fn LLVMRustPositionAfter<'a>(B: &'a Builder<'_>, I: &'a Value); + pub(crate) fn LLVMRustGetFunctionCall( + F: &Value, + name: *const c_char, + NameLen: libc::size_t, + ) -> Option<&Value>; + } unsafe extern "C" { diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 80a0e5c5acc..edfb29dd1be 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1138,6 +1138,11 @@ unsafe extern "C" { Count: c_uint, Packed: Bool, ) -> &'a Value; + pub(crate) fn LLVMConstNamedStruct<'a>( + StructTy: &'a Type, + ConstantVals: *const &'a Value, + Count: c_uint, + ) -> &'a Value; pub(crate) fn LLVMConstVector(ScalarConstantVals: *const &Value, Size: c_uint) -> &Value; // Constant expressions @@ -1217,6 +1222,8 @@ unsafe extern "C" { ) -> &'a BasicBlock; // Operations on instructions + pub(crate) fn LLVMGetInstructionParent(Inst: &Value) -> &BasicBlock; + pub(crate) fn LLVMGetCalledValue(CallInst: &Value) -> Option<&Value>; pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>; pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock; pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>; @@ -2557,6 +2564,7 @@ unsafe extern "C" { pub(crate) fn LLVMRustSetDataLayoutFromTargetMachine<'a>(M: &'a Module, TM: &'a TargetMachine); + pub(crate) fn LLVMRustPositionBuilderPastAllocas<'a>(B: &Builder<'a>, Fn: &'a Value); pub(crate) fn LLVMRustPositionBuilderAtStart<'a>(B: &Builder<'a>, BB: &'a BasicBlock); pub(crate) fn LLVMRustSetModulePICLevel(M: &Module); diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 50a7cba300b..24e0a4eb533 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -120,6 +120,7 @@ pub struct ModuleConfig { pub emit_lifetime_markers: bool, pub llvm_plugins: Vec<String>, pub autodiff: Vec<config::AutoDiff>, + pub offload: Vec<config::Offload>, } impl ModuleConfig { @@ -268,6 +269,7 @@ impl ModuleConfig { emit_lifetime_markers: sess.emit_lifetime_markers(), llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]), autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]), + offload: if_regular!(sess.opts.unstable_opts.offload.clone(), vec![]), } } diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 360b5629e9d..8771bb44050 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -13,10 +13,10 @@ use rustc_session::config::{ CoverageOptions, DebugInfo, DumpMonoStatsFormat, ErrorOutputType, ExternEntry, ExternLocation, Externs, FmtDebug, FunctionReturn, InliningThreshold, Input, InstrumentCoverage, InstrumentXRay, LinkSelfContained, LinkerPluginLto, LocationDetail, LtoCli, MirIncludeSpans, - NextSolverConfig, OomStrategy, Options, OutFileName, OutputType, OutputTypes, PAuthKey, PacRet, - Passes, PatchableFunctionEntry, Polonius, ProcMacroExecutionStrategy, Strip, SwitchWithOptPath, - SymbolManglingVersion, WasiExecModel, build_configuration, build_session_options, - rustc_optgroups, + NextSolverConfig, Offload, OomStrategy, Options, OutFileName, OutputType, OutputTypes, + PAuthKey, PacRet, Passes, PatchableFunctionEntry, Polonius, ProcMacroExecutionStrategy, Strip, + SwitchWithOptPath, SymbolManglingVersion, WasiExecModel, build_configuration, + build_session_options, rustc_optgroups, }; use rustc_session::lint::Level; use rustc_session::search_paths::SearchPath; @@ -833,6 +833,7 @@ fn test_unstable_options_tracking_hash() { tracked!(no_profiler_runtime, true); tracked!(no_trait_vptr, true); tracked!(no_unique_section_names, true); + tracked!(offload, vec![Offload::Enable]); tracked!(on_broken_pipe, OnBrokenPipe::Kill); tracked!(oom, OomStrategy::Panic); tracked!(osx_rpath_install_name, true); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 90aa9188c83..82568ed4ae1 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -1591,12 +1591,49 @@ extern "C" LLVMValueRef LLVMRustBuildMemSet(LLVMBuilderRef B, LLVMValueRef Dst, MaybeAlign(DstAlign), IsVolatile)); } +extern "C" void LLVMRustPositionBuilderPastAllocas(LLVMBuilderRef B, + LLVMValueRef Fn) { + Function *F = unwrap<Function>(Fn); + unwrap(B)->SetInsertPointPastAllocas(F); +} extern "C" void LLVMRustPositionBuilderAtStart(LLVMBuilderRef B, LLVMBasicBlockRef BB) { auto Point = unwrap(BB)->getFirstInsertionPt(); unwrap(B)->SetInsertPoint(unwrap(BB), Point); } +extern "C" void LLVMRustPositionBefore(LLVMBuilderRef B, LLVMValueRef Instr) { + if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) { + unwrap(B)->SetInsertPoint(I); + } +} + +extern "C" void LLVMRustPositionAfter(LLVMBuilderRef B, LLVMValueRef Instr) { + if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) { + auto J = I->getNextNonDebugInstruction(); + unwrap(B)->SetInsertPoint(J); + } +} + +extern "C" LLVMValueRef +LLVMRustGetFunctionCall(LLVMValueRef Fn, const char *Name, size_t NameLen) { + auto targetName = StringRef(Name, NameLen); + Function *F = unwrap<Function>(Fn); + for (auto &BB : *F) { + for (auto &I : BB) { + if (auto *callInst = llvm::dyn_cast<llvm::CallBase>(&I)) { + const llvm::Function *calledFunc = callInst->getCalledFunction(); + if (calledFunc && calledFunc->getName() == targetName) { + // Found a call to the target function + return wrap(callInst); + } + } + } + } + + return nullptr; +} + extern "C" bool LLVMRustConstIntGetZExtValue(LLVMValueRef CV, uint64_t *value) { auto C = unwrap<llvm::ConstantInt>(CV); if (C->getBitWidth() > 64) diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index d6215e1de04..7bea8685724 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -226,6 +226,13 @@ pub enum CoverageLevel { Mcdc, } +// The different settings that the `-Z offload` flag can have. +#[derive(Clone, Copy, PartialEq, Hash, Debug)] +pub enum Offload { + /// Enable the llvm offload pipeline + Enable, +} + /// The different settings that the `-Z autodiff` flag can have. #[derive(Clone, PartialEq, Hash, Debug)] pub enum AutoDiff { @@ -2706,6 +2713,15 @@ pub fn build_session_options(early_dcx: &mut EarlyDiagCtxt, matches: &getopts::M ) } + if !nightly_options::is_unstable_enabled(matches) + && unstable_opts.offload.contains(&Offload::Enable) + { + early_dcx.early_fatal( + "`-Zoffload=Enable` also requires `-Zunstable-options` \ + and a nightly compiler", + ) + } + let target_triple = parse_target_triple(early_dcx, matches); // Ensure `-Z unstable-options` is required when using the unstable `-C link-self-contained` and @@ -3178,7 +3194,7 @@ pub(crate) mod dep_tracking { AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions, CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn, InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail, - LtoCli, MirStripDebugInfo, NextSolverConfig, OomStrategy, OptLevel, OutFileName, + LtoCli, MirStripDebugInfo, NextSolverConfig, Offload, OomStrategy, OptLevel, OutFileName, OutputType, OutputTypes, PatchableFunctionEntry, Polonius, RemapPathScopeComponents, ResolveDocLinks, SourceFileHashAlgorithm, SplitDwarfKind, SwitchWithOptPath, SymbolManglingVersion, WasiExecModel, @@ -3225,6 +3241,7 @@ pub(crate) mod dep_tracking { impl_dep_tracking_hash_via_hash!( (), AutoDiff, + Offload, bool, usize, NonZero<usize>, diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 2bdde2f887a..b33e3815ea4 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -726,6 +726,7 @@ mod desc { pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`"; + pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -1357,6 +1358,27 @@ pub mod parse { } } + pub(crate) fn parse_offload(slot: &mut Vec<Offload>, v: Option<&str>) -> bool { + let Some(v) = v else { + *slot = vec![]; + return true; + }; + let mut v: Vec<&str> = v.split(",").collect(); + v.sort_unstable(); + for &val in v.iter() { + let variant = match val { + "Enable" => Offload::Enable, + _ => { + // FIXME(ZuseZ4): print an error saying which value is not recognized + return false; + } + }; + slot.push(variant); + } + + true + } + pub(crate) fn parse_autodiff(slot: &mut Vec<AutoDiff>, v: Option<&str>) -> bool { let Some(v) = v else { *slot = vec![]; @@ -2401,6 +2423,11 @@ options! { "do not use unique names for text and data sections when -Z function-sections is used"), normalize_docs: bool = (false, parse_bool, [TRACKED], "normalize associated items in rustdoc when generating documentation"), + offload: Vec<crate::config::Offload> = (Vec::new(), parse_offload, [TRACKED], + "a list of offload flags to enable + Mandatory setting: + `=Enable` + Currently the only option available"), on_broken_pipe: OnBrokenPipe = (OnBrokenPipe::Default, parse_on_broken_pipe, [TRACKED], "behavior of std::io::ErrorKind::BrokenPipe (SIGPIPE)"), oom: OomStrategy = (OomStrategy::Abort, parse_oom_strategy, [TRACKED], diff --git a/src/doc/unstable-book/src/compiler-flags/offload.md b/src/doc/unstable-book/src/compiler-flags/offload.md new file mode 100644 index 00000000000..4266e8c11a2 --- /dev/null +++ b/src/doc/unstable-book/src/compiler-flags/offload.md @@ -0,0 +1,8 @@ +# `offload` + +The tracking issue for this feature is: [#131513](https://github.com/rust-lang/rust/issues/131513). + +------------------------ + +This feature will later allow you to run functions on GPUs. It is work in progress. +Set the `-Zoffload=Enable` compiler flag to experiment with it. diff --git a/src/tools/enzyme b/src/tools/enzyme -Subproject b5098d515d5e1bd0f5470553bc0d18da9794ca8 +Subproject 2cccfba93c1650f26f1cf8be8aa875a7c1d23fb diff --git a/tests/codegen/gpu_offload/gpu_host.rs b/tests/codegen/gpu_offload/gpu_host.rs new file mode 100644 index 00000000000..513e27426bc --- /dev/null +++ b/tests/codegen/gpu_offload/gpu_host.rs @@ -0,0 +1,80 @@ +//@ compile-flags: -Zoffload=Enable -Zunstable-options -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// This test is verifying that we generate __tgt_target_data_*_mapper before and after a call to the +// kernel_1. Better documentation to what each global or variable means is available in the gpu +// offlaod code, or the LLVM offload documentation. This code does not launch any GPU kernels yet, +// and will be rewritten once a proper offload frontend has landed. +// +// We currently only handle memory transfer for specific calls to functions named `kernel_{num}`, +// when inside of a function called main. This, too, is a temporary workaround for not having a +// frontend. + +#![no_main] + +#[unsafe(no_mangle)] +fn main() { + let mut x = [3.0; 256]; + kernel_1(&mut x); + core::hint::black_box(&x); +} + +// CHECK: %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr } +// CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } +// CHECK: %struct.ident_t = type { i32, i32, i32, i32, ptr } +// CHECK: %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } + +// CHECK: @.offload_sizes.1 = private unnamed_addr constant [1 x i64] [i64 1024] +// CHECK: @.offload_maptypes.1 = private unnamed_addr constant [1 x i64] [i64 3] +// CHECK: @.kernel_1.region_id = weak unnamed_addr constant i8 0 +// CHECK: @.offloading.entry_name.1 = internal unnamed_addr constant [9 x i8] c"kernel_1\00", section ".llvm.rodata.offloading", align 1 +// CHECK: @.offloading.entry.kernel_1 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_1.region_id, ptr @.offloading.entry_name.1, i64 0, i64 0, ptr null }, section ".omp_offloading_entries", align 1 +// CHECK: @my_struct_global2 = external global %struct.__tgt_kernel_arguments +// CHECK: @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 +// CHECK: @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 + +// CHECK: Function Attrs: +// CHECK-NEXT: define{{( dso_local)?}} void @main() +// CHECK-NEXT: start: +// CHECK-NEXT: %0 = alloca [8 x i8], align 8 +// CHECK-NEXT: %x = alloca [1024 x i8], align 16 +// CHECK-NEXT: %EmptyDesc = alloca %struct.__tgt_bin_desc, align 8 +// CHECK-NEXT: %.offload_baseptrs = alloca [1 x ptr], align 8 +// CHECK-NEXT: %.offload_ptrs = alloca [1 x ptr], align 8 +// CHECK-NEXT: %.offload_sizes = alloca [1 x i64], align 8 +// CHECK-NEXT: %x.addr = alloca ptr, align 8 +// CHECK-NEXT: store ptr %x, ptr %x.addr, align 8 +// CHECK-NEXT: %1 = load ptr, ptr %x.addr, align 8 +// CHECK-NEXT: %2 = getelementptr inbounds float, ptr %1, i32 0 +// CHECK: call void @llvm.memset.p0.i64(ptr align 8 %EmptyDesc, i8 0, i64 32, i1 false) +// CHECK-NEXT: call void @__tgt_register_lib(ptr %EmptyDesc) +// CHECK-NEXT: call void @__tgt_init_all_rtls() +// CHECK-NEXT: %3 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: store ptr %1, ptr %3, align 8 +// CHECK-NEXT: %4 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: store ptr %2, ptr %4, align 8 +// CHECK-NEXT: %5 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: store i64 1024, ptr %5, align 8 +// CHECK-NEXT: %6 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %7 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: %8 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 1, ptr %6, ptr %7, ptr %8, ptr @.offload_maptypes.1, ptr null, ptr null) +// CHECK-NEXT: call void @kernel_1(ptr noalias noundef nonnull align 4 dereferenceable(1024) %x) +// CHECK-NEXT: %9 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %10 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: %11 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 1, ptr %9, ptr %10, ptr %11, ptr @.offload_maptypes.1, ptr null, ptr null) +// CHECK-NEXT: call void @__tgt_unregister_lib(ptr %EmptyDesc) +// CHECK: store ptr %x, ptr %0, align 8 +// CHECK-NEXT: call void asm sideeffect "", "r,~{memory}"(ptr nonnull %0) +// CHECK: ret void +// CHECK-NEXT: } + +#[unsafe(no_mangle)] +#[inline(never)] +pub fn kernel_1(x: &mut [f32; 256]) { + for i in 0..256 { + x[i] = 21.0; + } +} |
