about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorKaran Janthe <karanjanthe@gmail.com>2025-09-12 06:11:18 +0000
committerKaran Janthe <karanjanthe@gmail.com>2025-09-19 05:42:27 +0000
commit3ba5f19182bf7144c54cbbd0b7af3d4fe76b5317 (patch)
tree68eff90e616c4a68d470c5d38668b14d5bd2095d /compiler
parent4520926bb527bd43edbf0de84c2b0c6a9c5fc5ce (diff)
downloadrust-3ba5f19182bf7144c54cbbd0b7af3d4fe76b5317.tar.gz
rust-3ba5f19182bf7144c54cbbd0b7af3d4fe76b5317.zip
autodiff: typetree recursive depth query from enzyme with fallback
Signed-off-by: Karan Janthe <karanjanthe@gmail.com>
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs1
-rw-r--r--compiler/rustc_codegen_llvm/src/typetree.rs10
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp12
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs21
4 files changed, 24 insertions, 20 deletions
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index b604f5139c8..e63043b2122 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -127,7 +127,6 @@ pub(crate) mod Enzyme_AD {
         );
         pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char;
         pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char);
-        pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
     }
 
     unsafe extern "C" {
diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs
index 1a54884f6c5..7e263503700 100644
--- a/compiler/rustc_codegen_llvm/src/typetree.rs
+++ b/compiler/rustc_codegen_llvm/src/typetree.rs
@@ -1,5 +1,5 @@
 use rustc_ast::expand::typetree::FncTree;
-#[cfg(llvm_enzyme)]
+#[cfg(feature = "llvm_enzyme")]
 use {
     crate::attributes,
     rustc_ast::expand::typetree::TypeTree as RustTypeTree,
@@ -8,7 +8,7 @@ use {
 
 use crate::llvm::{self, Value};
 
-#[cfg(llvm_enzyme)]
+#[cfg(feature = "llvm_enzyme")]
 fn to_enzyme_typetree(
     rust_typetree: RustTypeTree,
     _data_layout: &str,
@@ -18,7 +18,7 @@ fn to_enzyme_typetree(
     process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
     enzyme_tt
 }
-#[cfg(llvm_enzyme)]
+#[cfg(feature = "llvm_enzyme")]
 fn process_typetree_recursive(
     enzyme_tt: &mut llvm::TypeTree,
     rust_typetree: &RustTypeTree,
@@ -56,7 +56,7 @@ fn process_typetree_recursive(
     }
 }
 
-#[cfg(llvm_enzyme)]
+#[cfg(feature = "llvm_enzyme")]
 pub(crate) fn add_tt<'ll>(
     llmod: &'ll llvm::Module,
     llcx: &'ll llvm::Context,
@@ -111,7 +111,7 @@ pub(crate) fn add_tt<'ll>(
     }
 }
 
-#[cfg(not(llvm_enzyme))]
+#[cfg(not(feature = "llvm_enzyme"))]
 pub(crate) fn add_tt<'ll>(
     _llmod: &'ll llvm::Module,
     _llcx: &'ll llvm::Context,
diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
index 64151962321..c1a924a87e4 100644
--- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
@@ -1847,3 +1847,15 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
   MD.NoHWAddress = true;
   GV.setSanitizerMetadata(MD);
 }
+
+#ifdef ENZYME
+extern "C" {
+extern llvm::cl::opt<unsigned> EnzymeMaxTypeDepth;
+}
+
+extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; }
+#else
+extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() {
+  return 6; // Default fallback depth
+}
+#endif
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 7ca2355947a..ce4de6b95e0 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -63,7 +63,7 @@ pub use rustc_type_ir::solve::SizedTraitKind;
 pub use rustc_type_ir::*;
 #[allow(hidden_glob_reexports, unused_imports)]
 use rustc_type_ir::{InferCtxtLike, Interner};
-use tracing::{debug, instrument};
+use tracing::{debug, instrument, trace};
 pub use vtable::*;
 use {rustc_ast as ast, rustc_hir as hir};
 
@@ -2256,6 +2256,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
     typetree_from_ty_inner(tcx, ty, 0, &mut visited)
 }
 
+/// Maximum recursion depth for TypeTree generation to prevent stack overflow
+/// from pathological deeply nested types. Combined with cycle detection.
+const MAX_TYPETREE_DEPTH: usize = 6;
+
 /// Internal recursive function for TypeTree generation with cycle detection and depth limiting.
 fn typetree_from_ty_inner<'tcx>(
     tcx: TyCtxt<'tcx>,
@@ -2263,19 +2267,8 @@ fn typetree_from_ty_inner<'tcx>(
     depth: usize,
     visited: &mut Vec<Ty<'tcx>>,
 ) -> TypeTree {
-    #[cfg(llvm_enzyme)]
-    {
-        unsafe extern "C" {
-            fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint;
-        }
-        let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize;
-        if depth > max_depth {
-            return TypeTree::new();
-        }
-    }
-
-    #[cfg(not(llvm_enzyme))]
-    if depth > 6 {
+    if depth >= MAX_TYPETREE_DEPTH {
+        trace!("typetree depth limit {} reached for type: {}", MAX_TYPETREE_DEPTH, ty);
         return TypeTree::new();
     }