diff options
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/lto.rs | 12 | ||||
| -rw-r--r-- | compiler/rustc_session/src/config.rs | 6 | ||||
| -rw-r--r-- | compiler/rustc_session/src/options.rs | 4 |
3 files changed, 17 insertions, 5 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 668795191a2..158275a559b 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -610,6 +610,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen< } // We handle this below config::AutoDiff::PrintModAfter => {} + // We handle this below + config::AutoDiff::PrintModFinal => {} // This is required and already checked config::AutoDiff::Enable => {} } @@ -657,14 +659,20 @@ pub(crate) fn run_pass_manager( } if cfg!(llvm_enzyme) && enable_ad { + // This is the post-autodiff IR, mainly used for testing and educational purposes. + if config.autodiff.contains(&config::AutoDiff::PrintModAfter) { + unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; + } + let opt_stage = llvm::OptStage::FatLTO; let stage = write::AutodiffStage::PostAD; unsafe { write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?; } - // This is the final IR, so people should be able to inspect the optimized autodiff output. - if config.autodiff.contains(&config::AutoDiff::PrintModAfter) { + // This is the final IR, so people should be able to inspect the optimized autodiff output, + // for manual inspection. + if config.autodiff.contains(&config::AutoDiff::PrintModFinal) { unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; } } diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 43b78423c72..39d6a24d081 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -235,10 +235,12 @@ pub enum AutoDiff { PrintPerf, /// Print intermediate IR generation steps PrintSteps, - /// Print the whole module, before running opts. + /// Print the module, before running autodiff. PrintModBefore, - /// Print the module after Enzyme differentiated everything. + /// Print the module after running autodiff. PrintModAfter, + /// Print the module after running autodiff and optimizations. + PrintModFinal, /// Enzyme's loose type debug helper (can cause incorrect gradients!!) /// Usable in cases where Enzyme errors with `can not deduce type of X`. diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index b3be4b611f0..45cce29799f 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -707,7 +707,7 @@ mod desc { pub(crate) const parse_list: &str = "a space-separated list of strings"; pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; - pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `LooseTypes`, `Inline`"; + pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -1355,6 +1355,7 @@ pub mod parse { "PrintSteps" => AutoDiff::PrintSteps, "PrintModBefore" => AutoDiff::PrintModBefore, "PrintModAfter" => AutoDiff::PrintModAfter, + "PrintModFinal" => AutoDiff::PrintModFinal, "LooseTypes" => AutoDiff::LooseTypes, "Inline" => AutoDiff::Inline, _ => { @@ -2088,6 +2089,7 @@ options! { `=PrintSteps` `=PrintModBefore` `=PrintModAfter` + `=PrintModFinal` `=LooseTypes` `=Inline` Multiple options can be combined with commas."), |
