about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_codegen_gcc/src/builder.rs1
-rw-r--r--compiler/rustc_codegen_gcc/src/intrinsic/mod.rs1
-rw-r--r--compiler/rustc_codegen_llvm/src/abi.rs1
-rw-r--r--compiler/rustc_codegen_llvm/src/builder.rs15
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs1
-rw-r--r--compiler/rustc_codegen_llvm/src/typetree.rs20
-rw-r--r--compiler/rustc_codegen_llvm/src/va_arg.rs1
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/block.rs1
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/intrinsic.rs2
-rw-r--r--compiler/rustc_codegen_ssa/src/mir/statement.rs2
-rw-r--r--compiler/rustc_codegen_ssa/src/traits/builder.rs3
-rw-r--r--compiler/rustc_interface/src/tests.rs1
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs4
-rw-r--r--tests/codegen-llvm/autodiff/typetree.rs2
-rw-r--r--tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check8
-rw-r--r--tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check13
-rw-r--r--tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs36
-rw-r--r--tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs39
-rw-r--r--tests/run-make/autodiff/type-trees/nott-flag/rmake.rs14
-rw-r--r--tests/run-make/autodiff/type-trees/nott-flag/test.rs2
-rw-r--r--tests/ui/autodiff/flag_nott.rs2
21 files changed, 135 insertions, 34 deletions
diff --git a/compiler/rustc_codegen_gcc/src/builder.rs b/compiler/rustc_codegen_gcc/src/builder.rs
index f7a7a3f8c7e..5657620879c 100644
--- a/compiler/rustc_codegen_gcc/src/builder.rs
+++ b/compiler/rustc_codegen_gcc/src/builder.rs
@@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
         _src_align: Align,
         size: RValue<'gcc>,
         flags: MemFlags,
+        _tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
     ) {
         assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
         let size = self.intcast(size, self.type_size_t(), false);
diff --git a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs
index 84fa56cf903..3b897a19835 100644
--- a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs
+++ b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs
@@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
                     scratch_align,
                     bx.const_usize(self.layout.size.bytes()),
                     MemFlags::empty(),
+                    None,
                 );
 
                 bx.lifetime_end(scratch, scratch_size);
diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs
index 11be7041167..61fadad6066 100644
--- a/compiler/rustc_codegen_llvm/src/abi.rs
+++ b/compiler/rustc_codegen_llvm/src/abi.rs
@@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
                     scratch_align,
                     bx.const_usize(copy_bytes),
                     MemFlags::empty(),
+                    None,
                 );
                 bx.lifetime_end(llscratch, scratch_size);
             }
diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs
index 0f17cc9063a..83a9cf620f1 100644
--- a/compiler/rustc_codegen_llvm/src/builder.rs
+++ b/compiler/rustc_codegen_llvm/src/builder.rs
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
 use std::ops::Deref;
 use std::{iter, ptr};
 
+use rustc_ast::expand::typetree::FncTree;
 pub(crate) mod autodiff;
 pub(crate) mod gpu_offload;
 
@@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
         src_align: Align,
         size: &'ll Value,
         flags: MemFlags,
+        tt: Option<FncTree>,
     ) {
         assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
         let size = self.intcast(size, self.type_isize(), false);
         let is_volatile = flags.contains(MemFlags::VOLATILE);
-        unsafe {
+        let memcpy = unsafe {
             llvm::LLVMRustBuildMemCpy(
                 self.llbuilder,
                 dst,
@@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
                 src_align.bytes() as c_uint,
                 size,
                 is_volatile,
-            );
+            )
+        };
+
+        // TypeTree metadata for memcpy is especially important: when Enzyme encounters
+        // a memcpy during autodiff, it needs to know the structure of the data being
+        // copied to properly track derivatives. For example, copying an array of floats
+        // vs. copying a struct with mixed types requires different derivative handling.
+        // The TypeTree tells Enzyme exactly what memory layout to expect.
+        if let Some(tt) = tt {
+            crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
         }
     }
 
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index 12b6ffa3c0b..9dd84ad9a4d 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -25,6 +25,7 @@ pub(crate) enum CConcreteType {
     DT_Half = 3,
     DT_Float = 4,
     DT_Double = 5,
+    // FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600)
     DT_Unknown = 6,
 }
 
diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs
index 434316464e6..8e6272a186b 100644
--- a/compiler/rustc_codegen_llvm/src/typetree.rs
+++ b/compiler/rustc_codegen_llvm/src/typetree.rs
@@ -1,8 +1,11 @@
-use std::ffi::{CString, c_char, c_uint};
-
-use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
+use rustc_ast::expand::typetree::FncTree;
+#[cfg(llvm_enzyme)]
+use {
+    crate::attributes,
+    rustc_ast::expand::typetree::TypeTree as RustTypeTree,
+    std::ffi::{CString, c_char, c_uint},
+};
 
-use crate::attributes;
 use crate::llvm::{self, Value};
 
 /// Converts a Rust TypeTree to Enzyme's internal TypeTree format
@@ -50,15 +53,6 @@ fn to_enzyme_typetree(
     enzyme_tt
 }
 
