diff options
| author | Karan Janthe <karanjanthe@gmail.com> | 2025-08-23 23:10:48 +0000 |
|---|---|---|
| committer | Karan Janthe <karanjanthe@gmail.com> | 2025-09-19 04:02:20 +0000 |
| commit | 664e83b3e76b51fb5192e74a64eef3bc5bbd4e32 (patch) | |
| tree | e777129c436cdcb8fa1cfaec619a6c8afc586395 /compiler | |
| parent | 5d3ebc3804299503387b7ea1427f1619d410c2b2 (diff) | |
| download | rust-664e83b3e76b51fb5192e74a64eef3bc5bbd4e32.tar.gz rust-664e83b3e76b51fb5192e74a64eef3bc5bbd4e32.zip | |
added typetree support for memcpy
Diffstat (limited to 'compiler')
| -rw-r--r-- | compiler/rustc_codegen_gcc/src/builder.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_codegen_gcc/src/intrinsic/mod.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/abi.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder.rs | 15 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/typetree.rs | 20 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/va_arg.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_codegen_ssa/src/mir/block.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_codegen_ssa/src/mir/intrinsic.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_codegen_ssa/src/mir/statement.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_codegen_ssa/src/traits/builder.rs | 3 | ||||
| -rw-r--r-- | compiler/rustc_interface/src/tests.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/ty/mod.rs | 4 |
13 files changed, 32 insertions, 21 deletions
diff --git a/compiler/rustc_codegen_gcc/src/builder.rs b/compiler/rustc_codegen_gcc/src/builder.rs index f7a7a3f8c7e..5657620879c 100644 --- a/compiler/rustc_codegen_gcc/src/builder.rs +++ b/compiler/rustc_codegen_gcc/src/builder.rs @@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> { _src_align: Align, size: RValue<'gcc>, flags: MemFlags, + _tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_size_t(), false); diff --git a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs index 84fa56cf903..3b897a19835 100644 --- a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs +++ b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs @@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(self.layout.size.bytes()), MemFlags::empty(), + None, ); bx.lifetime_end(scratch, scratch_size); diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs index 11be7041167..61fadad6066 100644 --- a/compiler/rustc_codegen_llvm/src/abi.rs +++ b/compiler/rustc_codegen_llvm/src/abi.rs @@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(copy_bytes), MemFlags::empty(), + None, ); bx.lifetime_end(llscratch, scratch_size); } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 0f17cc9063a..83a9cf620f1 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow}; use std::ops::Deref; use std::{iter, ptr}; +use rustc_ast::expand::typetree::FncTree; pub(crate) mod autodiff; pub(crate) mod gpu_offload; @@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option<FncTree>, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memcpy = unsafe { llvm::LLVMRustBuildMemCpy( self.llbuilder, dst, @@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + + // TypeTree metadata for memcpy is especially important: when Enzyme encounters + // a memcpy during autodiff, it needs to know the structure of the data being + // copied to properly track derivatives. For example, copying an array of floats + // vs. copying a struct with mixed types requires different derivative handling. + // The TypeTree tells Enzyme exactly what memory layout to expect. + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 12b6ffa3c0b..9dd84ad9a4d 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -25,6 +25,7 @@ pub(crate) enum CConcreteType { DT_Half = 3, DT_Float = 4, DT_Double = 5, + // FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600) DT_Unknown = 6, } diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 434316464e6..8e6272a186b 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -1,8 +1,11 @@ -use std::ffi::{CString, c_char, c_uint}; - -use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; +use rustc_ast::expand::typetree::FncTree; +#[cfg(llvm_enzyme)] +use { + crate::attributes, + rustc_ast::expand::typetree::TypeTree as RustTypeTree, + std::ffi::{CString, c_char, c_uint}, +}; -use crate::attributes; use crate::llvm::{self, Value}; /// Converts a Rust TypeTree to Enzyme's internal TypeTree format @@ -50,15 +53,6 @@ fn to_enzyme_typetree( enzyme_tt } -#[cfg(not(llvm_enzyme))] -fn to_enzyme_typetree( - _rust_typetree: RustTypeTree, - _data_layout: &str, - _llcx: &llvm::Context, -) -> ! { - unimplemented!("TypeTree conversion not available without llvm_enzyme support") -} - // Attaches TypeTree information to LLVM function as enzyme_type attributes. #[cfg(llvm_enzyme)] pub(crate) fn add_tt<'ll>( diff --git a/compiler/rustc_codegen_llvm/src/va_arg.rs b/compiler/rustc_codegen_llvm/src/va_arg.rs index ab08125217f..d48c7cf874a 100644 --- a/compiler/rustc_codegen_llvm/src/va_arg.rs +++ b/compiler/rustc_codegen_llvm/src/va_arg.rs @@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>( src_align, bx.const_u32(layout.layout.size().bytes() as u32), MemFlags::empty(), + None, ); tmp } else { diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index 1b218a0d339..b2dc4fe32b0 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -1626,6 +1626,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { align, bx.const_usize(copy_bytes), MemFlags::empty(), + None, ); // ...and then load it with the ABI type. llval = load_cast(bx, cast, llscratch, scratch_align); diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index 3c667b8e882..befa00c6861 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( if allow_overlap { bx.memmove(dst, align, src, align, size, flags); } else { - bx.memcpy(dst, align, src, align, size, flags); + bx.memcpy(dst, align, src, align, size, flags, None); } } diff --git a/compiler/rustc_codegen_ssa/src/mir/statement.rs b/compiler/rustc_codegen_ssa/src/mir/statement.rs index f164e0f9123..0a50d7f18db 100644 --- a/compiler/rustc_codegen_ssa/src/mir/statement.rs +++ b/compiler/rustc_codegen_ssa/src/mir/statement.rs @@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let align = pointee_layout.align; let dst = dst_val.immediate(); let src = src_val.immediate(); - bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty()); + bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None); } mir::StatementKind::FakeRead(..) | mir::StatementKind::Retag { .. } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 4a5694e97fa..60296e36e0c 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -451,6 +451,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option<rustc_ast::expand::typetree::FncTree>, ); fn memmove( &mut self, @@ -507,7 +508,7 @@ pub trait BuilderMethods<'a, 'tcx>: temp.val.store_with_flags(self, dst.with_type(layout), flags); } else if !layout.is_zst() { let bytes = self.const_usize(layout.size.bytes()); - self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags); + self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None); } } diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 837acdadd57..0dc5b5af3ac 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -765,7 +765,6 @@ fn test_unstable_options_tracking_hash() { tracked!(allow_features, Some(vec![String::from("lang_items")])); tracked!(always_encode_mir, true); tracked!(assume_incomplete_release, true); - tracked!(autodiff, vec![AutoDiff::Enable]); tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]); tracked!(binary_dep_depinfo, true); tracked!(box_noalias, false); diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index e7130757f1d..741b5d7fd4e 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2280,12 +2280,12 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { let child = typetree_from_ty(tcx, inner_ty); return TypeTree(vec![Type { offset: -1, - size: 8, // TODO(KMJ-007): Get actual pointer size from target + size: tcx.data_layout.pointer_size().bytes_usize(), kind: Kind::Pointer, child, }]); } - // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types + // FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types TypeTree::new() } |
