about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMatthias Krüger <476013+matthiaskrgr@users.noreply.github.com>2025-04-24 17:19:44 +0200
committerGitHub <noreply@github.com>2025-04-24 17:19:44 +0200
commitc3f811f02ff3fe52d7ceaa7aa15f44059221a48d (patch)
treeb9200bcf66cdd61ba88242db949d8c5c0b367345
parent27eac4b6d31731f59917d623d726698f9ec9ddc2 (diff)
parentf79a992bffaad56503ef483a9ad830d66dfe3311 (diff)
downloadrust-c3f811f02ff3fe52d7ceaa7aa15f44059221a48d.tar.gz
rust-c3f811f02ff3fe52d7ceaa7aa15f44059221a48d.zip
Rollup merge of #139700 - EnzymeAD:autodiff-flags, r=oli-obk
Autodiff flags

Interestingly, it seems that some other projects have conflicts with exactly the same LLVM optimization passes as autodiff.
At least `LLVMRustOptimize` has exactly the flags that we need to disable problematic opt passes.

This PR enables us to compile code where users differentiate two identical functions in the same module. This has been especially common in test cases, but it's not impossible to encounter in the wild.

It also enables two new flags for testing/debugging. I consider writing an MCP to upgrade PrintPasses to be a standalone -Z flag, since it is *not* the same as `-Z print-llvm-passes`, which IMHO gives less useful output. A discussion can be found here: [#t-compiler/llvm > Print llvm passes. @ 💬](https://rust-lang.zulipchat.com/#narrow/channel/187780-t-compiler.2Fllvm/topic/Print.20llvm.20passes.2E/near/511533038)

Finally, it improves `PrintModBefore` and `PrintModAfter`. They used to work reliable, but now we just schedule enzyme as part of an existing ModulePassManager (MPM). Since Enzyme is last in the MPM scheduling, PrintModBefore became very inaccurate. It used to print the input module, which we gave to the Enzyme and was great to create llvm-ir reproducer. However, lately the MPM would run the whole `default<O3>` pipeline, which heavily modifies the llvm module, before we pass it to Enzyme. That made it impossible to use the flag to create llvm-ir reproducers for Enzyme bugs. We now schedule a PrintModule pass just before Enzyme, solving this problem.

Based on the PrintPass output, it also _seems_ like changing `registerEnzymeAndPassPipeline(PB, true);` to `registerEnzymeAndPassPipeline(PB, false);` has no effect. In theory, the bool should tell Enzyme to schedule some helpful passes in the PassBuilder. However, since it doesn't do anything and I'm not 100% sure anymore on whether we really need it, I'll just disable it for now and postpone investigations.

r? ``@oli-obk``

closes #139471

Tracking:

- https://github.com/rust-lang/rust/issues/124509
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs40
-rw-r--r--compiler/rustc_codegen_llvm/src/back/write.rs16
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs2
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/ffi.rs3
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp30
-rw-r--r--compiler/rustc_session/src/config.rs4
-rw-r--r--compiler/rustc_session/src/options.rs6
-rw-r--r--tests/codegen/autodiff/identical_fnc.rs45
8 files changed, 123 insertions, 23 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index a8b49e9552c..925898d8173 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -584,12 +584,10 @@ fn thin_lto(
     }
 }
 
-fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
+fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
     for &val in ad {
+        // We intentionally don't use a wildcard, to not forget handling anything new.
         match val {
-            config::AutoDiff::PrintModBefore => {
-                unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
-            }
             config::AutoDiff::PrintPerf => {
                 llvm::set_print_perf(true);
             }
@@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
                 llvm::set_inline(true);
             }
             config::AutoDiff::LooseTypes => {
-                llvm::set_loose_types(false);
+                llvm::set_loose_types(true);
             }
             config::AutoDiff::PrintSteps => {
                 llvm::set_print(true);
             }
-            // We handle this below
+            // We handle this in the PassWrapper.cpp
+            config::AutoDiff::PrintPasses => {}
+            // We handle this in the PassWrapper.cpp
+            config::AutoDiff::PrintModBefore => {}
+            // We handle this in the PassWrapper.cpp
             config::AutoDiff::PrintModAfter => {}
-            // We handle this below
+            // We handle this in the PassWrapper.cpp
             config::AutoDiff::PrintModFinal => {}
             // This is required and already checked
             config::AutoDiff::Enable => {}
+            // We handle this below
+            config::AutoDiff::NoPostopt => {}
         }
     }
     // This helps with handling enums for now.
@@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
     // We then run the llvm_optimize function a second time, to optimize the code which we generated
     // in the enzyme differentiation pass.
     let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
-    let stage =
-        if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
+    let stage = if thin {
+        write::AutodiffStage::PreAD
+    } else {
+        if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
+    };
 
     if enable_ad {
-        enable_autodiff_settings(&config.autodiff, module);
+        enable_autodiff_settings(&config.autodiff);
     }
 
     unsafe {
         write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
     }
 
