about summary refs log tree commit diff
path: root/compiler/rustc_session
diff options
context:
space:
mode:
authorJacob Pratt <jacob@jhpratt.dev>2025-01-31 00:26:30 -0500
committerGitHub <noreply@github.com>2025-01-31 00:26:30 -0500
commitc19c4b91f53dbe9bda18171b493514ed2bbec2b2 (patch)
treedf4f28f749f427d1e391919bc5cb937fe6f247c0 /compiler/rustc_session
parentae9dbf169fa43b45617107d8bd90eced26b1dea2 (diff)
parent1f30517d40a9a8fe3b89479891c7a167adb75cbd (diff)
downloadrust-c19c4b91f53dbe9bda18171b493514ed2bbec2b2.tar.gz
rust-c19c4b91f53dbe9bda18171b493514ed2bbec2b2.zip
Rollup merge of #133429 - EnzymeAD:autodiff-middle, r=oli-obk
Autodiff Upstreaming - rustc_codegen_ssa, rustc_middle

This PR should not be merged until the rustc_codegen_llvm part is merged.
I will also alter it a little based on what get's shaved off from the cg_llvm PR,
and address some of the feedback I received in the other PR (including cleanups).

I am putting it already up to
1) Discuss with `@jieyouxu` if there is more work needed to add tests to this and
2) Pray that there is someone reviewing who can tell me why some of my autodiff invocations get lost.

Re 1: My test require fat-lto. I also modify the compilation pipeline. So if there are any other llvm-ir tests in the same compilation unit then I will likely break them. Luckily there are two groups who currently have the same fat-lto requirement for their GPU code which I have for my autodiff code and both groups have some plans to enable support for thin-lto. Once either that work pans out, I'll copy it over for this feature. I will also work on not changing the optimization pipeline for functions not differentiated, but that will require some thoughts and engineering, so I think it would be good to be able to run the autodiff tests isolated from the rest for now. Can you guide me here please?
For context, here are some of my tests in the samples folder: https://github.com/EnzymeAD/rustbook

Re 2: This is a pretty serious issue, since it effectively prevents publishing libraries making use of autodiff: https://github.com/EnzymeAD/rust/issues/173. For some reason my dummy code persists till the end, so the code which calls autodiff, deletes the dummy, and inserts the code to compute the derivative never gets executed. To me it looks like the rustc_autodiff attribute just get's dropped, but I don't know WHY? Any help would be super appreciated, as rustc queries look a bit voodoo to me.

Tracking:

- https://github.com/rust-lang/rust/issues/124509

r? `@jieyouxu`
Diffstat (limited to 'compiler/rustc_session')
-rw-r--r--compiler/rustc_session/src/config.rs36
-rw-r--r--compiler/rustc_session/src/options.rs49
2 files changed, 84 insertions, 1 deletions
diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs
index 97bd2670aa6..c8a811985d5 100644
--- a/compiler/rustc_session/src/config.rs
+++ b/compiler/rustc_session/src/config.rs
@@ -189,6 +189,39 @@ pub enum CoverageLevel {
     Mcdc,
 }
 
