about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/builder
diff options
context:
space:
mode:
authorMatthias Krüger <matthias.krueger@famsik.de>2025-02-10 16:38:23 +0100
committerGitHub <noreply@github.com>2025-02-10 16:38:23 +0100
commit78f5bddd57e8eabd0a71efd5fe59005a7b2a87c1 (patch)
tree0682ab2b0a7d130f9d52774b15efdae826785685 /compiler/rustc_codegen_llvm/src/builder
parent8c04e395952022a451138dc4dbead6dd6ae65203 (diff)
parent061abbc36928cce784c54463c266f4d43d14d419 (diff)
downloadrust-78f5bddd57e8eabd0a71efd5fe59005a7b2a87c1.tar.gz
rust-78f5bddd57e8eabd0a71efd5fe59005a7b2a87c1.zip
Rollup merge of #136419 - EnzymeAD:autodiff-tests, r=onur-ozkan,jieyouxu
adding autodiff tests

I'd like to get started with upstreaming some tests, even though I'm still waiting for an answer on how to best integrate the enzyme pass. Can we therefore temporarily support the -Z llvm-plugins here without too much effort? And in that case, how would that work? I saw you can do remapping, e.g. `rust-src-base`, but I don't think that will give me the path to libEnzyme.so. Do you have another suggestion?

Other than that this test simply checks that the derivative of `x*x` is `2.0 * x`, which in this case is computed as
`%0 = fadd fast double %x.0.val, %x.0.val`
(I'll add a few more tests and move it to an autodiff folder if we can use the -Z flag)

r? ``@jieyouxu``

Locally at least `-Zllvm-plugins=${PWD}/build/x86_64-unknown-linux-gnu/enzyme/build/Enzyme/libEnzyme-19.so` seems to work if I copy the command I get from x.py test and run it manually. However, running x.py test itself fails.

Tracking:

- https://github.com/rust-lang/rust/issues/124509

Zulip discussion: https://rust-lang.zulipchat.com/#narrow/channel/326414-t-infra.2Fbootstrap/topic/Enzyme.20build.20changes
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder')
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs46
1 files changed, 14 insertions, 32 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index dd5e726160d..b2c1088e3fc 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -4,10 +4,9 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
 use rustc_codegen_ssa::ModuleCodegen;
 use rustc_codegen_ssa::back::write::ModuleConfig;
 use rustc_errors::FatalError;
-use rustc_session::config::Lto;
 use tracing::{debug, trace};
 
-use crate::back::write::{llvm_err, llvm_optimize};
+use crate::back::write::llvm_err;
 use crate::builder::SBuilder;
 use crate::context::SimpleCx;
 use crate::declare::declare_simple_fn;
@@ -53,8 +52,6 @@ fn generate_enzyme_call<'ll>(
     let mut ad_name: String = match attrs.mode {
         DiffMode::Forward => "__enzyme_fwddiff",
         DiffMode::Reverse => "__enzyme_autodiff",
-        DiffMode::ForwardFirst => "__enzyme_fwddiff",
-        DiffMode::ReverseFirst => "__enzyme_autodiff",
         _ => panic!("logic bug in autodiff, unrecognized mode"),
     }
     .to_string();
@@ -153,7 +150,7 @@ fn generate_enzyme_call<'ll>(
             _ => {}
         }
 
-        trace!("matching autodiff arguments");
+        debug!("matching autodiff arguments");
         // We now handle the issue that Rust level arguments not always match the llvm-ir level
         // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
         // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
@@ -164,10 +161,10 @@ fn generate_enzyme_call<'ll>(
         let mut activity_pos = 0;
         let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
         while activity_pos < inputs.len() {
-            let activity = inputs[activity_pos as usize];
+            let diff_activity = inputs[activity_pos as usize];
             // Duplicated arguments received a shadow argument, into which enzyme will write the
             // gradient.
-            let (activity, duplicated): (&Metadata, bool) = match activity {
+            let (activity, duplicated): (&Metadata, bool) = match diff_activity {
                 DiffActivity::None => panic!("not a valid input activity"),
                 DiffActivity::Const => (enzyme_const, false),
                 DiffActivity::Active => (enzyme_out, false),
@@ -222,7 +219,15 @@ fn generate_enzyme_call<'ll>(
                     // A duplicated pointer will have the following two outer_fn arguments:
                     // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
                     // (..., metadata! enzyme_dup, ptr, ptr, ...).
-                    assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
+                    if matches!(
+                        diff_activity,
+                        DiffActivity::Duplicated | DiffActivity::DuplicatedOnly
+                    ) {
+                        assert!(
+                            llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
+                        );
+                    }
+                    // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
                     args.push(next_outer_arg);
                     outer_pos += 2;
                     activity_pos += 1;
@@ -277,7 +282,7 @@ pub(crate) fn differentiate<'ll>(
     module: &'ll ModuleCodegen<ModuleLlvm>,
     cgcx: &CodegenContext<LlvmCodegenBackend>,
     diff_items: Vec<AutoDiffItem>,
-    config: &ModuleConfig,
+    _config: &ModuleConfig,
 ) -> Result<(), FatalError> {
     for item in &diff_items {
         trace!("{}", item);
@@ -318,29 +323,6 @@ pub(crate) fn differentiate<'ll>(
 
     // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
 
-    if let Some(opt_level) = config.opt_level {
-        let opt_stage = match cgcx.lto {
-            Lto::Fat => llvm::OptStage::PreLinkFatLTO,
-            Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
-            _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
-            _ => llvm::OptStage::PreLinkNoLTO,
-        };
-        // This is our second opt call, so now we run all opts,
-        // to make sure we get the best performance.
-        let skip_size_increasing_opts = false;
-        trace!("running Module Optimization after differentiation");
-        unsafe {
-            llvm_optimize(
-                cgcx,
-                diag_handler.handle(),
-                module,
-                config,
-                opt_level,
-                opt_stage,
-                skip_size_increasing_opts,
-            )?
-        };
-    }
     trace!("done with differentiate()");
 
     Ok(())