about summary refs log tree commit diff
path: root/compiler/rustc_middle
diff options
context:
space:
mode:
authorKaran Janthe <karanjanthe@gmail.com>2025-09-11 07:30:35 +0000
committerKaran Janthe <karanjanthe@gmail.com>2025-09-19 04:11:35 +0000
commit4520926bb527bd43edbf0de84c2b0c6a9c5fc5ce (patch)
tree5bc2f8d378284735c25a81d8e1fe18b0d8a4978c /compiler/rustc_middle
parent4f3f0f48e7b1e61818b2bcbe4451f89bb4f47049 (diff)
downloadrust-4520926bb527bd43edbf0de84c2b0c6a9c5fc5ce.tar.gz
rust-4520926bb527bd43edbf0de84c2b0c6a9c5fc5ce.zip
autodiff: recurion added for typetree
Diffstat (limited to 'compiler/rustc_middle')
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs76
1 files changed, 68 insertions, 8 deletions
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 581f20e8492..7ca2355947a 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -2252,6 +2252,61 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
 /// Generate TypeTree for a specific type.
 /// This function analyzes a Rust type and creates appropriate TypeTree metadata.
 pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
+    let mut visited = Vec::new();
+    typetree_from_ty_inner(tcx, ty, 0, &mut visited)
+}
+
+/// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
+fn typetree_from_ty_inner<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    ty: Ty<'tcx>,
+    depth: usize,
+    visited: &mut Vec<Ty<'tcx>>,
+) -> TypeTree {
+    #[cfg(llvm_enzyme)]
+    {
+        unsafe extern "C" {
+            fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
+        }
+        let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
+        if depth > max_depth {
+            return TypeTree::new();
+        }
+    }
+
+    #[cfg(not(llvm_enzyme))]
+    if depth > 6 {
+        return TypeTree::new();
+    }
+
+    if visited.contains(&ty) {
+        return TypeTree::new();
+    }
+
+    visited.push(ty);
+    let result = typetree_from_ty_impl(tcx, ty, depth, visited);
+    visited.pop();
+    result
+}
+
+/// Implementation of TypeTree generation logic.
+fn typetree_from_ty_impl<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    ty: Ty<'tcx>,
+    depth: usize,
+    visited: &mut Vec<Ty<'tcx>>,
+) -> TypeTree {
+    typetree_from_ty_impl_inner(tcx, ty, depth, visited, false)
+}
+
+/// Internal implementation with context about whether this is for a reference target.
+fn typetree_from_ty_impl_inner<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    ty: Ty<'tcx>,
+    depth: usize,
+    visited: &mut Vec<Ty<'tcx>>,
+    is_reference_target: bool,
+) -> TypeTree {
     if ty.is_scalar() {
         let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
             (Kind::Integer, ty.primitive_size(tcx).bytes_usize())
@@ -2267,7 +2322,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
             (Kind::Integer, 0)
         };
 
-        return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]);
+        // Use offset 0 for scalars that are direct targets of references (like &f64)
+        // Use offset -1 for scalars used directly (like function return types)
+        let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 };
+        return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]);
     }
 
     if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
@@ -2277,7 +2335,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
             return TypeTree::new();
         };
 
-        let child = typetree_from_ty(tcx, inner_ty);
+        let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true);
         return TypeTree(vec![Type {
             offset: -1,
             size: tcx.data_layout.pointer_size().bytes_usize(),
@@ -2292,9 +2350,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
             if len == 0 {
                 return TypeTree::new();
             }
-
-            let element_tree = typetree_from_ty(tcx, *element_ty);
-
+            let element_tree =
+                typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
             let mut types = Vec::new();
             for elem_type in &element_tree.0 {
                 types.push(Type {
@@ -2311,7 +2368,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
 
     if ty.is_slice() {
         if let ty::Slice(element_ty) = ty.kind() {
-            let element_tree = typetree_from_ty(tcx, *element_ty);
+            let element_tree =
+                typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false);
             return element_tree;
         }
     }
@@ -2325,7 +2383,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
         let mut current_offset = 0;
 
         for tuple_ty in tuple_types.iter() {
-            let element_tree = typetree_from_ty(tcx, tuple_ty);
+            let element_tree =
+                typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false);
 
             let element_layout = tcx
                 .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
@@ -2361,7 +2420,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
 
                 for (field_idx, field_def) in adt_def.all_fields().enumerate() {
                     let field_ty = field_def.ty(tcx, args);
-                    let field_tree = typetree_from_ty(tcx, field_ty);
+                    let field_tree =
+                        typetree_from_ty_impl_inner(tcx, field_ty, depth + 1, visited, false);
 
                     let field_offset = layout.fields.offset(field_idx).bytes_usize();