-#[cfg(not(llvm_enzyme))]
-fn to_enzyme_typetree(
-    _rust_typetree: RustTypeTree,
-    _data_layout: &str,
-    _llcx: &llvm::Context,
-) -> ! {
-    unimplemented!("TypeTree conversion not available without llvm_enzyme support")
-}
-
 // Attaches TypeTree information to LLVM function as enzyme_type attributes.
 #[cfg(llvm_enzyme)]
 pub(crate) fn add_tt<'ll>(
diff --git a/compiler/rustc_codegen_llvm/src/va_arg.rs b/compiler/rustc_codegen_llvm/src/va_arg.rs
index ab08125217f..d48c7cf874a 100644
--- a/compiler/rustc_codegen_llvm/src/va_arg.rs
+++ b/compiler/rustc_codegen_llvm/src/va_arg.rs
@@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
             src_align,
             bx.const_u32(layout.layout.size().bytes() as u32),
             MemFlags::empty(),
+            None,
         );
         tmp
     } else {
diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs
index 1b218a0d339..b2dc4fe32b0 100644
--- a/compiler/rustc_codegen_ssa/src/mir/block.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/block.rs
@@ -1626,6 +1626,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
                     align,
                     bx.const_usize(copy_bytes),
                     MemFlags::empty(),
+                    None,
                 );
                 // ...and then load it with the ABI type.
                 llval = load_cast(bx, cast, llscratch, scratch_align);
diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
index 3c667b8e882..befa00c6861 100644
--- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
     if allow_overlap {
         bx.memmove(dst, align, src, align, size, flags);
     } else {
-        bx.memcpy(dst, align, src, align, size, flags);
+        bx.memcpy(dst, align, src, align, size, flags, None);
     }
 }
 
diff --git a/compiler/rustc_codegen_ssa/src/mir/statement.rs b/compiler/rustc_codegen_ssa/src/mir/statement.rs
index f164e0f9123..0a50d7f18db 100644
--- a/compiler/rustc_codegen_ssa/src/mir/statement.rs
+++ b/compiler/rustc_codegen_ssa/src/mir/statement.rs
@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
                 let align = pointee_layout.align;
                 let dst = dst_val.immediate();
                 let src = src_val.immediate();
-                bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
+                bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
             }
             mir::StatementKind::FakeRead(..)
             | mir::StatementKind::Retag { .. }
diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs
index 4a5694e97fa..60296e36e0c 100644
--- a/compiler/rustc_codegen_ssa/src/traits/builder.rs
+++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs
@@ -451,6 +451,7 @@ pub trait BuilderMethods<'a, 'tcx>:
         src_align: Align,
         size: Self::Value,
         flags: MemFlags,
+        tt: Option<rustc_ast::expand::typetree::FncTree>,
     );
     fn memmove(
         &mut self,
@@ -507,7 +508,7 @@ pub trait BuilderMethods<'a, 'tcx>:
             temp.val.store_with_flags(self, dst.with_type(layout), flags);
         } else if !layout.is_zst() {
             let bytes = self.const_usize(layout.size.bytes());
-            self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags);
+            self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
         }
     }
 
diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs
index 837acdadd57..0dc5b5af3ac 100644
--- a/compiler/rustc_interface/src/tests.rs
+++ b/compiler/rustc_interface/src/tests.rs
@@ -765,7 +765,6 @@ fn test_unstable_options_tracking_hash() {
     tracked!(allow_features, Some(vec![String::from("lang_items")]));
     tracked!(always_encode_mir, true);
     tracked!(assume_incomplete_release, true);
-    tracked!(autodiff, vec![AutoDiff::Enable]);
     tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]);
     tracked!(binary_dep_depinfo, true);
     tracked!(box_noalias, false);
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index e7130757f1d..741b5d7fd4e 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -2280,12 +2280,12 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
         let child = typetree_from_ty(tcx, inner_ty);
         return TypeTree(vec![Type {
             offset: -1,
-            size: 8, // TODO(KMJ-007): Get actual pointer size from target
+            size: tcx.data_layout.pointer_size().bytes_usize(),
             kind: Kind::Pointer,
             child,
         }]);
     }
 
-    // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
+    // FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types
     TypeTree::new()
 }
