about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs44
-rw-r--r--compiler/rustc_codegen_llvm/src/intrinsic.rs3
-rw-r--r--tests/codegen-llvm/autodiff/abi_handling.rs210
-rw-r--r--tests/ui/autodiff/zst.rs17
4 files changed, 271 insertions, 3 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 6ddf53cdc87..b66e3dfdeec 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -3,8 +3,9 @@ use std::ptr;
 use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
 use rustc_codegen_ssa::common::TypeKind;
 use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
-use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
+use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
 use rustc_middle::{bug, ty};
+use rustc_target::callconv::PassMode;
 use tracing::debug;
 
 use crate::builder::{Builder, PlaceRef, UNNAMED};
@@ -16,9 +17,12 @@ use crate::value::Value;
 
 pub(crate) fn adjust_activity_to_abi<'tcx>(
     tcx: TyCtxt<'tcx>,
-    fn_ty: Ty<'tcx>,
+    instance: Instance<'tcx>,
+    typing_env: TypingEnv<'tcx>,
     da: &mut Vec<DiffActivity>,
 ) {
+    let fn_ty = instance.ty(tcx, typing_env);
+
     if !matches!(fn_ty.kind(), ty::FnDef(..)) {
         bug!("expected fn def for autodiff, got {:?}", fn_ty);
     }
@@ -27,8 +31,16 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
     // All we do is decide how to handle the arguments.
     let sig = fn_ty.fn_sig(tcx).skip_binder();
 
+    // FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
+    let Ok(fn_abi) =
+        tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
+    else {
+        bug!("failed to get fn_abi of instance with empty varargs");
+    };
+
     let mut new_activities = vec![];
     let mut new_positions = vec![];
+    let mut del_activities = 0;
     for (i, ty) in sig.inputs().iter().enumerate() {
         if let Some(inner_ty) = ty.builtin_deref(true) {
             if inner_ty.is_slice() {
@@ -80,6 +92,34 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
                 continue;
             }
         }
+
+        let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };
+
+        let layout = match tcx.layout_of(pci) {
+            Ok(layout) => layout.layout,
+            Err(_) => {
+                bug!("failed to compute layout for type {:?}", ty);
+            }
+        };
+
+        let pass_mode = &fn_abi.args[i].mode;
+
+        // For ZST, just ignore and don't add its activity, as this arg won't be present
+        // in the LLVM passed to Enzyme.
+        // Some targets pass ZST indirectly in the C ABI, in that case, handle it as a normal arg
+        // FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
+        if *pass_mode == PassMode::Ignore {
+            del_activities += 1;
+            da.remove(i);
+        }
+
+        // If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
+        // Otherwise, the number of activities won't match the number of LLVM arguments and
+        // this will lead to errors when verifying the Enzyme call.
+        if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
+            new_activities.push(da[i].clone());
+            new_positions.push(i + 1 - del_activities);
+        }
     }
     // now add the extra activities coming from slices
     // Reverse order to not invalidate the indices
diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs
index 85f71f331a4..e7f4a357048 100644
--- a/compiler/rustc_codegen_llvm/src/intrinsic.rs
+++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs
@@ -1208,7 +1208,8 @@ fn codegen_autodiff<'ll, 'tcx>(
 
     adjust_activity_to_abi(
         tcx,
-        fn_source.ty(tcx, TypingEnv::fully_monomorphized()),
+        fn_source,
+        TypingEnv::fully_monomorphized(),
         &mut diff_attrs.input_activity,
     );
 
diff --git a/tests/codegen-llvm/autodiff/abi_handling.rs b/tests/codegen-llvm/autodiff/abi_handling.rs
new file mode 100644
index 00000000000..454ec698b91
--- /dev/null
+++ b/tests/codegen-llvm/autodiff/abi_handling.rs
@@ -0,0 +1,210 @@
+//@ revisions: debug release
+
+//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat
+//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+
+// This test checks that Rust types are lowered to LLVM-IR types in a way
+// we expect and Enzyme can handle. We explicitly check release mode to
+// ensure that LLVM's O3 pipeline doesn't rewrite function signatures
+// into forms that Enzyme can't process correctly.
+
+#![feature(autodiff)]
+
+use std::autodiff::{autodiff_forward, autodiff_reverse};
+
+#[derive(Copy, Clone)]
+struct Input {
+    x: f32,
+    y: f32,
+}
+
+#[derive(Copy, Clone)]
+struct Wrapper {
+    z: f32,
+}
+
+#[derive(Copy, Clone)]
+struct NestedInput {
+    x: f32,
+    y: Wrapper,
+}
+
+fn square(x: f32) -> f32 {
+    x * x
+}
+
+// CHECK-LABEL: ; abi_handling::df1
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal { float, float }
+// debug-SAME: (ptr align 4 %x, ptr align 4 %bx_0)
+// release-NEXT: define internal fastcc float
+// release-SAME: (float %x.0.val, float %x.4.val)
+
+// CHECK-LABEL: ; abi_handling::f1
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal float
+// debug-SAME: (ptr align 4 %x)
+// release-NEXT: define internal fastcc noundef float
+// release-SAME: (float %x.0.val, float %x.4.val)
+#[autodiff_forward(df1, Dual, Dual)]
+#[inline(never)]
+fn f1(x: &[f32; 2]) -> f32 {
+    x[0] + x[1]
+}
+
+// CHECK-LABEL: ; abi_handling::df2
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal { float, float }
+// debug-SAME: (ptr %f, float %x, float %dret)
+// release-NEXT: define internal fastcc float
+// release-SAME: (float noundef %x)
+
+// CHECK-LABEL: ; abi_handling::f2
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal float
+// debug-SAME: (ptr %f, float %x)
+// release-NEXT: define internal fastcc noundef float
+// release-SAME: (float noundef %x)
+#[autodiff_reverse(df2, Const, Active, Active)]
+#[inline(never)]
+fn f2(f: fn(f32) -> f32, x: f32) -> f32 {
+    f(x)
+}
+
+// CHECK-LABEL: ; abi_handling::df3
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal { float, float }
+// debug-SAME: (ptr align 4 %x, ptr align 4 %bx_0, ptr align 4 %y, ptr align 4 %by_0)
+// release-NEXT: define internal fastcc { float, float }
+// release-SAME: (float %x.0.val)
+
+// CHECK-LABEL: ; abi_handling::f3
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal float
+// debug-SAME: (ptr align 4 %x, ptr align 4 %y)
+// release-NEXT: define internal fastcc noundef float
+// release-SAME: (float %x.0.val)
+#[autodiff_forward(df3, Dual, Dual, Dual)]
+#[inline(never)]
+fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 {
+    *x * *y
+}
+
+// CHECK-LABEL: ; abi_handling::df4
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal { float, float }
+// debug-SAME: (float %x.0, float %x.1, float %bx_0.0, float %bx_0.1)
+// release-NEXT: define internal fastcc { float, float }
+// release-SAME: (float noundef %x.0, float noundef %x.1)
+
+// CHECK-LABEL: ; abi_handling::f4
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal float
+// debug-SAME: (float %x.0, float %x.1)
+// release-NEXT: define internal fastcc noundef float
+// release-SAME: (float noundef %x.0, float noundef %x.1)
+#[autodiff_forward(df4, Dual, Dual)]
+#[inline(never)]
+fn f4(x: (f32, f32)) -> f32 {
+    x.0 * x.1
+}
+
+// CHECK-LABEL: ; abi_handling::df5
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal { float, float }
+// debug-SAME: (float %i.0, float %i.1, float %bi_0.0, float %bi_0.1)
+// release-NEXT: define internal fastcc { float, float }
+// release-SAME: (float noundef %i.0, float noundef %i.1)
+
+// CHECK-LABEL: ; abi_handling::f5
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal float
+// debug-SAME: (float %i.0, float %i.1)
+// release-NEXT: define internal fastcc noundef float
+// release-SAME: (float noundef %i.0, float noundef %i.1)
+#[autodiff_forward(df5, Dual, Dual)]
+#[inline(never)]
+fn f5(i: Input) -> f32 {
+    i.x + i.y
+}
+
+// CHECK-LABEL: ; abi_handling::df6
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal { float, float }
+// debug-SAME: (float %i.0, float %i.1, float %bi_0.0, float %bi_0.1)
+// release-NEXT: define internal fastcc { float, float }
+// release-SAME: float noundef %i.0, float noundef %i.1
+// release-SAME: float noundef %bi_0.0, float noundef %bi_0.1
+
+// CHECK-LABEL: ; abi_handling::f6
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal float
+// debug-SAME: (float %i.0, float %i.1)
+// release-NEXT: define internal fastcc noundef float
+// release-SAME: (float noundef %i.0, float noundef %i.1)
+#[autodiff_forward(df6, Dual, Dual)]
+#[inline(never)]
+fn f6(i: NestedInput) -> f32 {
+    i.x + i.y.z * i.y.z
+}
+
+// CHECK-LABEL: ; abi_handling::df7
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal { float, float }
+// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1, ptr align 4 %bx_0.0, ptr align 4 %bx_0.1)
+// release-NEXT: define internal fastcc { float, float }
+// release-SAME: (float %x.0.0.val, float %x.1.0.val)
+
+// CHECK-LABEL: ; abi_handling::f7
+// CHECK-NEXT: Function Attrs
+// debug-NEXT: define internal float
+// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1)
+// release-NEXT: define internal fastcc noundef float
+// release-SAME: (float %x.0.0.val, float %x.1.0.val)
+#[autodiff_forward(df7, Dual, Dual)]
+#[inline(never)]
+fn f7(x: (&f32, &f32)) -> f32 {
+    x.0 * x.1
+}
+
+fn main() {
+    let x = std::hint::black_box(2.0);
+    let y = std::hint::black_box(3.0);
+    let z = std::hint::black_box(4.0);
+    static Y: f32 = std::hint::black_box(3.2);
+
+    let in_f1 = [x, y];
+    dbg!(f1(&in_f1));
+    let res_f1 = df1(&in_f1, &[1.0, 0.0]);
+    dbg!(res_f1);
+
+    dbg!(f2(square, x));
+    let res_f2 = df2(square, x, 1.0);
+    dbg!(res_f2);
+
+    dbg!(f3(&x, &Y));
+    let res_f3 = df3(&x, &Y, &1.0, &0.0);
+    dbg!(res_f3);
+
+    let in_f4 = (x, y);
+    dbg!(f4(in_f4));
+    let res_f4 = df4(in_f4, (1.0, 0.0));
+    dbg!(res_f4);
+
+    let in_f5 = Input { x, y };
+    dbg!(f5(in_f5));
+    let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 });
+    dbg!(res_f5);
+
+    let in_f6 = NestedInput { x, y: Wrapper { z: y } };
+    dbg!(f6(in_f6));
+    let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } });
+    dbg!(res_f6);
+
+    let in_f7 = (&x, &y);
+    dbg!(f7(in_f7));
+    let res_f7 = df7(in_f7, (&1.0, &0.0));
+    dbg!(res_f7);
+}
diff --git a/tests/ui/autodiff/zst.rs b/tests/ui/autodiff/zst.rs
new file mode 100644
index 00000000000..7b9b5f5f20b
--- /dev/null
+++ b/tests/ui/autodiff/zst.rs
@@ -0,0 +1,17 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+//@ build-pass
+
+// Check that differentiating functions with ZST args does not break
+
+#![feature(autodiff)]
+
+#[core::autodiff::autodiff_forward(fd_inner, Const, Dual)]
+fn f(_zst: (), _x: &mut f64) {}
+
+fn fd(x: &mut f64, xd: &mut f64) {
+    fd_inner((), x, xd);
+}
+
+fn main() {}