-    if cfg!(llvm_enzyme) && enable_ad {
-        // This is the post-autodiff IR, mainly used for testing and educational purposes.
-        if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
-            unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
-        }
-
+    if cfg!(llvm_enzyme) && enable_ad && !thin {
         let opt_stage = llvm::OptStage::FatLTO;
         let stage = write::AutodiffStage::PostAD;
-        unsafe {
-            write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
+        if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
+            unsafe {
+                write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
+            }
         }
 
         // This is the final IR, so people should be able to inspect the optimized autodiff output,
diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs
index 18d221d232e..4ac77c8f7f1 100644
--- a/compiler/rustc_codegen_llvm/src/back/write.rs
+++ b/compiler/rustc_codegen_llvm/src/back/write.rs
@@ -572,6 +572,10 @@ pub(crate) unsafe fn llvm_optimize(
 
     let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
     let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
+    let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore);
+    let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter);
+    let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses);
+    let merge_functions;
     let unroll_loops;
     let vectorize_slp;
     let vectorize_loop;
@@ -579,13 +583,20 @@ pub(crate) unsafe fn llvm_optimize(
     // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
     // optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
     // We therefore have two calls to llvm_optimize, if autodiff is used.
+    //
+    // We also must disable merge_functions, since autodiff placeholder/dummy bodies tend to be
+    // identical. We run opts before AD, so there is a chance that LLVM will merge our dummies.
+    // In that case, we lack some dummy bodies and can't replace them with the real AD code anymore.
+    // We then would need to abort compilation. This was especially common in test cases.
     if consider_ad && autodiff_stage != AutodiffStage::PostAD {
+        merge_functions = false;
         unroll_loops = false;
         vectorize_slp = false;
         vectorize_loop = false;
     } else {
         unroll_loops =
             opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
+        merge_functions = config.merge_functions;
         vectorize_slp = config.vectorize_slp;
         vectorize_loop = config.vectorize_loop;
     }
@@ -663,13 +674,16 @@ pub(crate) unsafe fn llvm_optimize(
             thin_lto_buffer,
             config.emit_thin_lto,
             config.emit_thin_lto_summary,
-            config.merge_functions,
+            merge_functions,
             unroll_loops,
             vectorize_slp,
             vectorize_loop,
             config.no_builtins,
             config.emit_lifetime_markers,
             run_enzyme,
+            print_before_enzyme,
+            print_after_enzyme,
+            print_passes,
             sanitizer_options.as_ref(),
             pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
             pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index e7c071d05aa..0147bd5a665 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -473,7 +473,7 @@ pub(crate) fn differentiate<'ll>(
         return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
     }
 
-    // Before dumping the module, we want all the TypeTrees to become part of the module.
+    // Here we replace the placeholder code with the actual autodiff code, which calls Enzyme.
     for item in diff_items.iter() {
         let name = item.source.clone();
         let fn_def: Option<&llvm::Value> = cx.get_function(&name);
diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
index 9ff04f72903..ffb490dcdc2 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
@@ -2454,6 +2454,9 @@ unsafe extern "C" {
         DisableSimplifyLibCalls: bool,
         EmitLifetimeMarkers: bool,
         RunEnzyme: bool,
+        PrintBeforeEnzyme: bool,
+        PrintAfterEnzyme: bool,
+        PrintPasses: bool,
         SanitizerOptions: Option<&SanitizerOptions>,
         PGOGenPath: *const c_char,
         PGOUsePath: *const c_char,
diff --git a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
index e02c80c50b1..8bee051dd4c 100644
--- a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
@@ -14,6 +14,7 @@
 #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Verifier.h"
+#include "llvm/IRPrinter/IRPrintingPasses.h"
 #include "llvm/LTO/LTO.h"
 #include "llvm/MC/MCSubtargetInfo.h"
 #include "llvm/MC/TargetRegistry.h"
@@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
     bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
     bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
     bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
-    bool EmitLifetimeMarkers, bool RunEnzyme,
+    bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
+    bool PrintAfterEnzyme, bool PrintPasses,
     LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
     const char *PGOUsePath, bool InstrumentCoverage,
     const char *InstrProfileOutput, const char *PGOSampleUsePath,
@@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize(
   // now load "-enzyme" pass:
 #ifdef ENZYME
   if (RunEnzyme) {
-    registerEnzymeAndPassPipeline(PB, true);
+
+    if (PrintBeforeEnzyme) {
+      // Handle the Rust flag `-Zautodiff=PrintModBefore`.
+      std::string Banner = "Module before EnzymeNewPM";
+      MPM.addPass(PrintModulePass(outs(), Banner, true, false));
+    }
+
+    registerEnzymeAndPassPipeline(PB, false);
     if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
       std::string ErrMsg = toString(std::move(Err));
       LLVMRustSetLastError(ErrMsg.c_str());
       return LLVMRustResult::Failure;
     }
+
+    if (PrintAfterEnzyme) {
+      // Handle the Rust flag `-Zautodiff=PrintModAfter`.
+      std::string Banner = "Module after EnzymeNewPM";
+      MPM.addPass(PrintModulePass(outs(), Banner, true, false));
+    }
   }
 #endif
+  if (PrintPasses) {
+    // Print all passes from the PM:
+    std::string Pipeline;
+    raw_string_ostream SOS(Pipeline);
+    MPM.printPipeline(SOS, [&PIC](StringRef ClassName) {
+      auto PassName = PIC.getPassNameForClassName(ClassName);
+      return PassName.empty() ? ClassName : PassName;
+    });
+    outs() << Pipeline;
+    outs() << "\n";
+  }
 
   // Upgrade all calls to old intrinsics first.
   for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)
diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs
index bc92b95ce71..02c164a706c 100644
--- a/compiler/rustc_session/src/config.rs
+++ b/compiler/rustc_session/src/config.rs
@@ -246,6 +246,10 @@ pub enum AutoDiff {
     /// Print the module after running autodiff and optimizations.
     PrintModFinal,
 
+    /// Print all passes scheduled by LLVM
+    PrintPasses,
+    /// Disable extra opt run after running autodiff
+    NoPostopt,
     /// Enzyme's loose type debug helper (can cause incorrect gradients!!)
     /// Usable in cases where Enzyme errors with `can not deduce type of X`.
     LooseTypes,
diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs
index 36eee5f3086..5f4695fb184 100644
--- a/compiler/rustc_session/src/options.rs
+++ b/compiler/rustc_session/src/options.rs
@@ -711,7 +711,7 @@ mod desc {
     pub(crate) const parse_list: &str = "a space-separated list of strings";
     pub(crate) const parse_list_with_polarity: &str =
         "a comma-separated list of strings, with elements beginning with + or -";
-    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
+    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
     pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
     pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
     pub(crate) const parse_number: &str = "a number";
@@ -1360,6 +1360,8 @@ pub mod parse {
                 "PrintModBefore" => AutoDiff::PrintModBefore,
                 "PrintModAfter" => AutoDiff::PrintModAfter,
                 "PrintModFinal" => AutoDiff::PrintModFinal,
+                "NoPostopt" => AutoDiff::NoPostopt,
+                "PrintPasses" => AutoDiff::PrintPasses,
                 "LooseTypes" => AutoDiff::LooseTypes,
                 "Inline" => AutoDiff::Inline,
                 _ => {
@@ -2098,6 +2100,8 @@ options! {
         `=PrintModBefore`
         `=PrintModAfter`
         `=PrintModFinal`
+        `=PrintPasses`,
+        `=NoPostopt`
         `=LooseTypes`
         `=Inline`
         Multiple options can be combined with commas."),
diff --git a/tests/codegen/autodiff/identical_fnc.rs b/tests/codegen/autodiff/identical_fnc.rs
new file mode 100644
index 00000000000..1c3277f52b4
--- /dev/null
+++ b/tests/codegen/autodiff/identical_fnc.rs
@@ -0,0 +1,45 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+//
+// Each autodiff invocation creates a new placeholder function, which we will replace on llvm-ir
+// level. If a user tries to differentiate two identical functions within the same compilation unit,
+// then LLVM might merge them in release mode before AD. In that case we can't rewrite one of the
+// merged placeholder function anymore, and compilation would fail. We prevent this by disabling
+// LLVM's merge_function pass before AD. Here we implicetely test that our solution keeps working.
+// We also explicetly test that we keep running merge_function after AD, by checking for two
+// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
+#![feature(autodiff)]
+
+use std::autodiff::autodiff;
+
+#[autodiff(d_square, Reverse, Duplicated, Active)]
+fn square(x: &f64) -> f64 {
+    x * x
+}
+
+#[autodiff(d_square2, Reverse, Duplicated, Active)]
+fn square2(x: &f64) -> f64 {
+    x * x
+}
+
+// CHECK:; identical_fnc::main
+// CHECK-NEXT:; Function Attrs:
+// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
+// CHECK-NEXT:start:
+// CHECK-NOT:br
+// CHECK-NOT:ret
+// CHECK:; call identical_fnc::d_square
+// CHECK-NEXT:  call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
+// CHECK-NEXT:; call identical_fnc::d_square
+// CHECK-NEXT:  call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
+
+fn main() {
+    let x = std::hint::black_box(3.0);
+    let mut dx1 = std::hint::black_box(1.0);
+    let mut dx2 = std::hint::black_box(1.0);
+    let _ = d_square(&x, &mut dx1, 1.0);
+    let _ = d_square2(&x, &mut dx2, 1.0);
+    assert_eq!(dx1, 6.0);
+    assert_eq!(dx2, 6.0);
+}