diff options
| author | Manuel Drehwald <git@manuel.drehwald.info> | 2025-02-21 21:51:20 -0500 | 
|---|---|---|
| committer | Manuel Drehwald <git@manuel.drehwald.info> | 2025-02-21 21:51:20 -0500 | 
| commit | e2d250c3f63d14e068e92ab3048817af6e1770c2 (patch) | |
| tree | 1302ad5ff1ff1bfe45c9e30e65dc74b6d97df110 /compiler/rustc_codegen_llvm/src | |
| parent | 161a4bf6ff3d0b10cd7e5b0984e908e4927d0890 (diff) | |
| download | rust-e2d250c3f63d14e068e92ab3048817af6e1770c2.tar.gz rust-e2d250c3f63d14e068e92ab3048817af6e1770c2.zip | |
update autodiff flags
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/lto.rs | 85 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder/autodiff.rs | 13 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/errors.rs | 5 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 94 | 
4 files changed, 169 insertions, 28 deletions
| diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 3e25b94961b..99906ea7bce 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -586,6 +586,42 @@ fn thin_lto( } } +fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) { + for &val in ad { + match val { + config::AutoDiff::PrintModBefore => { + unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; + } + config::AutoDiff::PrintPerf => { + llvm::set_print_perf(true); + } + config::AutoDiff::PrintAA => { + llvm::set_print_activity(true); + } + config::AutoDiff::PrintTA => { + llvm::set_print_type(true); + } + config::AutoDiff::Inline => { + llvm::set_inline(true); + } + config::AutoDiff::LooseTypes => { + llvm::set_loose_types(false); + } + config::AutoDiff::PrintSteps => { + llvm::set_print(true); + } + // We handle this below + config::AutoDiff::PrintModAfter => {} + // This is required and already checked + config::AutoDiff::Enable => {} + } + } + // This helps with handling enums for now. + llvm::set_strict_aliasing(false); + // FIXME(ZuseZ4): Test this, since it was added a long time ago. + llvm::set_rust_rules(true); +} + pub(crate) fn run_pass_manager( cgcx: &CodegenContext<LlvmCodegenBackend>, dcx: DiagCtxtHandle<'_>, @@ -604,34 +640,37 @@ pub(crate) fn run_pass_manager( let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO }; let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); - // If this rustc version was build with enzyme/autodiff enabled, and if users applied the - // `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time. - debug!("running llvm pm opt pipeline"); + // The PostAD behavior is the same that we would have if no autodiff was used. + // It will run the default optimization pipeline. If AD is enabled we select + // the DuringAD stage, which will disable vectorization and loop unrolling, and + // schedule two autodiff optimization + differentiation passes. + // We then run the llvm_optimize function a second time, to optimize the code which we generated + // in the enzyme differentiation pass. + let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable); + let stage = + if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }; + + if enable_ad { + enable_autodiff_settings(&config.autodiff, module); + } + unsafe { - write::llvm_optimize( - cgcx, - dcx, - module, - config, - opt_level, - opt_stage, - write::AutodiffStage::DuringAD, - )?; + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?; } - // FIXME(ZuseZ4): Make this more granular - if cfg!(llvm_enzyme) && !thin { + + if cfg!(llvm_enzyme) && enable_ad { + let opt_stage = llvm::OptStage::FatLTO; + let stage = write::AutodiffStage::PostAD; unsafe { - write::llvm_optimize( - cgcx, - dcx, - module, - config, - opt_level, - llvm::OptStage::FatLTO, - write::AutodiffStage::PostAD, - )?; + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?; + } + + // This is the final IR, so people should be able to inspect the optimized autodiff output. + if config.autodiff.contains(&config::AutoDiff::PrintModAfter) { + unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) }; } } + debug!("lto done"); Ok(()) } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index b2c1088e3fc..2c7899975e3 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -10,7 +10,7 @@ use crate::back::write::llvm_err; use crate::builder::SBuilder; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; -use crate::errors::LlvmError; +use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; use crate::llvm::{Metadata, True}; use crate::value::Value; @@ -46,9 +46,6 @@ fn generate_enzyme_call<'ll>( let output = attrs.ret_activity; // We have to pick the name depending on whether we want forward or reverse mode autodiff. - // FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since - // it will handle higher-order derivatives correctly automatically (in theory). Currently - // higher-order derivatives fail, so we should debug that before adjusting this code. let mut ad_name: String = match attrs.mode { DiffMode::Forward => "__enzyme_fwddiff", DiffMode::Reverse => "__enzyme_autodiff", @@ -291,6 +288,14 @@ pub(crate) fn differentiate<'ll>( let diag_handler = cgcx.create_dcx(); let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx }; + // First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag? + if !diff_items.is_empty() + && !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) + { + let dcx = cgcx.create_dcx(); + return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable)); + } + // Before dumping the module, we want all the TypeTrees to become part of the module. for item in diff_items.iter() { let name = item.source.clone(); diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 97f49256165..4c5a78ca74f 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -92,10 +92,13 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> { #[derive(Diagnostic)] #[diag(codegen_llvm_autodiff_without_lto)] -#[note] pub(crate) struct AutoDiffWithoutLTO; #[derive(Diagnostic)] +#[diag(codegen_llvm_autodiff_without_enable)] +pub(crate) struct AutoDiffWithoutEnable; + +#[derive(Diagnostic)] #[diag(codegen_llvm_lto_disallowed)] pub(crate) struct LtoDisallowed; diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 39bac13a968..daa6696e963 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -35,3 +35,97 @@ pub enum LLVMRustVerifierFailureAction { LLVMPrintMessageAction = 1, LLVMReturnStatusAction = 2, } + +#[cfg(llvm_enzyme)] +pub use self::Enzyme_AD::*; + +#[cfg(llvm_enzyme)] +pub mod Enzyme_AD { + use libc::c_void; + extern "C" { + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); + } + extern "C" { + static mut EnzymePrintPerf: c_void; + static mut EnzymePrintActivity: c_void; + static mut EnzymePrintType: c_void; + static mut EnzymePrint: c_void; + static mut EnzymeStrictAliasing: c_void; + static mut looseTypeAnalysis: c_void; + static mut EnzymeInline: c_void; + static mut RustTypeRules: c_void; + } + pub fn set_print_perf(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8); + } + } + pub fn set_print_activity(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8); + } + } + pub fn set_print_type(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8); + } + } + pub fn set_print(print: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8); + } + } + pub fn set_strict_aliasing(strict: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8); + } + } + pub fn set_loose_types(loose: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8); + } + } + pub fn set_inline(val: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8); + } + } + pub fn set_rust_rules(val: bool) { + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8); + } + } +} + +#[cfg(not(llvm_enzyme))] +pub use self::Fallback_AD::*; + +#[cfg(not(llvm_enzyme))] +pub mod Fallback_AD { + #![allow(unused_variables)] + + pub fn set_inline(val: bool) { + unimplemented!() + } + pub fn set_print_perf(print: bool) { + unimplemented!() + } + pub fn set_print_activity(print: bool) { + unimplemented!() + } + pub fn set_print_type(print: bool) { + unimplemented!() + } + pub fn set_print(print: bool) { + unimplemented!() + } + pub fn set_strict_aliasing(strict: bool) { + unimplemented!() + } + pub fn set_loose_types(loose: bool) { + unimplemented!() + } + pub fn set_rust_rules(val: bool) { + unimplemented!() + } +} | 