+/// The different settings that the `-Z autodiff` flag can have.
+#[derive(Clone, Copy, PartialEq, Hash, Debug)]
+pub enum AutoDiff {
+    /// Print TypeAnalysis information
+    PrintTA,
+    /// Print ActivityAnalysis Information
+    PrintAA,
+    /// Print Performance Warnings from Enzyme
+    PrintPerf,
+    /// Combines the three print flags above.
+    Print,
+    /// Print the whole module, before running opts.
+    PrintModBefore,
+    /// Print the whole module just before we pass it to Enzyme.
+    /// For Debug purpose, prefer the OPT flag below
+    PrintModAfterOpts,
+    /// Print the module after Enzyme differentiated everything.
+    PrintModAfterEnzyme,
+
+    /// Enzyme's loose type debug helper (can cause incorrect gradients)
+    LooseTypes,
+
+    /// More flags
+    NoModOptAfter,
+    /// Tell Enzyme to run LLVM Opts on each function it generated. By default off,
+    /// since we already optimize the whole module after Enzyme is done.
+    EnableFncOpt,
+    NoVecUnroll,
+    RuntimeActivity,
+    /// Runs Enzyme specific Inlining
+    Inline,
+}
+
 /// Settings for `-Z instrument-xray` flag.
 #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
 pub struct InstrumentXRay {
@@ -2902,7 +2935,7 @@ pub(crate) mod dep_tracking {
     };
 
     use super::{
-        BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions,
+        AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions,
         CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn,
         InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail,
         LtoCli, MirStripDebugInfo, NextSolverConfig, OomStrategy, OptLevel, OutFileName,
@@ -2950,6 +2983,7 @@ pub(crate) mod dep_tracking {
     }
 
     impl_dep_tracking_hash_via_hash!(
+        AutoDiff,
         bool,
         usize,
         NonZero<usize>,
diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs
index c7af40a8ad4..e99f907c578 100644
--- a/compiler/rustc_session/src/options.rs
+++ b/compiler/rustc_session/src/options.rs
@@ -398,6 +398,7 @@ mod desc {
     pub(crate) const parse_list: &str = "a space-separated list of strings";
     pub(crate) const parse_list_with_polarity: &str =
         "a comma-separated list of strings, with elements beginning with + or -";
+    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Print`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfterOpts`, `PrintModAfterEnzyme`, `LooseTypes`, `NoModOptAfter`, `EnableFncOpt`, `NoVecUnroll`, `Inline`";
     pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
     pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
     pub(crate) const parse_number: &str = "a number";
@@ -1029,6 +1030,38 @@ pub mod parse {
         }
     }
 
+    pub(crate) fn parse_autodiff(slot: &mut Vec<AutoDiff>, v: Option<&str>) -> bool {
+        let Some(v) = v else {
+            *slot = vec![];
+            return true;
+        };
+        let mut v: Vec<&str> = v.split(",").collect();
+        v.sort_unstable();
+        for &val in v.iter() {
+            let variant = match val {
+                "PrintTA" => AutoDiff::PrintTA,
+                "PrintAA" => AutoDiff::PrintAA,
+                "PrintPerf" => AutoDiff::PrintPerf,
+                "Print" => AutoDiff::Print,
+                "PrintModBefore" => AutoDiff::PrintModBefore,
+                "PrintModAfterOpts" => AutoDiff::PrintModAfterOpts,
+                "PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme,
+                "LooseTypes" => AutoDiff::LooseTypes,
+                "NoModOptAfter" => AutoDiff::NoModOptAfter,
+                "EnableFncOpt" => AutoDiff::EnableFncOpt,
+                "NoVecUnroll" => AutoDiff::NoVecUnroll,
+                "Inline" => AutoDiff::Inline,
+                _ => {
+                    // FIXME(ZuseZ4): print an error saying which value is not recognized
+                    return false;
+                }
+            };
+            slot.push(variant);
+        }
+
+        true
+    }
+
     pub(crate) fn parse_instrument_coverage(
         slot: &mut InstrumentCoverage,
         v: Option<&str>,
@@ -1736,6 +1769,22 @@ options! {
          either `loaded` or `not-loaded`."),
     assume_incomplete_release: bool = (false, parse_bool, [TRACKED],
         "make cfg(version) treat the current version as incomplete (default: no)"),
+    autodiff: Vec<crate::config::AutoDiff> = (Vec::new(), parse_autodiff, [TRACKED],
+        "a list of optional autodiff flags to enable
+         Optional extra settings:
+         `=PrintTA`
+         `=PrintAA`
+         `=PrintPerf`
+         `=Print`
+         `=PrintModBefore`
+         `=PrintModAfterOpts`
+         `=PrintModAfterEnzyme`
+         `=LooseTypes`
+         `=NoModOptAfter`
+         `=EnableFncOpt`
+         `=NoVecUnroll`
+         `=Inline`
+         Multiple options can be combined with commas."),
     #[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]
     binary_dep_depinfo: bool = (false, parse_bool, [TRACKED],
         "include artifacts (sysroot, crate dependencies) used during compilation in dep-info \