diff options
| author | Manuel Drehwald <git@manuel.drehwald.info> | 2025-02-21 21:51:20 -0500 |
|---|---|---|
| committer | Manuel Drehwald <git@manuel.drehwald.info> | 2025-02-21 21:51:20 -0500 |
| commit | e2d250c3f63d14e068e92ab3048817af6e1770c2 (patch) | |
| tree | 1302ad5ff1ff1bfe45c9e30e65dc74b6d97df110 /compiler/rustc_codegen_llvm/src/back/lto.rs | |
| parent | 161a4bf6ff3d0b10cd7e5b0984e908e4927d0890 (diff) | |
| download | rust-e2d250c3f63d14e068e92ab3048817af6e1770c2.tar.gz rust-e2d250c3f63d14e068e92ab3048817af6e1770c2.zip | |
update autodiff flags
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/back/lto.rs')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/lto.rs | 85 |
1 files changed, 62 insertions, 23 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 3e25b94961b..99906ea7bce 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -586,6 +586,42 @@ fn thin_lto( } } +fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) { + for &val in ad { + match val { + config::AutoDiff::PrintModBefore => { + unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; + } + config::AutoDiff::PrintPerf => { + llvm::set_print_perf(true); + } + config::AutoDiff::PrintAA => { + llvm::set_print_activity(true); + } + config::AutoDiff::PrintTA => { + llvm::set_print_type(true); + } + config::AutoDiff::Inline => { + llvm::set_inline(true); + } + config::AutoDiff::LooseTypes => { + llvm::set_loose_types(false); + } + config::AutoDiff::PrintSteps => { + llvm::set_print(true); + } + // We handle this below + config::AutoDiff::PrintModAfter => {} + // This is required and already checked + config::AutoDiff::Enable => {} + } + } + // This helps with handling enums for now. + llvm::set_strict_aliasing(false); + // FIXME(ZuseZ4): Test this, since it was added a long time ago. + llvm::set_rust_rules(true); +} + pub(crate) fn run_pass_manager( cgcx: &CodegenContext<LlvmCodegenBackend>, dcx: DiagCtxtHandle<'_>, @@ -604,34 +640,37 @@ pub(crate) fn run_pass_manager( let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO }; let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); - // If this rustc version was build with enzyme/autodiff enabled, and if users applied the - // `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time. - debug!("running llvm pm opt pipeline"); + // The PostAD behavior is the same that we would have if no autodiff was used. + // It will run the default optimization pipeline. If AD is enabled we select + // the DuringAD stage, which will disable vectorization and loop unrolling, and + // schedule two autodiff optimization + differentiation passes. + // We then run the llvm_optimize function a second time, to optimize the code which we generated + // in the enzyme differentiation pass. + let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable); + let stage = + if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }; + + if enable_ad { + enable_autodiff_settings(&config.autodiff, module); + } + unsafe { - write::llvm_optimize( - cgcx, - dcx, - module, - config, - opt_level, - opt_stage, - write::AutodiffStage::DuringAD, - )?; + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?; } - // FIXME(ZuseZ4): Make this more granular - if cfg!(llvm_enzyme) && !thin { + + if cfg!(llvm_enzyme) && enable_ad { + let opt_stage = llvm::OptStage::FatLTO; + let stage = write::AutodiffStage::PostAD; unsafe { - write::llvm_optimize( - cgcx, - dcx, - module, - config, - opt_level, - llvm::OptStage::FatLTO, - write::AutodiffStage::PostAD, - )?; + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?; + } + + // This is the final IR, so people should be able to inspect the optimized autodiff output. + if config.autodiff.contains(&config::AutoDiff::PrintModAfter) { + unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; } } + debug!("lto done"); Ok(()) } |
