about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
diff options
context:
space:
mode:
authorBoxy <rust@boxyuwu.dev>2025-02-25 21:27:44 +0000
committerBoxy <rust@boxyuwu.dev>2025-02-25 21:27:44 +0000
commitd9683df7c2f6d4141b1321e27635d2ce3167eaa4 (patch)
treedce0d46d1b7d624ec9b9b09b2c1854f6245a5ff4 /compiler/rustc_codegen_llvm/src/builder/autodiff.rs
parent46392d1661540e256fd9573d8f06c2784a58c983 (diff)
parent4ecd70ddd1039a3954056c1071e40278048476fa (diff)
downloadrust-d9683df7c2f6d4141b1321e27635d2ce3167eaa4.tar.gz
rust-d9683df7c2f6d4141b1321e27635d2ce3167eaa4.zip
Merge from rustc
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder/autodiff.rs')
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs59
1 files changed, 23 insertions, 36 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index dd5e726160d..2c7899975e3 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -4,14 +4,13 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
 use rustc_codegen_ssa::ModuleCodegen;
 use rustc_codegen_ssa::back::write::ModuleConfig;
 use rustc_errors::FatalError;
-use rustc_session::config::Lto;
 use tracing::{debug, trace};
 
-use crate::back::write::{llvm_err, llvm_optimize};
+use crate::back::write::llvm_err;
 use crate::builder::SBuilder;
 use crate::context::SimpleCx;
 use crate::declare::declare_simple_fn;
-use crate::errors::LlvmError;
+use crate::errors::{AutoDiffWithoutEnable, LlvmError};
 use crate::llvm::AttributePlace::Function;
 use crate::llvm::{Metadata, True};
 use crate::value::Value;
@@ -47,14 +46,9 @@ fn generate_enzyme_call<'ll>(
     let output = attrs.ret_activity;
 
     // We have to pick the name depending on whether we want forward or reverse mode autodiff.
-    // FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since
-    // it will handle higher-order derivatives correctly automatically (in theory). Currently
-    // higher-order derivatives fail, so we should debug that before adjusting this code.
     let mut ad_name: String = match attrs.mode {
         DiffMode::Forward => "__enzyme_fwddiff",
         DiffMode::Reverse => "__enzyme_autodiff",
-        DiffMode::ForwardFirst => "__enzyme_fwddiff",
-        DiffMode::ReverseFirst => "__enzyme_autodiff",
         _ => panic!("logic bug in autodiff, unrecognized mode"),
     }
     .to_string();
@@ -153,7 +147,7 @@ fn generate_enzyme_call<'ll>(
             _ => {}
         }
 
-        trace!("matching autodiff arguments");
+        debug!("matching autodiff arguments");
         // We now handle the issue that Rust level arguments not always match the llvm-ir level
         // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
         // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
@@ -164,10 +158,10 @@ fn generate_enzyme_call<'ll>(
         let mut activity_pos = 0;
         let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
         while activity_pos < inputs.len() {
-            let activity = inputs[activity_pos as usize];
+            let diff_activity = inputs[activity_pos as usize];
             // Duplicated arguments received a shadow argument, into which enzyme will write the
             // gradient.
-            let (activity, duplicated): (&Metadata, bool) = match activity {
+            let (activity, duplicated): (&Metadata, bool) = match diff_activity {
                 DiffActivity::None => panic!("not a valid input activity"),
                 DiffActivity::Const => (enzyme_const, false),
                 DiffActivity::Active => (enzyme_out, false),
@@ -222,7 +216,15 @@ fn generate_enzyme_call<'ll>(
                     // A duplicated pointer will have the following two outer_fn arguments:
                     // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
                     // (..., metadata! enzyme_dup, ptr, ptr, ...).
-                    assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
+                    if matches!(
+                        diff_activity,
+                        DiffActivity::Duplicated | DiffActivity::DuplicatedOnly
+                    ) {
+                        assert!(
+                            llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
+                        );
+                    }
+                    // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
                     args.push(next_outer_arg);
                     outer_pos += 2;
                     activity_pos += 1;
@@ -277,7 +279,7 @@ pub(crate) fn differentiate<'ll>(
     module: &'ll ModuleCodegen<ModuleLlvm>,
     cgcx: &CodegenContext<LlvmCodegenBackend>,
     diff_items: Vec<AutoDiffItem>,
-    config: &ModuleConfig,
+    _config: &ModuleConfig,
 ) -> Result<(), FatalError> {
     for item in &diff_items {
         trace!("{}", item);
@@ -286,6 +288,14 @@ pub(crate) fn differentiate<'ll>(
     let diag_handler = cgcx.create_dcx();
     let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
 
+    // First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag?
+    if !diff_items.is_empty()
+        && !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
+    {
+        let dcx = cgcx.create_dcx();
+        return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
+    }
+
     // Before dumping the module, we want all the TypeTrees to become part of the module.
     for item in diff_items.iter() {
         let name = item.source.clone();
@@ -318,29 +328,6 @@ pub(crate) fn differentiate<'ll>(
 
     // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
 
-    if let Some(opt_level) = config.opt_level {
-        let opt_stage = match cgcx.lto {
-            Lto::Fat => llvm::OptStage::PreLinkFatLTO,
-            Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
-            _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
-            _ => llvm::OptStage::PreLinkNoLTO,
-        };
-        // This is our second opt call, so now we run all opts,
-        // to make sure we get the best performance.
-        let skip_size_increasing_opts = false;
-        trace!("running Module Optimization after differentiation");
-        unsafe {
-            llvm_optimize(
-                cgcx,
-                diag_handler.handle(),
-                module,
-                config,
-                opt_level,
-                opt_stage,
-                skip_size_increasing_opts,
-            )?
-        };
-    }
     trace!("done with differentiate()");
 
     Ok(())