about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src
diff options
context:
space:
mode:
authorManuel Drehwald <git@manuel.drehwald.info>2025-02-21 21:51:20 -0500
committerManuel Drehwald <git@manuel.drehwald.info>2025-02-21 21:51:20 -0500
commite2d250c3f63d14e068e92ab3048817af6e1770c2 (patch)
tree1302ad5ff1ff1bfe45c9e30e65dc74b6d97df110 /compiler/rustc_codegen_llvm/src
parent161a4bf6ff3d0b10cd7e5b0984e908e4927d0890 (diff)
downloadrust-e2d250c3f63d14e068e92ab3048817af6e1770c2.tar.gz
rust-e2d250c3f63d14e068e92ab3048817af6e1770c2.zip
update autodiff flags
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs85
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs13
-rw-r--r--compiler/rustc_codegen_llvm/src/errors.rs5
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs94
4 files changed, 169 insertions, 28 deletions
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index 3e25b94961b..99906ea7bce 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -586,6 +586,42 @@ fn thin_lto(
     }
 }
 
+fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
+    for &val in ad {
+        match val {
+            config::AutoDiff::PrintModBefore => {
+                unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
+            }
+            config::AutoDiff::PrintPerf => {
+                llvm::set_print_perf(true);
+            }
+            config::AutoDiff::PrintAA => {
+                llvm::set_print_activity(true);
+            }
+            config::AutoDiff::PrintTA => {
+                llvm::set_print_type(true);
+            }
+            config::AutoDiff::Inline => {
+                llvm::set_inline(true);
+            }
+            config::AutoDiff::LooseTypes => {
+                llvm::set_loose_types(false);
+            }
+            config::AutoDiff::PrintSteps => {
+                llvm::set_print(true);
+            }
+            // We handle this below
+            config::AutoDiff::PrintModAfter => {}
+            // This is required and already checked
+            config::AutoDiff::Enable => {}
+        }
+    }
+    // This helps with handling enums for now.
+    llvm::set_strict_aliasing(false);
+    // FIXME(ZuseZ4): Test this, since it was added a long time ago.
+    llvm::set_rust_rules(true);
+}
+
 pub(crate) fn run_pass_manager(
     cgcx: &CodegenContext<LlvmCodegenBackend>,
     dcx: DiagCtxtHandle<'_>,
@@ -604,34 +640,37 @@ pub(crate) fn run_pass_manager(
     let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
     let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
 
-    // If this rustc version was build with enzyme/autodiff enabled, and if users applied the
-    // `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
-    debug!("running llvm pm opt pipeline");
+    // The PostAD behavior is the same that we would have if no autodiff was used.
+    // It will run the default optimization pipeline. If AD is enabled we select
+    // the DuringAD stage, which will disable vectorization and loop unrolling, and
+    // schedule two autodiff optimization + differentiation passes.
+    // We then run the llvm_optimize function a second time, to optimize the code which we generated
+    // in the enzyme differentiation pass.
+    let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
+    let stage =
+        if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
+
+    if enable_ad {
+        enable_autodiff_settings(&config.autodiff, module);
+    }
+
     unsafe {
-        write::llvm_optimize(
-            cgcx,
-            dcx,
-            module,
-            config,
-            opt_level,
-            opt_stage,
-            write::AutodiffStage::DuringAD,
-        )?;
+        write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
     }
-    // FIXME(ZuseZ4): Make this more granular
-    if cfg!(llvm_enzyme) && !thin {
+
+    if cfg!(llvm_enzyme) && enable_ad {
+        let opt_stage = llvm::OptStage::FatLTO;
+        let stage = write::AutodiffStage::PostAD;
         unsafe {
-            write::llvm_optimize(
-                cgcx,
-                dcx,
-                module,
-                config,
-                opt_level,
-                llvm::OptStage::FatLTO,
-                write::AutodiffStage::PostAD,
-            )?;
+            write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
+        }
+
+        // This is the final IR, so people should be able to inspect the optimized autodiff output.
+        if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
+            unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
         }
     }
+
     debug!("lto done");
     Ok(())
 }
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index b2c1088e3fc..2c7899975e3 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -10,7 +10,7 @@ use crate::back::write::llvm_err;
 use crate::builder::SBuilder;
 use crate::context::SimpleCx;
 use crate::declare::declare_simple_fn;
-use crate::errors::LlvmError;
+use crate::errors::{AutoDiffWithoutEnable, LlvmError};
 use crate::llvm::AttributePlace::Function;
 use crate::llvm::{Metadata, True};
 use crate::value::Value;
@@ -46,9 +46,6 @@ fn generate_enzyme_call<'ll>(
     let output = attrs.ret_activity;
 
     // We have to pick the name depending on whether we want forward or reverse mode autodiff.
-    // FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since
-    // it will handle higher-order derivatives correctly automatically (in theory). Currently
-    // higher-order derivatives fail, so we should debug that before adjusting this code.
     let mut ad_name: String = match attrs.mode {
         DiffMode::Forward => "__enzyme_fwddiff",
         DiffMode::Reverse => "__enzyme_autodiff",
@@ -291,6 +288,14 @@ pub(crate) fn differentiate<'ll>(
     let diag_handler = cgcx.create_dcx();
     let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
 
+    // First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag?
+    if !diff_items.is_empty()
+        && !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
+    {
+        let dcx = cgcx.create_dcx();
+        return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
+    }
+
     // Before dumping the module, we want all the TypeTrees to become part of the module.
     for item in diff_items.iter() {
         let name = item.source.clone();
diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs
index 97f49256165..4c5a78ca74f 100644
--- a/compiler/rustc_codegen_llvm/src/errors.rs
+++ b/compiler/rustc_codegen_llvm/src/errors.rs
@@ -92,10 +92,13 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {
 
 #[derive(Diagnostic)]
 #[diag(codegen_llvm_autodiff_without_lto)]
-#[note]
 pub(crate) struct AutoDiffWithoutLTO;
 
 #[derive(Diagnostic)]
+#[diag(codegen_llvm_autodiff_without_enable)]
+pub(crate) struct AutoDiffWithoutEnable;
+
+#[derive(Diagnostic)]
 #[diag(codegen_llvm_lto_disallowed)]
 pub(crate) struct LtoDisallowed;
 
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index 39bac13a968..daa6696e963 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -35,3 +35,97 @@ pub enum LLVMRustVerifierFailureAction {
     LLVMPrintMessageAction = 1,
     LLVMReturnStatusAction = 2,
 }
+
+#[cfg(llvm_enzyme)]
+pub use self::Enzyme_AD::*;
+
+#[cfg(llvm_enzyme)]
+pub mod Enzyme_AD {
+    use libc::c_void;
+    extern "C" {
+        pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
+    }
+    extern "C" {
+        static mut EnzymePrintPerf: c_void;
+        static mut EnzymePrintActivity: c_void;
+        static mut EnzymePrintType: c_void;
+        static mut EnzymePrint: c_void;
+        static mut EnzymeStrictAliasing: c_void;
+        static mut looseTypeAnalysis: c_void;
+        static mut EnzymeInline: c_void;
+        static mut RustTypeRules: c_void;
+    }
+    pub fn set_print_perf(print: bool) {
+        unsafe {
+            EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8);
+        }
+    }
+    pub fn set_print_activity(print: bool) {
+        unsafe {
+            EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8);
+        }
+    }
+    pub fn set_print_type(print: bool) {
+        unsafe {
+            EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
+        }
+    }
+    pub fn set_print(print: bool) {
+        unsafe {
+            EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
+        }
+    }
+    pub fn set_strict_aliasing(strict: bool) {
+        unsafe {
+            EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8);
+        }
+    }
+    pub fn set_loose_types(loose: bool) {
+        unsafe {
+            EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8);
+        }
+    }
+    pub fn set_inline(val: bool) {
+        unsafe {
+            EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8);
+        }
+    }
+    pub fn set_rust_rules(val: bool) {
+        unsafe {
+            EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8);
+        }
+    }
+}
+
+#[cfg(not(llvm_enzyme))]
+pub use self::Fallback_AD::*;
+
+#[cfg(not(llvm_enzyme))]
+pub mod Fallback_AD {
+    #![allow(unused_variables)]
+
+    pub fn set_inline(val: bool) {
+        unimplemented!()
+    }
+    pub fn set_print_perf(print: bool) {
+        unimplemented!()
+    }
+    pub fn set_print_activity(print: bool) {
+        unimplemented!()
+    }
+    pub fn set_print_type(print: bool) {
+        unimplemented!()
+    }
+    pub fn set_print(print: bool) {
+        unimplemented!()
+    }
+    pub fn set_strict_aliasing(strict: bool) {
+        unimplemented!()
+    }
+    pub fn set_loose_types(loose: bool) {
+        unimplemented!()
+    }
+    pub fn set_rust_rules(val: bool) {
+        unimplemented!()
+    }
+}