about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs40
-rw-r--r--compiler/rustc_codegen_llvm/src/back/write.rs16
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs2
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/ffi.rs3
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp30
-rw-r--r--compiler/rustc_session/src/config.rs4
-rw-r--r--compiler/rustc_session/src/options.rs6
-rw-r--r--tests/codegen/autodiff/identical_fnc.rs45
8 files changed, 123 insertions, 23 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index a8b49e9552c..925898d8173 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -584,12 +584,10 @@ fn thin_lto(
     }
 }
 
-fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
+fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
     for &val in ad {
+        // We intentionally don't use a wildcard, to not forget handling anything new.
         match val {
-            config::AutoDiff::PrintModBefore => {
-                unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
-            }
             config::AutoDiff::PrintPerf => {
                 llvm::set_print_perf(true);
             }
@@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
                 llvm::set_inline(true);
             }
             config::AutoDiff::LooseTypes => {
-                llvm::set_loose_types(false);
+                llvm::set_loose_types(true);
             }
             config::AutoDiff::PrintSteps => {
                 llvm::set_print(true);
             }
-            // We handle this below
+            // We handle this in the PassWrapper.cpp
+            config::AutoDiff::PrintPasses => {}
+            // We handle this in the PassWrapper.cpp
+            config::AutoDiff::PrintModBefore => {}
+            // We handle this in the PassWrapper.cpp
             config::AutoDiff::PrintModAfter => {}
-            // We handle this below
+            // We handle this in the PassWrapper.cpp
             config::AutoDiff::PrintModFinal => {}
             // This is required and already checked
             config::AutoDiff::Enable => {}
+            // We handle this below
+            config::AutoDiff::NoPostopt => {}
         }
     }
     // This helps with handling enums for now.
@@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
     // We then run the llvm_optimize function a second time, to optimize the code which we generated
     // in the enzyme differentiation pass.
     let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
-    let stage =
-        if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
+    let stage = if thin {
+        write::AutodiffStage::PreAD
+    } else {
+        if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
+    };
 
     if enable_ad {
-        enable_autodiff_settings(&config.autodiff, module);
+        enable_autodiff_settings(&config.autodiff);
     }
 
     unsafe {
         write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
     }
 
-    if cfg!(llvm_enzyme) && enable_ad {
-        // This is the post-autodiff IR, mainly used for testing and educational purposes.
-        if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
-            unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
-        }
-
+    if cfg!(llvm_enzyme) && enable_ad && !thin {
         let opt_stage = llvm::OptStage::FatLTO;
         let stage = write::AutodiffStage::PostAD;
-        unsafe {
-            write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
+        if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
+            unsafe {
+                write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
+            }
         }
 
         // This is the final IR, so people should be able to inspect the optimized autodiff output,
diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs
index 18d221d232e..4ac77c8f7f1 100644
--- a/compiler/rustc_codegen_llvm/src/back/write.rs
+++ b/compiler/rustc_codegen_llvm/src/back/write.rs
@@ -572,6 +572,10 @@ pub(crate) unsafe fn llvm_optimize(
 
     let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
     let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
+    let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore);
+    let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter);
+    let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses);
+    let merge_functions;
     let unroll_loops;
     let vectorize_slp;
     let vectorize_loop;
