about summary refs log tree commit diff
diff options
context:
space:
mode:
authorKaran Janthe <karanjanthe@gmail.com>2025-09-11 07:30:35 +0000
committerKaran Janthe <karanjanthe@gmail.com>2025-09-19 04:11:35 +0000
commit4520926bb527bd43edbf0de84c2b0c6a9c5fc5ce (patch)
tree5bc2f8d378284735c25a81d8e1fe18b0d8a4978c
parent4f3f0f48e7b1e61818b2bcbe4451f89bb4f47049 (diff)
downloadrust-4520926bb527bd43edbf0de84c2b0c6a9c5fc5ce.tar.gz
rust-4520926bb527bd43edbf0de84c2b0c6a9c5fc5ce.zip
autodiff: recurion added for typetree
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs1
-rw-r--r--compiler/rustc_codegen_llvm/src/typetree.rs10
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs76
-rw-r--r--tests/run-make/autodiff/type-trees/nott-flag/with_tt.check2
-rw-r--r--tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check3
-rw-r--r--tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs9
-rw-r--r--tests/run-make/autodiff/type-trees/recursion-typetree/test.rs100
-rw-r--r--tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check2
-rw-r--r--tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check2
-rw-r--r--tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check2
-rw-r--r--tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check2
-rw-r--r--tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check2
12 files changed, 191 insertions, 20 deletions
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index e63043b2122..b604f5139c8 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -127,6 +127,7 @@ pub(crate) mod Enzyme_AD {
         );
         pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
         pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
+        pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
     }
 
     unsafe extern "C" {
diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs
index ae6a2da62b5..1a54884f6c5 100644
--- a/compiler/rustc_codegen_llvm/src/typetree.rs
+++ b/compiler/rustc_codegen_llvm/src/typetree.rs
@@ -39,11 +39,7 @@ fn process_typetree_recursive(
 
         let mut indices = parent_indices.to_vec();
         if !parent_indices.is_empty() {
-            if rust_type.offset == -1 {
-                indices.push(-1);
-            } else {
-                indices.push(rust_type.offset as i64);
-            }
+            indices.push(rust_type.offset as i64);
         } else if rust_type.offset == -1 {
             indices.push(-1);
         } else {
@@ -52,7 +48,9 @@ fn process_typetree_recursive(
 
         enzyme_tt.insert(&indices, concrete_type, llcx);
 
-        if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() {
+        if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
+            && !rust_type.child.0.is_empty()
+        {
             process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
         }
     }
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 581f20e8492..7ca2355947a 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -2252,6 +2252,61 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
 /// Generate TypeTree for a specific type.
 /// This function analyzes a Rust type and creates appropriate TypeTree metadata.
 pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
+    let mut visited = Vec::new();
+    typetree_from_ty_inner(tcx, ty, 0, &mut visited)
+}
+
+/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
+fn typetree_from_ty_inner<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    ty: Ty<'tcx>,
+    depth: usize,
+    visited: &mut Vec<Ty<'tcx>>,
+) -> TypeTree {
+    #[cfg(llvm_enzyme)]
+    {
+        unsafe extern "C" {
+            fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
+        }
+        let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
+        if depth > max_depth {
+            return TypeTree::new();
+        }
+    }
+
+    #[cfg(not(llvm_enzyme))]
+    if depth > 6 {
+        return TypeTree::new();
+    }
+
+    if visited.contains(&ty) {
+        return TypeTree::new();
+    }
+
+    visited.push(ty);
+    let result = typetree_from_ty_impl(tcx, ty, depth, visited);
+    visited.pop();
+    result
+}
+
+/// Implementation of TypeTree generation logic.
+fn typetree_from_ty_impl<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    ty: Ty<'tcx>,
+    depth: usize,
+    visited: &mut Vec<Ty<'tcx>>,
+) -> TypeTree {
+    typetree_from_ty_impl_inner(tcx, ty, depth, visited, false)
+}
+
+/// Internal implementation with context about whether this is for a reference target.
+fn typetree_from_ty_impl_inner<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    ty: Ty<'tcx>,
+    depth: usize,
+    visited: &mut Vec<Ty<'tcx>>,
+    is_reference_target: bool,
+) -> TypeTree {
     if ty.is_scalar() {
         let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
             (Kind::Integer, ty.primitive_size(tcx).bytes_usize())
@@ -2267,7 +2322,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
             (Kind::Integer, 0)
         };
 
-        return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]);
+        // Use offset 0 for scalars that are direct targets of references (like &f64)
+        // Use offset -1 for scalars used directly (like function return types)
+        let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 };
+        return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]);
     }
 
     if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
