diff options
| author | Jana Dönszelmann <jonathan@donsz.nl> | 2025-06-25 22:14:55 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-06-25 22:14:55 +0200 |
| commit | 69b11c64ebf9798984f8ceb250cb1b755488aaf1 (patch) | |
| tree | e1b74e2accefb185ea36ca3d9478c78ca05f518c /compiler/rustc_codegen_llvm/src | |
| parent | 63c5a84b74091970026e9ff023ff4ba1fe01db44 (diff) | |
| parent | 7b1c89f2b58515bba671de38ba974edb58429973 (diff) | |
| download | rust-69b11c64ebf9798984f8ceb250cb1b755488aaf1.tar.gz rust-69b11c64ebf9798984f8ceb250cb1b755488aaf1.zip | |
Rollup merge of #142809 - KMJ-007:ad-type-analysis-flag, r=ZuseZ4
Add PrintTAFn flag for targeted type analysis printing ## Summary This PR adds a new `PrintTAFn` flag to the `-Z autodiff` option that allows printing type analysis information for a specific function, rather than all functions. ## Changes ### New Flag - Added `PrintTAFn=<function_name>` option to `-Z autodiff` - Usage: `-Z autodiff=Enable,PrintTAFn=my_function_name` ### Implementation Details - **Rust side**: Added `PrintTAFn(String)` variant to `AutoDiff` enum - **Parser**: Updated `parse_autodiff` to handle `PrintTAFn=<function_name>` syntax with proper error handling - **FFI**: Added `set_print_type_fun` function to interface with Enzyme's `FunctionToAnalyze` command line option - **Documentation**: Updated help text and documentation for the new flag ### Files Modified - `compiler/rustc_session/src/config.rs`: Added `PrintTAFn(String)` variant - `compiler/rustc_session/src/options.rs`: Updated parser and help text (now shows `PrintTAFn` in the list) - `compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs`: Added FFI function and static variable - `compiler/rustc_codegen_llvm/src/back/lto.rs`: Added handling for new flag - `src/doc/rustc-dev-guide/src/autodiff/flags.md`: Updated documentation - `src/doc/unstable-book/src/compiler-flags/autodiff.md`: Updated documentation ## Testing The flag can be tested with: ```bash rustc +enzyme -Z autodiff=Enable,PrintTAFn=square test.rs ``` This will print type analysis information only for the function named "square" instead of all functions. ## Error Handling The parser includes proper error handling: - Missing argument: `PrintTAFn` without `=<function_name>` will show an error - Unknown options: Invalid autodiff options will be reported r? ```@ZuseZ4```
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/lto.rs | 6 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 17 |
2 files changed, 22 insertions, 1 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index ee46b49a094..9c62244f3c9 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -587,7 +587,7 @@ fn thin_lto( } fn enable_autodiff_settings(ad: &[config::AutoDiff]) { - for &val in ad { + for val in ad { // We intentionally don't use a wildcard, to not forget handling anything new. match val { config::AutoDiff::PrintPerf => { @@ -599,6 +599,10 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) { config::AutoDiff::PrintTA => { llvm::set_print_type(true); } + config::AutoDiff::PrintTAFn(fun) => { + llvm::set_print_type(true); // Enable general type printing + llvm::set_print_type_fun(&fun); // Set specific function to analyze + } config::AutoDiff::Inline => { llvm::set_inline(true); } diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 2ad39fc8538..b94716b89d6 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -57,14 +57,19 @@ pub(crate) use self::Enzyme_AD::*; #[cfg(llvm_enzyme)] pub(crate) mod Enzyme_AD { + use std::ffi::{CString, c_char}; + use libc::c_void; + unsafe extern "C" { pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); + pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char); } unsafe extern "C" { static mut EnzymePrintPerf: c_void; static mut EnzymePrintActivity: c_void; static mut EnzymePrintType: c_void; + static mut EnzymeFunctionToAnalyze: c_void; static mut EnzymePrint: c_void; static mut EnzymeStrictAliasing: c_void; static mut looseTypeAnalysis: c_void; @@ -86,6 +91,15 @@ pub(crate) mod Enzyme_AD { EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8); } } + pub(crate) fn set_print_type_fun(fun_name: &str) { + let c_fun_name = CString::new(fun_name).unwrap(); + unsafe { + EnzymeSetCLString( + std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze), + c_fun_name.as_ptr() as *const c_char, + ); + } + } pub(crate) fn set_print(print: bool) { unsafe { EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8); @@ -132,6 +146,9 @@ pub(crate) mod Fallback_AD { pub(crate) fn set_print_type(print: bool) { unimplemented!() } + pub(crate) fn set_print_type_fun(fun_name: &str) { + unimplemented!() + } pub(crate) fn set_print(print: bool) { unimplemented!() } |