@@ -579,13 +583,20 @@ pub(crate) unsafe fn llvm_optimize(
     // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
     // optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
     // We therefore have two calls to llvm_optimize, if autodiff is used.
+    //
+    // We also must disable merge_functions, since autodiff placeholder/dummy bodies tend to be
+    // identical. We run opts before AD, so there is a chance that LLVM will merge our dummies.
+    // In that case, we lack some dummy bodies and can't replace them with the real AD code anymore.
+    // We then would need to abort compilation. This was especially common in test cases.
     if consider_ad && autodiff_stage != AutodiffStage::PostAD {
+        merge_functions = false;
         unroll_loops = false;
         vectorize_slp = false;
         vectorize_loop = false;
     } else {
         unroll_loops =
             opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
+        merge_functions = config.merge_functions;
         vectorize_slp = config.vectorize_slp;
         vectorize_loop = config.vectorize_loop;
     }
@@ -663,13 +674,16 @@ pub(crate) unsafe fn llvm_optimize(
             thin_lto_buffer,
             config.emit_thin_lto,
             config.emit_thin_lto_summary,
-            config.merge_functions,
+            merge_functions,
             unroll_loops,
             vectorize_slp,
             vectorize_loop,
             config.no_builtins,
             config.emit_lifetime_markers,
             run_enzyme,
+            print_before_enzyme,
+            print_after_enzyme,
+            print_passes,
             sanitizer_options.as_ref(),
             pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
             pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index e7c071d05aa..0147bd5a665 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -473,7 +473,7 @@ pub(crate) fn differentiate<'ll>(
         return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
     }
 
-    // Before dumping the module, we want all the TypeTrees to become part of the module.
+    // Here we replace the placeholder code with the actual autodiff code, which calls Enzyme.
     for item in diff_items.iter() {
         let name = item.source.clone();
         let fn_def: Option<&llvm::Value> = cx.get_function(&name);
diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
index 9ff04f72903..ffb490dcdc2 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs
@@ -2454,6 +2454,9 @@ unsafe extern "C" {
         DisableSimplifyLibCalls: bool,
         EmitLifetimeMarkers: bool,
         RunEnzyme: bool,
+        PrintBeforeEnzyme: bool,
+        PrintAfterEnzyme: bool,
+        PrintPasses: bool,
         SanitizerOptions: Option<&SanitizerOptions>,
         PGOGenPath: *const c_char,
         PGOUsePath: *const c_char,
diff --git a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
index e02c80c50b1..8bee051dd4c 100644
--- a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
@@ -14,6 +14,7 @@
 #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Verifier.h"
+#include "llvm/IRPrinter/IRPrintingPasses.h"
 #include "llvm/LTO/LTO.h"
 #include "llvm/MC/MCSubtargetInfo.h"
 #include "llvm/MC/TargetRegistry.h"
@@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
     bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
     bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
     bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
-    bool EmitLifetimeMarkers, bool RunEnzyme,
+    bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
+    bool PrintAfterEnzyme, bool PrintPasses,
     LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
     const char *PGOUsePath, bool InstrumentCoverage,
     const char *InstrProfileOutput, const char *PGOSampleUsePath,
@@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize(
   // now load "-enzyme" pass:
 #ifdef ENZYME
   if (RunEnzyme) {
-    registerEnzymeAndPassPipeline(PB, true);
+
+    if (PrintBeforeEnzyme) {
+      // Handle the Rust flag `-Zautodiff=PrintModBefore`.
+      std::string Banner = "Module before EnzymeNewPM";
+      MPM.addPass(PrintModulePass(outs(), Banner, true, false));
+    }
+
+    registerEnzymeAndPassPipeline(PB, false);
     if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
       std::string ErrMsg = toString(std::move(Err));
       LLVMRustSetLastError(ErrMsg.c_str());
       return LLVMRustResult::Failure;
     }
+
+    if (PrintAfterEnzyme) {
+      // Handle the Rust flag `-Zautodiff=PrintModAfter`.
+      std::string Banner = "Module after EnzymeNewPM";
+      MPM.addPass(PrintModulePass(outs(), Banner, true, false));
+    }
   }
 #endif
+  if (PrintPasses) {
+    // Print all passes from the PM:
+    std::string Pipeline;
+    raw_string_ostream SOS(Pipeline);
+    MPM.printPipeline(SOS, [&PIC](StringRef ClassName) {
+      auto PassName = PIC.getPassNameForClassName(ClassName);
+      return PassName.empty() ? ClassName : PassName;
+    });
+    outs() << Pipeline;
+    outs() << "\n";
+  }
 
   // Upgrade all calls to old intrinsics first.
   for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)
diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs
index bc92b95ce71..02c164a706c 100644
--- a/compiler/rustc_session/src/config.rs
+++ b/compiler/rustc_session/src/config.rs
@@ -246,6 +246,10 @@ pub enum AutoDiff {
     /// Print the module after running autodiff and optimizations.
     PrintModFinal,
 
+    /// Print all passes scheduled by LLVM
+    PrintPasses,
+    /// Disable extra opt run after running autodiff
+    NoPostopt,
     /// Enzyme's loose type debug helper (can cause incorrect gradients!!)
     /// Usable in cases where Enzyme errors with `can not deduce type of X`.
     LooseTypes,
diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs
index 36eee5f3086..5f4695fb184 100644
--- a/compiler/rustc_session/src/options.rs
+++ b/compiler/rustc_session/src/options.rs
@@ -711,7 +711,7 @@ mod desc {
     pub(crate) const parse_list: &str = "a space-separated list of strings";
     pub(crate) const parse_list_with_polarity: &str =
         "a comma-separated list of strings, with elements beginning with + or -";
-    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
+    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
     pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
     pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
     pub(crate) const parse_number: &str = "a number";
@@ -1360,6 +1360,8 @@ pub mod parse {
                 "PrintModBefore" => AutoDiff::PrintModBefore,
                 "PrintModAfter" => AutoDiff::PrintModAfter,
                 "PrintModFinal" => AutoDiff::PrintModFinal,
+                "NoPostopt" => AutoDiff::NoPostopt,
+                "PrintPasses" => AutoDiff::PrintPasses,
                 "LooseTypes" => AutoDiff::LooseTypes,
                 "Inline" => AutoDiff::Inline,
                 _ => {
@@ -2098,6 +2100,8 @@ options! {
         `=PrintModBefore`
         `=PrintModAfter`
         `=PrintModFinal`
+        `=PrintPasses`,
+        `=NoPostopt`
         `=LooseTypes`
         `=Inline`
         Multiple options can be combined with commas."),
diff --git a/tests/codegen/autodiff/identical_fnc.rs b/tests/codegen/autodiff/identical_fnc.rs
new file mode 100644
index 00000000000..1c3277f52b4
--- /dev/null
+++ b/tests/codegen/autodiff/identical_fnc.rs
@@ -0,0 +1,45 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+//
+// Each autodiff invocation creates a new placeholder function, which we will replace on llvm-ir
+// level. If a user tries to differentiate two identical functions within the same compilation unit,
+// then LLVM might merge them in release mode before AD. In that case we can't rewrite one of the
+// merged placeholder function anymore, and compilation would fail. We prevent this by disabling
+// LLVM's merge_function pass before AD. Here we implicetely test that our solution keeps working.
+// We also explicetly test that we keep running merge_function after AD, by checking for two
+// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
+#![feature(autodiff)]
+
+use std::autodiff::autodiff;
+
+#[autodiff(d_square, Reverse, Duplicated, Active)]
+fn square(x: &f64) -> f64 {
+    x * x
+}
+
+#[autodiff(d_square2, Reverse, Duplicated, Active)]
+fn square2(x: &f64) -> f64 {
+    x * x
+}
+
+// CHECK:; identical_fnc::main
+// CHECK-NEXT:; Function Attrs:
+// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
+// CHECK-NEXT:start:
+// CHECK-NOT:br
+// CHECK-NOT:ret
+// CHECK:; call identical_fnc::d_square
+// CHECK-NEXT:  call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
+// CHECK-NEXT:; call identical_fnc::d_square
+// CHECK-NEXT:  call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
+
+fn main() {
+    let x = std::hint::black_box(3.0);
+    let mut dx1 = std::hint::black_box(1.0);
+    let mut dx2 = std::hint::black_box(1.0);
+    let _ = d_square(&x, &mut dx1, 1.0);
+    let _ = d_square2(&x, &mut dx2, 1.0);
+    assert_eq!(dx1, 6.0);
+    assert_eq!(dx2, 6.0);
+}