diff options
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/lto.rs | 6 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder.rs | 3 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 17 |
3 files changed, 23 insertions, 3 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index ee46b49a094..9c62244f3c9 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -587,7 +587,7 @@ fn thin_lto( } fn enable_autodiff_settings(ad: &[config::AutoDiff]) { - for &val in ad { + for val in ad { // We intentionally don't use a wildcard, to not forget handling anything new. match val { config::AutoDiff::PrintPerf => { @@ -599,6 +599,10 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) { config::AutoDiff::PrintTA => { llvm::set_print_type(true); } + config::AutoDiff::PrintTAFn(fun) => { + llvm::set_print_type(true); // Enable general type printing + llvm::set_print_type_fun(&fun); // Set specific function to analyze + } config::AutoDiff::Inline => { llvm::set_inline(true); } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 5e9594dd06b..d0aa7320b4b 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -1166,11 +1166,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { (self.extract_value(landing_pad, 0), self.extract_value(landing_pad, 1)) } - fn filter_landing_pad(&mut self, pers_fn: &'ll Value) -> (&'ll Value, &'ll Value) { + fn filter_landing_pad(&mut self, pers_fn: &'ll Value) { let ty = self.type_struct(&[self.type_ptr(), self.type_i32()], false); let landing_pad = self.landing_pad(ty, pers_fn, 1); self.add_clause(landing_pad, self.const_array(self.type_ptr(), &[])); - (self.extract_value(landing_pad, 0), self.extract_value(landing_pad, 1)) } fn resume(&mut self, exn0: &'ll Value, exn1: &'ll Value) { diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 2ad39fc8538..b94716b89d6 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -57,14 +57,19 @@ pub(crate) use self::Enzyme_AD::*; #[cfg(llvm_enzyme)] pub(crate) mod Enzyme_AD { + use std::ffi::{CString, c_char}; + use libc::c_void; + unsafe extern "C" { pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); + pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char); } unsafe extern "C" { static mut EnzymePrintPerf: c_void; static mut EnzymePrintActivity: c_void; static mut EnzymePrintType: c_void; + static mut EnzymeFunctionToAnalyze: c_void; static mut EnzymePrint: c_void; static mut EnzymeStrictAliasing: c_void; static mut looseTypeAnalysis: c_void; @@ -86,6 +91,15 @@ pub(crate) mod Enzyme_AD { EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8); } } + pub(crate) fn set_print_type_fun(fun_name: &str) { + let c_fun_name = CString::new(fun_name).unwrap(); + unsafe { + EnzymeSetCLString( + std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze), + c_fun_name.as_ptr() as *const c_char, + ); + } + } pub(crate) fn set_print(print: bool) { unsafe { EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8); @@ -132,6 +146,9 @@ pub(crate) mod Fallback_AD { pub(crate) fn set_print_type(print: bool) { unimplemented!() } + pub(crate) fn set_print_type_fun(fun_name: &str) { + unimplemented!() + } pub(crate) fn set_print(print: bool) { unimplemented!() } |