@@ -2277,7 +2335,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
             return TypeTree::new();
         };
 
-        let child = typetree_from_ty(tcx, inner_ty);
+        let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true);
         return TypeTree(vec![Type {
             offset: -1,
             size: tcx.data_layout.pointer_size().bytes_usize(),
@@ -2292,9 +2350,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
             if len == 0 {
                 return TypeTree::new();
             }
-
-            let element_tree = typetree_from_ty(tcx, *element_ty);
-
+            let element_tree =
+                typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
             let mut types = Vec::new();
             for elem_type in &element_tree.0 {
                 types.push(Type {
@@ -2311,7 +2368,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
 
     if ty.is_slice() {
         if let ty::Slice(element_ty) = ty.kind() {
-            let element_tree = typetree_from_ty(tcx, *element_ty);
+            let element_tree =
+                typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
             return element_tree;
         }
     }
@@ -2325,7 +2383,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
         let mut current_offset = 0;
 
         for tuple_ty in tuple_types.iter() {
-            let element_tree = typetree_from_ty(tcx, tuple_ty);
+            let element_tree =
+                typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false);
 
             let element_layout = tcx
                 .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
@@ -2361,7 +2420,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
 
                 for (field_idx, field_def) in adt_def.all_fields().enumerate() {
                     let field_ty = field_def.ty(tcx, args);
-                    let field_tree = typetree_from_ty(tcx, field_ty);
+                    let field_tree =
+                        typetree_from_ty_impl_inner(tcx, field_ty, depth + 1, visited, false);
 
                     let field_offset = layout.fields.offset(field_idx).bytes_usize();
 
diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check
index 3c02003c882..0b4c9119179 100644
--- a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check
+++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check
@@ -1,4 +1,4 @@
 // Check that enzyme_type attributes are present when TypeTree is enabled
 // This verifies our TypeTree metadata attachment is working
 
-CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
\ No newline at end of file
+CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check b/tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check
new file mode 100644
index 00000000000..1960e7b816c
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check
@@ -0,0 +1,3 @@
+CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_deep{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
+CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_graph{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Integer, [-1,16]:Integer, [-1,24]:Float@double}"
+CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_node{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs
new file mode 100644
index 00000000000..78718f3a215
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs
@@ -0,0 +1,9 @@
+//@ needs-enzyme
+//@ ignore-cross-compile
+
+use run_make_support::{llvm_filecheck, rfs, rustc};
+
+fn main() {
+    rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
+    llvm_filecheck().patterns("recursion.check").stdin_buf(rfs::read("test.ll")).run();
+}
diff --git a/tests/run-make/autodiff/type-trees/recursion-typetree/test.rs b/tests/run-make/autodiff/type-trees/recursion-typetree/test.rs
new file mode 100644
index 00000000000..9d40bec1bf1
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/recursion-typetree/test.rs
@@ -0,0 +1,100 @@
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+// Self-referential struct to test recursion detection
+#[derive(Clone)]
+struct Node {
+    value: f64,
+    next: Option<Box<Node>>,
+}
+
+// Mutually recursive structs to test cycle detection
+#[derive(Clone)]
+struct GraphNodeA {
+    value: f64,
+    connections: Vec<GraphNodeB>,
+}
+
+#[derive(Clone)]
+struct GraphNodeB {
+    weight: f64,
+    target: Option<Box<GraphNodeA>>,
+}
+
+#[autodiff_reverse(d_test_node, Duplicated, Active)]
+#[no_mangle]
+fn test_node(node: &Node) -> f64 {
+    node.value * 2.0
+}
+
+#[autodiff_reverse(d_test_graph, Duplicated, Active)]
+#[no_mangle]
+fn test_graph(a: &GraphNodeA) -> f64 {
+    a.value * 3.0
+}
+
+// Simple depth test - deeply nested but not circular
+#[derive(Clone)]
+struct Level1 {
+    val: f64,
+    next: Option<Box<Level2>>,
+}
+#[derive(Clone)]
+struct Level2 {
+    val: f64,
+    next: Option<Box<Level3>>,
+}
+#[derive(Clone)]
+struct Level3 {
+    val: f64,
+    next: Option<Box<Level4>>,
+}
+#[derive(Clone)]
+struct Level4 {
+    val: f64,
+    next: Option<Box<Level5>>,
+}
+#[derive(Clone)]
+struct Level5 {
+    val: f64,
+    next: Option<Box<Level6>>,
+}
+#[derive(Clone)]
+struct Level6 {
+    val: f64,
+    next: Option<Box<Level7>>,
+}
+#[derive(Clone)]
+struct Level7 {
+    val: f64,
+    next: Option<Box<Level8>>,
+}
+#[derive(Clone)]
+struct Level8 {
+    val: f64,
+}
+
+#[autodiff_reverse(d_test_deep, Duplicated, Active)]
+#[no_mangle]
+fn test_deep(deep: &Level1) -> f64 {
+    deep.val * 4.0
+}
+
+fn main() {
+    let node = Node { value: 1.0, next: None };
+
+    let graph = GraphNodeA { value: 2.0, connections: vec![] };
+
+    let deep = Level1 { val: 5.0, next: None };
+
+    let mut d_node = Node { value: 0.0, next: None };
+
+    let mut d_graph = GraphNodeA { value: 0.0, connections: vec![] };
+
+    let mut d_deep = Level1 { val: 0.0, next: None };
+
+    let _result1 = d_test_node(&node, &mut d_node, 1.0);
+    let _result2 = d_test_graph(&graph, &mut d_graph, 1.0);
+    let _result3 = d_test_deep(&deep, &mut d_deep, 1.0);
+}
diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check
index 733e46aa45a..23db64eea52 100644
--- a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check
+++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check
@@ -1,4 +1,4 @@
 ; Check that f128 TypeTree metadata is correctly generated
 ; Should show Float@fp128 for f128 values and Pointer for references
 
-CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}"
\ No newline at end of file
+CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@fp128}"
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check
index 9caca26b5cf..9adff68d36f 100644
--- a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check
+++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check
@@ -1,4 +1,4 @@
 ; Check that f16 TypeTree metadata is correctly generated
 ; Should show Float@half for f16 values and Pointer for references
 
-CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@half}"
\ No newline at end of file
+CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@half}"
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check
index ec12ba6b234..176630f57e8 100644
--- a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check
+++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check
@@ -1,4 +1,4 @@
 ; Check that f32 TypeTree metadata is correctly generated
 ; Should show Float@float for f32 values and Pointer for references
 
-CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"
\ No newline at end of file
+CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}"
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check
index f1af270824d..929cd379694 100644
--- a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check
+++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check
@@ -1,4 +1,4 @@
 ; Check that f64 TypeTree metadata is correctly generated  
 ; Should show Float@double for f64 values and Pointer for references
 
-CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}"
\ No newline at end of file
+CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}"
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check
index fd9a94be810..dee4aa5bbb6 100644
--- a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check
+++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check
@@ -1,4 +1,4 @@
 ; Check that i32 TypeTree metadata is correctly generated
 ; Should show Integer for i32 values and Pointer for references
 
-CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}"
\ No newline at end of file
+CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer}"
\ No newline at end of file