diff options
| author | Manuel Drehwald <git@manuel.drehwald.info> | 2025-04-12 00:50:41 -0400 |
|---|---|---|
| committer | Manuel Drehwald <git@manuel.drehwald.info> | 2025-04-12 01:36:44 -0400 |
| commit | 75f86e6e2e07c40825e5c7e2f63537efff74a207 (patch) | |
| tree | 858fb7e8db864b68ab378b2440c23a61253efe9e | |
| parent | e643f59f6da3a84f43e75dea99afaa5b041ea6bf (diff) | |
| download | rust-75f86e6e2e07c40825e5c7e2f63537efff74a207.tar.gz rust-75f86e6e2e07c40825e5c7e2f63537efff74a207.zip | |
fix LooseTypes flag and PrintMod behaviour, add debug helper
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/lto.rs | 40 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/write.rs | 6 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 3 | ||||
| -rw-r--r-- | compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp | 30 | ||||
| -rw-r--r-- | compiler/rustc_session/src/config.rs | 4 | ||||
| -rw-r--r-- | compiler/rustc_session/src/options.rs | 6 |
6 files changed, 68 insertions, 21 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index a8b49e9552c..925898d8173 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -584,12 +584,10 @@ fn thin_lto( } } -fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) { +fn enable_autodiff_settings(ad: &[config::AutoDiff]) { for &val in ad { + // We intentionally don't use a wildcard, to not forget handling anything new. match val { - config::AutoDiff::PrintModBefore => { - unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; - } config::AutoDiff::PrintPerf => { llvm::set_print_perf(true); } @@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen< llvm::set_inline(true); } config::AutoDiff::LooseTypes => { - llvm::set_loose_types(false); + llvm::set_loose_types(true); } config::AutoDiff::PrintSteps => { llvm::set_print(true); } - // We handle this below + // We handle this in the PassWrapper.cpp + config::AutoDiff::PrintPasses => {} + // We handle this in the PassWrapper.cpp + config::AutoDiff::PrintModBefore => {} + // We handle this in the PassWrapper.cpp config::AutoDiff::PrintModAfter => {} - // We handle this below + // We handle this in the PassWrapper.cpp config::AutoDiff::PrintModFinal => {} // This is required and already checked config::AutoDiff::Enable => {} + // We handle this below + config::AutoDiff::NoPostopt => {} } } // This helps with handling enums for now. @@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager( // We then run the llvm_optimize function a second time, to optimize the code which we generated // in the enzyme differentiation pass. let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable); - let stage = - if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }; + let stage = if thin { + write::AutodiffStage::PreAD + } else { + if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD } + }; if enable_ad { - enable_autodiff_settings(&config.autodiff, module); + enable_autodiff_settings(&config.autodiff); } unsafe { write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?; } - if cfg!(llvm_enzyme) && enable_ad { - // This is the post-autodiff IR, mainly used for testing and educational purposes. - if config.autodiff.contains(&config::AutoDiff::PrintModAfter) { - unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; - } - + if cfg!(llvm_enzyme) && enable_ad && !thin { let opt_stage = llvm::OptStage::FatLTO; let stage = write::AutodiffStage::PostAD; - unsafe { - write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?; + if !config.autodiff.contains(&config::AutoDiff::NoPostopt) { + unsafe { + write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?; + } } // This is the final IR, so people should be able to inspect the optimized autodiff output, diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 76d431a4975..f60bc052a12 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -565,6 +565,9 @@ pub(crate) unsafe fn llvm_optimize( let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable); let run_enzyme = autodiff_stage == AutodiffStage::DuringAD; + let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore); + let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter); + let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses); let unroll_loops; let vectorize_slp; let vectorize_loop; @@ -663,6 +666,9 @@ pub(crate) unsafe fn llvm_optimize( config.no_builtins, config.emit_lifetime_markers, run_enzyme, + print_before_enzyme, + print_after_enzyme, + print_passes, sanitizer_options.as_ref(), pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()), pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()), diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 9ff04f72903..ffb490dcdc2 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2454,6 +2454,9 @@ unsafe extern "C" { DisableSimplifyLibCalls: bool, EmitLifetimeMarkers: bool, RunEnzyme: bool, + PrintBeforeEnzyme: bool, + PrintAfterEnzyme: bool, + PrintPasses: bool, SanitizerOptions: Option<&SanitizerOptions>, PGOGenPath: *const c_char, PGOUsePath: *const c_char, diff --git a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp index e02c80c50b1..8bee051dd4c 100644 --- a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp @@ -14,6 +14,7 @@ #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Verifier.h" +#include "llvm/IRPrinter/IRPrintingPasses.h" #include "llvm/LTO/LTO.h" #include "llvm/MC/MCSubtargetInfo.h" #include "llvm/MC/TargetRegistry.h" @@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize( bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO, bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops, bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls, - bool EmitLifetimeMarkers, bool RunEnzyme, + bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme, + bool PrintAfterEnzyme, bool PrintPasses, LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage, const char *InstrProfileOutput, const char *PGOSampleUsePath, @@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize( // now load "-enzyme" pass: #ifdef ENZYME if (RunEnzyme) { - registerEnzymeAndPassPipeline(PB, true); + + if (PrintBeforeEnzyme) { + // Handle the Rust flag `-Zautodiff=PrintModBefore`. + std::string Banner = "Module before EnzymeNewPM"; + MPM.addPass(PrintModulePass(outs(), Banner, true, false)); + } + + registerEnzymeAndPassPipeline(PB, false); if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) { std::string ErrMsg = toString(std::move(Err)); LLVMRustSetLastError(ErrMsg.c_str()); return LLVMRustResult::Failure; } + + if (PrintAfterEnzyme) { + // Handle the Rust flag `-Zautodiff=PrintModAfter`. + std::string Banner = "Module after EnzymeNewPM"; + MPM.addPass(PrintModulePass(outs(), Banner, true, false)); + } } #endif + if (PrintPasses) { + // Print all passes from the PM: + std::string Pipeline; + raw_string_ostream SOS(Pipeline); + MPM.printPipeline(SOS, [&PIC](StringRef ClassName) { + auto PassName = PIC.getPassNameForClassName(ClassName); + return PassName.empty() ? ClassName : PassName; + }); + outs() << Pipeline; + outs() << "\n"; + } // Upgrade all calls to old intrinsics first. for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;) diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 56b3fe2ab4c..cc6c814e76e 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -244,6 +244,10 @@ pub enum AutoDiff { /// Print the module after running autodiff and optimizations. PrintModFinal, + /// Print all passes scheduled by LLVM + PrintPasses, + /// Disable extra opt run after running autodiff + NoPostopt, /// Enzyme's loose type debug helper (can cause incorrect gradients!!) /// Usable in cases where Enzyme errors with `can not deduce type of X`. LooseTypes, diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index c70f1500d39..2531b0c9d42 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -711,7 +711,7 @@ mod desc { pub(crate) const parse_list: &str = "a space-separated list of strings"; pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; - pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`"; + pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -1360,6 +1360,8 @@ pub mod parse { "PrintModBefore" => AutoDiff::PrintModBefore, "PrintModAfter" => AutoDiff::PrintModAfter, "PrintModFinal" => AutoDiff::PrintModFinal, + "NoPostopt" => AutoDiff::NoPostopt, + "PrintPasses" => AutoDiff::PrintPasses, "LooseTypes" => AutoDiff::LooseTypes, "Inline" => AutoDiff::Inline, _ => { @@ -2095,6 +2097,8 @@ options! { `=PrintModBefore` `=PrintModAfter` `=PrintModFinal` + `=PrintPasses`, + `=NoPostopt` `=LooseTypes` `=Inline` Multiple options can be combined with commas."), |
