about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src
diff options
context:
space:
mode:
authorJana Dönszelmann <jonathan@donsz.nl>2025-06-25 22:14:55 +0200
committerGitHub <noreply@github.com>2025-06-25 22:14:55 +0200
commit69b11c64ebf9798984f8ceb250cb1b755488aaf1 (patch)
treee1b74e2accefb185ea36ca3d9478c78ca05f518c /compiler/rustc_codegen_llvm/src
parent63c5a84b74091970026e9ff023ff4ba1fe01db44 (diff)
parent7b1c89f2b58515bba671de38ba974edb58429973 (diff)
downloadrust-69b11c64ebf9798984f8ceb250cb1b755488aaf1.tar.gz
rust-69b11c64ebf9798984f8ceb250cb1b755488aaf1.zip
Rollup merge of #142809 - KMJ-007:ad-type-analysis-flag, r=ZuseZ4
Add PrintTAFn flag for targeted type analysis printing

## Summary
This PR adds a new `PrintTAFn` flag to the `-Z autodiff` option that allows printing type analysis information for a specific function, rather than all functions.

## Changes

### New Flag
- Added `PrintTAFn=<function_name>` option to `-Z autodiff`
- Usage: `-Z autodiff=Enable,PrintTAFn=my_function_name`

### Implementation Details
- **Rust side**: Added `PrintTAFn(String)` variant to `AutoDiff` enum
- **Parser**: Updated `parse_autodiff` to handle `PrintTAFn=<function_name>` syntax with proper error handling
- **FFI**: Added `set_print_type_fun` function to interface with Enzyme's `FunctionToAnalyze` command line option
- **Documentation**: Updated help text and documentation for the new flag

### Files Modified
- `compiler/rustc_session/src/config.rs`: Added `PrintTAFn(String)` variant
- `compiler/rustc_session/src/options.rs`: Updated parser and help text (now shows `PrintTAFn` in the list)
- `compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs`: Added FFI function and static variable
- `compiler/rustc_codegen_llvm/src/back/lto.rs`: Added handling for new flag
- `src/doc/rustc-dev-guide/src/autodiff/flags.md`: Updated documentation
- `src/doc/unstable-book/src/compiler-flags/autodiff.md`: Updated documentation

## Testing
The flag can be tested with:
```bash
rustc +enzyme -Z autodiff=Enable,PrintTAFn=square test.rs
```

This will print type analysis information only for the function named "square" instead of all functions.

## Error Handling
The parser includes proper error handling:
- Missing argument: `PrintTAFn` without `=<function_name>` will show an error
- Unknown options: Invalid autodiff options will be reported

r? ```@ZuseZ4```
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs6
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs17
2 files changed, 22 insertions, 1 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index ee46b49a094..9c62244f3c9 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -587,7 +587,7 @@ fn thin_lto(
 }
 
 fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
-    for &val in ad {
+    for val in ad {
         // We intentionally don't use a wildcard, to not forget handling anything new.
         match val {
             config::AutoDiff::PrintPerf => {
@@ -599,6 +599,10 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
             config::AutoDiff::PrintTA => {
                 llvm::set_print_type(true);
             }
+            config::AutoDiff::PrintTAFn(fun) => {
+                llvm::set_print_type(true); // Enable general type printing
+                llvm::set_print_type_fun(&fun); // Set specific function to analyze
+            }
             config::AutoDiff::Inline => {
                 llvm::set_inline(true);
             }
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index 2ad39fc8538..b94716b89d6 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -57,14 +57,19 @@ pub(crate) use self::Enzyme_AD::*;
 
 #[cfg(llvm_enzyme)]
 pub(crate) mod Enzyme_AD {
+    use std::ffi::{CString, c_char};
+
     use libc::c_void;
+
     unsafe extern "C" {
         pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
+        pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
     }
     unsafe extern "C" {
         static mut EnzymePrintPerf: c_void;
         static mut EnzymePrintActivity: c_void;
         static mut EnzymePrintType: c_void;
+        static mut EnzymeFunctionToAnalyze: c_void;
         static mut EnzymePrint: c_void;
         static mut EnzymeStrictAliasing: c_void;
         static mut looseTypeAnalysis: c_void;
@@ -86,6 +91,15 @@ pub(crate) mod Enzyme_AD {
             EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
         }
     }
+    pub(crate) fn set_print_type_fun(fun_name: &str) {
+        let c_fun_name = CString::new(fun_name).unwrap();
+        unsafe {
+            EnzymeSetCLString(
+                std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze),
+                c_fun_name.as_ptr() as *const c_char,
+            );
+        }
+    }
     pub(crate) fn set_print(print: bool) {
         unsafe {
             EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
@@ -132,6 +146,9 @@ pub(crate) mod Fallback_AD {
     pub(crate) fn set_print_type(print: bool) {
         unimplemented!()
     }
+    pub(crate) fn set_print_type_fun(fun_name: &str) {
+        unimplemented!()
+    }
     pub(crate) fn set_print(print: bool) {
         unimplemented!()
     }