diff --git a/tests/codegen-llvm/autodiff/typetree.rs b/tests/codegen-llvm/autodiff/typetree.rs
index 3ad38d581b9..1cb0c2fb68b 100644
--- a/tests/codegen-llvm/autodiff/typetree.rs
+++ b/tests/codegen-llvm/autodiff/typetree.rs
@@ -30,4 +30,4 @@ fn main() {
     let output_ = d_simple(&x, &mut df_dx, 1.0);
     assert_eq!(output, output_);
     assert_eq!(2.0, df_dx);
-}
\ No newline at end of file
+}
diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check
new file mode 100644
index 00000000000..3a59f06dda1
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check
@@ -0,0 +1,8 @@
+; Check that enzyme_type attributes are present in the LLVM IR function definition
+; This verifies our TypeTree system correctly attaches metadata for Enzyme
+
+CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"
+
+; Check that llvm.memcpy exists (either call or declare)
+CHECK: {{(call|declare).*}}@llvm.memcpy
+
diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check
new file mode 100644
index 00000000000..ae70830297a
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check
@@ -0,0 +1,13 @@
+CHECK: force_memcpy
+
+CHECK: @llvm.memcpy.p0.p0.i64
+
+CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer}:{}
+
+CHECK-DAG: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double}
+
+CHECK-DAG: load double{{.*}}: {[-1]:Float@double}
+
+CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double}
+
+CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double}
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs
new file mode 100644
index 00000000000..3c1029190c8
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs
@@ -0,0 +1,36 @@
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+use std::ptr;
+
+#[inline(never)]
+fn force_memcpy(src: *const f64, dst: *mut f64, count: usize) {
+    unsafe {
+        ptr::copy_nonoverlapping(src, dst, count);
+    }
+}
+
+#[autodiff_reverse(d_test_memcpy, Duplicated, Active)]
+#[no_mangle]
+fn test_memcpy(input: &[f64; 128]) -> f64 {
+    let mut local_data = [0.0f64; 128];
+
+    // Use a separate function to prevent inlining and optimization
+    force_memcpy(input.as_ptr(), local_data.as_mut_ptr(), 128);
+
+    // Sum only first few elements to keep the computation simple
+    local_data[0] * local_data[0]
+        + local_data[1] * local_data[1]
+        + local_data[2] * local_data[2]
+        + local_data[3] * local_data[3]
+}
+
+fn main() {
+    let input = [1.0; 128];
+    let mut d_input = [0.0; 128];
+    let result = test_memcpy(&input);
+    let result_d = d_test_memcpy(&input, &mut d_input, 1.0);
+
+    assert_eq!(result, result_d);
+    println!("Memcpy test passed: result = {}", result);
+}
diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs
new file mode 100644
index 00000000000..b4c650330fe
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs
@@ -0,0 +1,39 @@
+//@ needs-enzyme
+//@ ignore-cross-compile
+
+use run_make_support::{llvm_filecheck, rfs, rustc};
+
+fn main() {
+    // First, compile to LLVM IR to check for enzyme_type attributes
+    let _ir_output = rustc()
+        .input("memcpy.rs")
+        .arg("-Zautodiff=Enable")
+        .arg("-Zautodiff=NoPostopt")
+        .opt_level("0")
+        .arg("--emit=llvm-ir")
+        .arg("-o")
+        .arg("main.ll")
+        .run();
+
+    // Then compile with TypeTree analysis output for the existing checks
+    let output = rustc()
+        .input("memcpy.rs")
+        .arg("-Zautodiff=Enable,PrintTAFn=test_memcpy")
+        .arg("-Zautodiff=NoPostopt")
+        .opt_level("3")
+        .arg("-Clto=fat")
+        .arg("-g")
+        .run();
+
+    let stdout = output.stdout_utf8();
+    let stderr = output.stderr_utf8();
+    let ir_content = rfs::read_to_string("main.ll");
+
+    rfs::write("memcpy.stdout", &stdout);
+    rfs::write("memcpy.stderr", &stderr);
+    rfs::write("main.ir", &ir_content);
+
+    llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run();
+
+    llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run();
+}
diff --git a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs
index bab863ca9ff..de540b990ca 100644
--- a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs
+++ b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs
@@ -23,14 +23,8 @@ fn main() {
         .run();
 
     // Verify NoTT version does NOT have enzyme_type attributes
-    llvm_filecheck()
-        .patterns("nott.check")
-        .stdin_buf(rfs::read("nott.ll"))
-        .run();
-    
+    llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.ll")).run();
+
     // Verify TypeTree version DOES have enzyme_type attributes
-    llvm_filecheck()
-        .patterns("with_tt.check")
-        .stdin_buf(rfs::read("with_tt.ll"))
-        .run();
-}
\ No newline at end of file
+    llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.ll")).run();
+}
diff --git a/tests/run-make/autodiff/type-trees/nott-flag/test.rs b/tests/run-make/autodiff/type-trees/nott-flag/test.rs
index 5c634eea035..de3549c37c6 100644
--- a/tests/run-make/autodiff/type-trees/nott-flag/test.rs
+++ b/tests/run-make/autodiff/type-trees/nott-flag/test.rs
@@ -12,4 +12,4 @@ fn main() {
     let x = 2.0;
     let mut dx = 0.0;
     let _result = d_square(&x, &mut dx, 1.0);
-}
\ No newline at end of file
+}
diff --git a/tests/ui/autodiff/flag_nott.rs b/tests/ui/autodiff/flag_nott.rs
index 7a97d892cd8..faa9949fe81 100644
--- a/tests/ui/autodiff/flag_nott.rs
+++ b/tests/ui/autodiff/flag_nott.rs
@@ -16,4 +16,4 @@ fn main() {
     let x = 2.0;
     let mut dx = 0.0;
     let result = d_square(&x, &mut dx, 1.0);
-}
\ No newline at end of file
+}