about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm
diff options
context:
space:
mode:
authorKaran Janthe <karanjanthe@gmail.com>2025-09-04 11:17:34 +0000
committerKaran Janthe <karanjanthe@gmail.com>2025-09-19 04:11:35 +0000
commit4f3f0f48e7b1e61818b2bcbe4451f89bb4f47049 (patch)
tree10c5f51ddd0b6272941abfd2f430fcd9779e6bc7 /compiler/rustc_codegen_llvm
parent574f0b97d6f30cd6cedb165fde13cdec176611b8 (diff)
downloadrust-4f3f0f48e7b1e61818b2bcbe4451f89bb4f47049.tar.gz
rust-4f3f0f48e7b1e61818b2bcbe4451f89bb4f47049.zip
autodiff: fixed test to be more precise for type tree checking
Diffstat (limited to 'compiler/rustc_codegen_llvm')
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs23
-rw-r--r--compiler/rustc_codegen_llvm/src/typetree.rs70
2 files changed, 51 insertions, 42 deletions
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index 1596dc48379..e63043b2122 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -118,6 +118,13 @@ pub(crate) mod Enzyme_AD {
             max_size: i64,
             add_offset: u64,
         );
+        pub(crate) fn EnzymeTypeTreeInsertEq(
+            CTT: CTypeTreeRef,
+            indices: *const i64,
+            len: usize,
+            ct: CConcreteType,
+            ctx: &Context,
+        );
         pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
         pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
     }
@@ -234,6 +241,16 @@ pub(crate) mod Fallback_AD {
         unimplemented!()
     }
 
+    pub(crate) unsafe fn EnzymeTypeTreeInsertEq(
+        CTT: CTypeTreeRef,
+        indices: *const i64,
+        len: usize,
+        ct: CConcreteType,
+        ctx: &Context,
+    ) {
+        unimplemented!()
+    }
+
     pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char {
         unimplemented!()
     }
@@ -312,6 +329,12 @@ impl TypeTree {
 
         self
     }
+
+    pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
+        unsafe {
+            EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
+        }
+    }
 }
 
 impl Clone for TypeTree {
diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs
index 8c0d255bba8..ae6a2da62b5 100644
--- a/compiler/rustc_codegen_llvm/src/typetree.rs
+++ b/compiler/rustc_codegen_llvm/src/typetree.rs
@@ -8,22 +8,24 @@ use {
 
 use crate::llvm::{self, Value};
 
-/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
-///
-/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree)
-/// and converts it to Enzyme's internal C++ TypeTree representation that
-/// Enzyme can understand during differentiation analysis.
 #[cfg(llvm_enzyme)]
 fn to_enzyme_typetree(
     rust_typetree: RustTypeTree,
-    data_layout: &str,
+    _data_layout: &str,
     llcx: &llvm::Context,
 ) -> llvm::TypeTree {
-    // Start with an empty TypeTree
     let mut enzyme_tt = llvm::TypeTree::new();
-
-    // Convert each Type in the Rust TypeTree to Enzyme format
-    for rust_type in rust_typetree.0 {
+    process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
+    enzyme_tt
+}
+#[cfg(llvm_enzyme)]
+fn process_typetree_recursive(
+    enzyme_tt: &mut llvm::TypeTree,
+    rust_typetree: &RustTypeTree,
+    parent_indices: &[i64],
+    llcx: &llvm::Context,
+) {
+    for rust_type in &rust_typetree.0 {
         let concrete_type = match rust_type.kind {
             rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
             rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
@@ -35,25 +37,27 @@ fn to_enzyme_typetree(
             rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
         };
 
-        // Create a TypeTree for this specific type
-        let type_tt = llvm::TypeTree::from_type(concrete_type, llcx);
-
-        // Apply offset if specified
-        let type_tt = if rust_type.offset == -1 {
-            type_tt // -1 means everywhere/no specific offset
+        let mut indices = parent_indices.to_vec();
+        if !parent_indices.is_empty() {
+            if rust_type.offset == -1 {
+                indices.push(-1);
+            } else {
+                indices.push(rust_type.offset as i64);
+            }
+        } else if rust_type.offset == -1 {
+            indices.push(-1);
         } else {
-            // Apply specific offset positioning
-            type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0)
-        };
+            indices.push(rust_type.offset as i64);
+        }
 
-        // Merge this type into the main TypeTree
-        enzyme_tt = enzyme_tt.merge(type_tt);
-    }
+        enzyme_tt.insert(&indices, concrete_type, llcx);
 
-    enzyme_tt
+        if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() {
+            process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
+        }
+    }
 }
 
-// Attaches TypeTree information to LLVM function as enzyme_type attributes.
 #[cfg(llvm_enzyme)]
 pub(crate) fn add_tt<'ll>(
     llmod: &'ll llvm::Module,
@@ -64,28 +68,20 @@ pub(crate) fn add_tt<'ll>(
     let inputs = tt.args;
     let ret_tt: RustTypeTree = tt.ret;
 
-    // Get LLVM data layout string for TypeTree conversion
     let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
     let llvm_data_layout =
         std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
             .expect("got a non-UTF8 data-layout from LLVM");
 
-    // Attribute name that Enzyme recognizes for TypeTree information
     let attr_name = "enzyme_type";
     let c_attr_name = CString::new(attr_name).unwrap();
 
-    // Attach TypeTree attributes to each input parameter
-    // Enzyme uses these to understand parameter memory layouts during differentiation
     for (i, input) in inputs.iter().enumerate() {
         unsafe {
-            // Convert Rust TypeTree to Enzyme's internal format
             let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
-
-            // Serialize TypeTree to string format that Enzyme can parse
             let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
             let c_str = std::ffi::CStr::from_ptr(c_str);
 
-            // Create LLVM string attribute with TypeTree information
             let attr = llvm::LLVMCreateStringAttribute(
                 llcx,
                 c_attr_name.as_ptr(),
@@ -94,17 +90,11 @@ pub(crate) fn add_tt<'ll>(
                 c_str.to_bytes().len() as c_uint,
             );
 
-            // Attach attribute to the specific function parameter
-            // Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments
             attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
-
-            // Free the C string to prevent memory leaks
             llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
         }
     }
 
-    // Attach TypeTree attribute to the return type
-    // Enzyme needs this to understand how to handle return value derivatives
     unsafe {
         let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
         let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
@@ -118,15 +108,11 @@ pub(crate) fn add_tt<'ll>(
             c_str.to_bytes().len() as c_uint,
         );
 
-        // Attach to function return type
         attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
-
-        // Free the C string
         llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
     }
 }
 
-// Fallback implementation when Enzyme is not available
 #[cfg(not(llvm_enzyme))]
 pub(crate) fn add_tt<'ll>(
     _llmod: &'ll llvm::Module,