about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs6
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs17
-rw-r--r--compiler/rustc_session/src/config.rs4
-rw-r--r--compiler/rustc_session/src/options.rs17
-rw-r--r--src/doc/rustc-dev-guide/src/autodiff/flags.md1
-rw-r--r--src/doc/unstable-book/src/compiler-flags/autodiff.md1
m---------src/tools/enzyme0
7 files changed, 42 insertions, 4 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index ee46b49a094..9c62244f3c9 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -587,7 +587,7 @@ fn thin_lto(
 }
 
 fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
-    for &val in ad {
+    for val in ad {
         // We intentionally don't use a wildcard, to not forget handling anything new.
         match val {
             config::AutoDiff::PrintPerf => {
@@ -599,6 +599,10 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
             config::AutoDiff::PrintTA => {
                 llvm::set_print_type(true);
             }
+            config::AutoDiff::PrintTAFn(fun) => {
+                llvm::set_print_type(true); // Enable general type printing
+                llvm::set_print_type_fun(&fun); // Set specific function to analyze
+            }
             config::AutoDiff::Inline => {
                 llvm::set_inline(true);
             }
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index 2ad39fc8538..b94716b89d6 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -57,14 +57,19 @@ pub(crate) use self::Enzyme_AD::*;
 
 #[cfg(llvm_enzyme)]
 pub(crate) mod Enzyme_AD {
+    use std::ffi::{CString, c_char};
+
     use libc::c_void;
+
     unsafe extern "C" {
         pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
+        pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char);
     }
     unsafe extern "C" {
         static mut EnzymePrintPerf: c_void;
         static mut EnzymePrintActivity: c_void;
         static mut EnzymePrintType: c_void;
+        static mut EnzymeFunctionToAnalyze: c_void;
         static mut EnzymePrint: c_void;
         static mut EnzymeStrictAliasing: c_void;
         static mut looseTypeAnalysis: c_void;
@@ -86,6 +91,15 @@ pub(crate) mod Enzyme_AD {
             EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
         }
     }
+    pub(crate) fn set_print_type_fun(fun_name: &str) {
+        let c_fun_name = CString::new(fun_name).unwrap();
+        unsafe {
+            EnzymeSetCLString(
+                std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze),
+                c_fun_name.as_ptr() as *const c_char,
+            );
+        }
+    }
     pub(crate) fn set_print(print: bool) {
         unsafe {
             EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
@@ -132,6 +146,9 @@ pub(crate) mod Fallback_AD {
     pub(crate) fn set_print_type(print: bool) {
         unimplemented!()
     }
+    pub(crate) fn set_print_type_fun(fun_name: &str) {
+        unimplemented!()
+    }
     pub(crate) fn set_print(print: bool) {
         unimplemented!()
     }
diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs
index c62e4ac30ea..73bb0471c22 100644
--- a/compiler/rustc_session/src/config.rs
+++ b/compiler/rustc_session/src/config.rs
@@ -227,13 +227,15 @@ pub enum CoverageLevel {
 }
 
 /// The different settings that the `-Z autodiff` flag can have.
-#[derive(Clone, Copy, PartialEq, Hash, Debug)]
+#[derive(Clone, PartialEq, Hash, Debug)]
 pub enum AutoDiff {
     /// Enable the autodiff opt pipeline
     Enable,
 
     /// Print TypeAnalysis information
     PrintTA,
+    /// Print TypeAnalysis information for a specific function
+    PrintTAFn(String),
     /// Print ActivityAnalysis Information
     PrintAA,
     /// Print Performance Warnings from Enzyme
diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs
index 232531dc673..ecd82c0cc01 100644
--- a/compiler/rustc_session/src/options.rs
+++ b/compiler/rustc_session/src/options.rs
@@ -725,7 +725,7 @@ mod desc {
     pub(crate) const parse_list: &str = "a space-separated list of strings";
     pub(crate) const parse_list_with_polarity: &str =
         "a comma-separated list of strings, with elements beginning with + or -";
-    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
+    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
     pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
     pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
     pub(crate) const parse_number: &str = "a number";
@@ -1365,9 +1365,22 @@ pub mod parse {
         let mut v: Vec<&str> = v.split(",").collect();
         v.sort_unstable();
         for &val in v.iter() {
-            let variant = match val {
+            // Split each entry on '=' if it has an argument
+            let (key, arg) = match val.split_once('=') {
+                Some((k, a)) => (k, Some(a)),
+                None => (val, None),
+            };
+
+            let variant = match key {
                 "Enable" => AutoDiff::Enable,
                 "PrintTA" => AutoDiff::PrintTA,
+                "PrintTAFn" => {
+                    if let Some(fun) = arg {
+                        AutoDiff::PrintTAFn(fun.to_string())
+                    } else {
+                        return false;
+                    }
+                }
                 "PrintAA" => AutoDiff::PrintAA,
                 "PrintPerf" => AutoDiff::PrintPerf,
                 "PrintSteps" => AutoDiff::PrintSteps,
diff --git a/src/doc/rustc-dev-guide/src/autodiff/flags.md b/src/doc/rustc-dev-guide/src/autodiff/flags.md
index 65287d9ba4c..efbb9ea3497 100644
--- a/src/doc/rustc-dev-guide/src/autodiff/flags.md
+++ b/src/doc/rustc-dev-guide/src/autodiff/flags.md
@@ -6,6 +6,7 @@ To support you while debugging or profiling, we have added support for an experi
 
 ```text
 PrintTA // Print TypeAnalysis information
+PrintTAFn // Print TypeAnalysis information for a specific function
 PrintAA // Print ActivityAnalysis information
 Print // Print differentiated functions while they are being generated and optimized
 PrintPerf // Print AD related Performance warnings
diff --git a/src/doc/unstable-book/src/compiler-flags/autodiff.md b/src/doc/unstable-book/src/compiler-flags/autodiff.md
index 95c188d1f3b..28d2ece1468 100644
--- a/src/doc/unstable-book/src/compiler-flags/autodiff.md
+++ b/src/doc/unstable-book/src/compiler-flags/autodiff.md
@@ -10,6 +10,7 @@ Multiple options can be separated with a comma. Valid options are:
 
 `Enable` - Required flag to enable autodiff
 `PrintTA` - print Type Analysis Information
+`PrintTAFn` - print Type Analysis Information for a specific function
 `PrintAA` - print Activity Analysis Information
 `PrintPerf` - print Performance Warnings from Enzyme
 `PrintSteps` - prints all intermediate transformations
diff --git a/src/tools/enzyme b/src/tools/enzyme
-Subproject a35f4f773118ccfbd8d05102eb12a34097b1ee5
+Subproject b5098d515d5e1bd0f5470553bc0d18da9794ca8