about summary refs log tree commit diff
diff options
context:
space:
mode:
authorManuel Drehwald <git@manuel.drehwald.info>2025-04-05 03:10:19 -0400
committerManuel Drehwald <git@manuel.drehwald.info>2025-04-16 17:13:31 -0400
commita68ae0cbc1d923b69ef690f32b651ecf583ff27f (patch)
tree528a3092d9c80690934bcdf9653fdfd1df9bbd30
parentd4f880f8ce832cd7560bb2f1ebc34f967055ffd7 (diff)
downloadrust-a68ae0cbc1d923b69ef690f32b651ecf583ff27f.tar.gz
rust-a68ae0cbc1d923b69ef690f32b651ecf583ff27f.zip
working dupv and dupvonly for fwd mode
-rw-r--r--compiler/rustc_ast/src/expand/autodiff_attrs.rs40
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs23
-rw-r--r--compiler/rustc_codegen_llvm/src/builder.rs2
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs38
-rw-r--r--compiler/rustc_monomorphize/src/partitioning/autodiff.rs32
5 files changed, 107 insertions, 28 deletions
diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
index 13a7c5a1805..2f918faaf75 100644
--- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs
+++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
@@ -50,8 +50,16 @@ pub enum DiffActivity {
     /// with it.
     Dual,
     /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
+    /// with it. It expects the shadow argument to be `width` times larger than the original
+    /// input/output.
+    Dualv,
+    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
     /// with it. Drop the code which updates the original input/output for maximum performance.
     DualOnly,
+    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
+    /// with it. Drop the code which updates the original input/output for maximum performance.
+    /// It expects the shadow argument to be `width` times larger than the original input/output.
+    DualvOnly,
     /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
     Duplicated,
     /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
@@ -59,7 +67,15 @@ pub enum DiffActivity {
     DuplicatedOnly,
     /// All Integers must be Const, but these are used to mark the integer which represents the
     /// length of a slice/vec. This is used for safety checks on slices.
-    FakeActivitySize,
+    /// The integer (if given) specifies the size of the slice element in bytes.
+    FakeActivitySize(Option<u32>),
+}
+
+impl DiffActivity {
+    pub fn is_dual_or_const(&self) -> bool {
+        use DiffActivity::*;
+        matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)
+    }
 }
 /// We generate one of these structs for each `#[autodiff(...)]` attribute.
 #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@@ -131,11 +147,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
     match mode {
         DiffMode::Error => false,
         DiffMode::Source => false,
-        DiffMode::Forward => {
-            activity == DiffActivity::Dual
-                || activity == DiffActivity::DualOnly
-                || activity == DiffActivity::Const
-        }
+        DiffMode::Forward => activity.is_dual_or_const(),
         DiffMode::Reverse => {
             activity == DiffActivity::Const
                 || activity == DiffActivity::Active
@@ -153,10 +165,8 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
 pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
     use DiffActivity::*;
     // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
-    if matches!(activity, Const) {
-        return true;
-    }
-    if matches!(activity, Dual | DualOnly) {
+    // Dual variants also support all types.
+    if activity.is_dual_or_const() {
         return true;
     }
     // FIXME(ZuseZ4) We should make this more robust to also
@@ -172,9 +182,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
     return match mode {
         DiffMode::Error => false,
         DiffMode::Source => false,
-        DiffMode::Forward => {
-            matches!(activity, Dual | DualOnly | Const)
-        }
+        DiffMode::Forward => activity.is_dual_or_const(),
         DiffMode::Reverse => {
             matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
         }
@@ -189,10 +197,12 @@ impl Display for DiffActivity {
             DiffActivity::Active => write!(f, "Active"),
             DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
             DiffActivity::Dual => write!(f, "Dual"),
+            DiffActivity::Dualv => write!(f, "Dualv"),
             DiffActivity::DualOnly => write!(f, "DualOnly"),
+            DiffActivity::DualvOnly => write!(f, "DualvOnly"),
             DiffActivity::Duplicated => write!(f, "Duplicated"),
             DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
-            DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
+            DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s),
         }
     }
 }
@@ -220,7 +230,9 @@ impl FromStr for DiffActivity {
             "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
             "Const" => Ok(DiffActivity::Const),
             "Dual" => Ok(DiffActivity::Dual),
+            "Dualv" => Ok(DiffActivity::Dualv),
             "DualOnly" => Ok(DiffActivity::DualOnly),
+            "DualvOnly" => Ok(DiffActivity::DualvOnly),
             "Duplicated" => Ok(DiffActivity::Duplicated),
             "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
             _ => Err(()),
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 351413dea49..44ee78ebe68 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -799,8 +799,19 @@ mod llvm_enzyme {
                         d_inputs.push(shadow_arg.clone());
                     }
                 }
-                DiffActivity::Dual | DiffActivity::DualOnly => {
-                    for i in 0..x.width {
+                DiffActivity::Dual
+                | DiffActivity::DualOnly
+                | DiffActivity::Dualv
+                | DiffActivity::DualvOnly => {
+                    // the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
+                    // Enzyme to not expect N arguments, but one argument (which is instead larger).
+                    let iterations =
+                        if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
+                            1
+                        } else {
+                            x.width
+                        };
+                    for i in 0..iterations {
                         let mut shadow_arg = arg.clone();
                         let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
                             ident.name
@@ -823,7 +834,7 @@ mod llvm_enzyme {
                 DiffActivity::Const => {
                     // Nothing to do here.
                 }
-                DiffActivity::None | DiffActivity::FakeActivitySize => {
+                DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
                     panic!("Should not happen");
                 }
             }
@@ -887,8 +898,8 @@ mod llvm_enzyme {
                 }
             };
 
-            if let DiffActivity::Dual = x.ret_activity {
-                let kind = if x.width == 1 {
+            if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
+                let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
                     // Dual can only be used for f32/f64 ret.
                     // In that case we return now a tuple with two floats.
                     TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
@@ -903,7 +914,7 @@ mod llvm_enzyme {
                 let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
                 d_decl.output = FnRetTy::Ty(ty);
             }
-            if let DiffActivity::DualOnly = x.ret_activity {
+            if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
                 // No need to change the return type,
                 // we will just return the shadow in place of the primal return.
                 // However, if we have a width > 1, then we don't return -> T, but -> [T; width]
diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs
index 35134e9f5a0..27f7f95f100 100644
--- a/compiler/rustc_codegen_llvm/src/builder.rs
+++ b/compiler/rustc_codegen_llvm/src/builder.rs
@@ -123,7 +123,7 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
 /// Empty string, to be used where LLVM expects an instruction name, indicating
 /// that the instruction is to be left unnamed (i.e. numbered, in textual IR).
 // FIXME(eddyb) pass `&CStr` directly to FFI once it's a thin pointer.
-const UNNAMED: *const c_char = c"".as_ptr();
+pub(crate) const UNNAMED: *const c_char = c"".as_ptr();
 
 impl<'ll, CX: Borrow<SCx<'ll>>> BackendTypes for GenericBuilder<'_, 'll, CX> {
     type Value = <GenericCx<'ll, CX> as BackendTypes>::Value;
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 5e7ef27143b..e7c071d05aa 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -10,7 +10,7 @@ use rustc_middle::bug;
 use tracing::{debug, trace};
 
 use crate::back::write::llvm_err;
-use crate::builder::SBuilder;
+use crate::builder::{SBuilder, UNNAMED};
 use crate::context::SimpleCx;
 use crate::declare::declare_simple_fn;
 use crate::errors::{AutoDiffWithoutEnable, LlvmError};
@@ -51,6 +51,7 @@ fn has_sret(fnc: &Value) -> bool {
 // using iterators and peek()?
 fn match_args_from_caller_to_enzyme<'ll>(
     cx: &SimpleCx<'ll>,
+    builder: &SBuilder<'ll, 'll>,
     width: u32,
     args: &mut Vec<&'ll llvm::Value>,
     inputs: &[DiffActivity],
@@ -78,7 +79,9 @@ fn match_args_from_caller_to_enzyme<'ll>(
     let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
     let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
     let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
+    let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
     let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
+    let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();
 
     while activity_pos < inputs.len() {
         let diff_activity = inputs[activity_pos as usize];
@@ -90,13 +93,34 @@ fn match_args_from_caller_to_enzyme<'ll>(
             DiffActivity::Active => (enzyme_out, false),
             DiffActivity::ActiveOnly => (enzyme_out, false),
             DiffActivity::Dual => (enzyme_dup, true),
+            DiffActivity::Dualv => (enzyme_dupv, true),
             DiffActivity::DualOnly => (enzyme_dupnoneed, true),
+            DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
             DiffActivity::Duplicated => (enzyme_dup, true),
             DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
-            DiffActivity::FakeActivitySize => (enzyme_const, false),
+            DiffActivity::FakeActivitySize(_) => (enzyme_const, false),
         };
         let outer_arg = outer_args[outer_pos];
         args.push(cx.get_metadata_value(activity));
+        if matches!(diff_activity, DiffActivity::Dualv) {
+            let next_outer_arg = outer_args[outer_pos + 1];
+            let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
+                DiffActivity::FakeActivitySize(Some(s)) => s.into(),
+                _ => bug!("incorrect Dualv handling recognized."),
+            };
+            // stride: sizeof(T) * n_elems.
+            // n_elems is the next integer.
+            // Now we multiply `4 * next_outer_arg` to get the stride.
+            let mul = unsafe {
+                llvm::LLVMBuildMul(
+                    builder.llbuilder,
+                    cx.get_const_i64(elem_bytes_size),
+                    next_outer_arg,
+                    UNNAMED,
+                )
+            };
+            args.push(mul);
+        }
         args.push(outer_arg);
         if duplicated {
             // We know that duplicated args by construction have a following argument,
@@ -114,7 +138,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
                 } else {
                     let next_activity = inputs[activity_pos + 1];
                     // We analyze the MIR types and add this dummy activity if we visit a slice.
-                    next_activity == DiffActivity::FakeActivitySize
+                    matches!(next_activity, DiffActivity::FakeActivitySize(_))
                 }
             };
             if slice {
@@ -125,7 +149,10 @@ fn match_args_from_caller_to_enzyme<'ll>(
                 // int2 >= int1, which means the shadow vector is large enough to store the gradient.
                 assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
 
-                for i in 0..(width as usize) {
+                let iterations =
+                    if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };
+
+                for i in 0..iterations {
                     let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
                     let next_outer_ty2 = cx.val_ty(next_outer_arg2);
                     assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
@@ -136,7 +163,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
                 }
                 args.push(cx.get_metadata_value(enzyme_const));
                 args.push(next_outer_arg);
-                outer_pos += 2 + 2 * width as usize;
+                outer_pos += 2 + 2 * iterations;
                 activity_pos += 2;
             } else {
                 // A duplicated pointer will have the following two outer_fn arguments:
@@ -360,6 +387,7 @@ fn generate_enzyme_call<'ll>(
         let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
         match_args_from_caller_to_enzyme(
             &cx,
+            &builder,
             attrs.width,
             &mut args,
             &attrs.input_activity,
diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs
index ebe0b258c1b..22d593b80b8 100644
--- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs
+++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs
@@ -2,7 +2,7 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
 use rustc_hir::def_id::LOCAL_CRATE;
 use rustc_middle::bug;
 use rustc_middle::mir::mono::MonoItem;
-use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
+use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
 use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
 use tracing::{debug, trace};
 
@@ -22,23 +22,51 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
     for (i, ty) in sig.inputs().iter().enumerate() {
         if let Some(inner_ty) = ty.builtin_deref(true) {
             if inner_ty.is_slice() {
+                // Now we need to figure out the size of each slice element in memory to allow
+                // safety checks and usability improvements in the backend.
+                let sty = match inner_ty.builtin_index() {
+                    Some(sty) => sty,
+                    None => {
+                        panic!("slice element type unknown");
+                    }
+                };
+                let pci = PseudoCanonicalInput {
+                    typing_env: TypingEnv::fully_monomorphized(),
+                    value: sty,
+                };
+
+                let layout = tcx.layout_of(pci);
+                let elem_size = match layout {
+                    Ok(layout) => layout.size,
+                    Err(_) => {
+                        bug!("autodiff failed to compute slice element size");
+                    }
+                };
+                let elem_size: u32 = elem_size.bytes() as u32;
+
                 // We know that the length will be passed as extra arg.
                 if !da.is_empty() {
                     // We are looking at a slice. The length of that slice will become an
                     // extra integer on llvm level. Integers are always const.
                     // However, if the slice get's duplicated, we want to know to later check the
                     // size. So we mark the new size argument as FakeActivitySize.
+                    // There is one FakeActivitySize per slice, so for convenience we store the
+                    // slice element size in bytes in it. We will use the size in the backend.
                     let activity = match da[i] {
                         DiffActivity::DualOnly
                         | DiffActivity::Dual
+                        | DiffActivity::Dualv
                         | DiffActivity::DuplicatedOnly
-                        | DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
+                        | DiffActivity::Duplicated => {
+                            DiffActivity::FakeActivitySize(Some(elem_size))
+                        }
                         DiffActivity::Const => DiffActivity::Const,
                         _ => bug!("unexpected activity for ptr/ref"),
                     };
                     new_activities.push(activity);
                     new_positions.push(i + 1);
                 }
+
                 continue;
             }
         }