about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs')
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs20
1 files changed, 18 insertions, 2 deletions
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index 2ad39fc8538..c696b8d8ff2 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -1,4 +1,3 @@
-#![allow(non_camel_case_types)]
 #![expect(dead_code)]
 
 use libc::{c_char, c_uint};
@@ -40,7 +39,7 @@ unsafe extern "C" {
     pub(crate) fn LLVMDumpValue(V: &Value);
     pub(crate) fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;
     pub(crate) fn LLVMGetReturnType(T: &Type) -> &Type;
-    pub(crate) fn LLVMGetParams(Fnc: &Value, parms: *mut &Value);
+    pub(crate) fn LLVMGetParams(Fnc: &Value, params: *mut &Value);
     pub(crate) fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>;
 }
 
@@ -57,14 +56,19 @@ pub(crate) use self::Enzyme_AD::*;
 
 #[cfg(llvm_enzyme)]
 pub(crate) mod Enzyme_AD {
+    use std::ffi::{CString, c_char};
+
     use libc::c_void;
+
     unsafe extern "C" {
         pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
+        pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
     }
     unsafe extern "C" {
         static mut EnzymePrintPerf: c_void;
         static mut EnzymePrintActivity: c_void;
         static mut EnzymePrintType: c_void;
+        static mut EnzymeFunctionToAnalyze: c_void;
         static mut EnzymePrint: c_void;
         static mut EnzymeStrictAliasing: c_void;
         static mut looseTypeAnalysis: c_void;
@@ -86,6 +90,15 @@ pub(crate) mod Enzyme_AD {
             EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
         }
     }
+    pub(crate) fn set_print_type_fun(fun_name: &str) {
+        let c_fun_name = CString::new(fun_name).unwrap();
+        unsafe {
+            EnzymeSetCLString(
+                std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze),
+                c_fun_name.as_ptr() as *const c_char,
+            );
+        }
+    }
     pub(crate) fn set_print(print: bool) {
         unsafe {
             EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
@@ -132,6 +145,9 @@ pub(crate) mod Fallback_AD {
     pub(crate) fn set_print_type(print: bool) {
         unimplemented!()
     }
+    pub(crate) fn set_print_type_fun(fun_name: &str) {
+        unimplemented!()
+    }
     pub(crate) fn set_print(print: bool) {
         unimplemented!()
     }