diff options
Diffstat (limited to 'compiler/rustc_codegen_llvm')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/abi.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder.rs | 15 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/typetree.rs | 20 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/va_arg.rs | 1 | 
5 files changed, 23 insertions, 15 deletions
| diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs index 11be7041167..61fadad6066 100644 --- a/compiler/rustc_codegen_llvm/src/abi.rs +++ b/compiler/rustc_codegen_llvm/src/abi.rs @@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(copy_bytes), MemFlags::empty(), + None, ); bx.lifetime_end(llscratch, scratch_size); } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 0f17cc9063a..83a9cf620f1 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow}; use std::ops::Deref; use std::{iter, ptr}; +use rustc_ast::expand::typetree::FncTree; pub(crate) mod autodiff; pub(crate) mod gpu_offload; @@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option<FncTree>, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memcpy = unsafe { llvm::LLVMRustBuildMemCpy( self.llbuilder, dst, @@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + + // TypeTree metadata for memcpy is especially important: when Enzyme encounters + // a memcpy during autodiff, it needs to know the structure of the data being + // copied to properly track derivatives. For example, copying an array of floats + // vs. copying a struct with mixed types requires different derivative handling. + // The TypeTree tells Enzyme exactly what memory layout to expect. + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 12b6ffa3c0b..9dd84ad9a4d 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -25,6 +25,7 @@ pub(crate) enum CConcreteType { DT_Half = 3, DT_Float = 4, DT_Double = 5, + // FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600) DT_Unknown = 6, } diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 434316464e6..8e6272a186b 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -1,8 +1,11 @@ -use std::ffi::{CString, c_char, c_uint}; - -use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; +use rustc_ast::expand::typetree::FncTree; +#[cfg(llvm_enzyme)] +use { + crate::attributes, + rustc_ast::expand::typetree::TypeTree as RustTypeTree, + std::ffi::{CString, c_char, c_uint}, +}; -use crate::attributes; use crate::llvm::{self, Value}; /// Converts a Rust TypeTree to Enzyme's internal TypeTree format @@ -50,15 +53,6 @@ fn to_enzyme_typetree( enzyme_tt } -#[cfg(not(llvm_enzyme))] -fn to_enzyme_typetree( - _rust_typetree: RustTypeTree, - _data_layout: &str, - _llcx: &llvm::Context, -) -> ! { - unimplemented!("TypeTree conversion not available without llvm_enzyme support") -} - // Attaches TypeTree information to LLVM function as enzyme_type attributes. #[cfg(llvm_enzyme)] pub(crate) fn add_tt<'ll>( diff --git a/compiler/rustc_codegen_llvm/src/va_arg.rs b/compiler/rustc_codegen_llvm/src/va_arg.rs index ab08125217f..d48c7cf874a 100644 --- a/compiler/rustc_codegen_llvm/src/va_arg.rs +++ b/compiler/rustc_codegen_llvm/src/va_arg.rs @@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>( src_align, bx.const_u32(layout.layout.size().bytes() as u32), MemFlags::empty(), + None, ); tmp } else { | 
