about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs11
-rw-r--r--tests/ui/autodiff/zst.rs17
2 files changed, 27 insertions, 1 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index a19a0d867ac..78deffa3a7a 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -29,6 +29,7 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
 
     let mut new_activities = vec![];
     let mut new_positions = vec![];
+    let mut del_activities = 0;
     for (i, ty) in sig.inputs().iter().enumerate() {
         if let Some(inner_ty) = ty.builtin_deref(true) {
             if inner_ty.is_slice() {
@@ -90,12 +91,20 @@ pub(crate) fn adjust_activity_to_abi<'tcx>(
             }
         };
 
+        // For ZST, just ignore and don't add its activity, as this arg won't be present
+        // in the LLVM passed to Enzyme.
+        // FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
+        if layout.is_zst() {
+            del_activities += 1;
+            da.remove(i);
+        }
+
         // If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
         // Otherwise, the number of activities won't match the number of LLVM arguments and
         // this will lead to errors when verifying the Enzyme call.
         if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
             new_activities.push(da[i].clone());
-            new_positions.push(i + 1);
+            new_positions.push(i + 1 - del_activities);
         }
     }
     // now add the extra activities coming from slices
diff --git a/tests/ui/autodiff/zst.rs b/tests/ui/autodiff/zst.rs
new file mode 100644
index 00000000000..7b9b5f5f20b
--- /dev/null
+++ b/tests/ui/autodiff/zst.rs
@@ -0,0 +1,17 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+//@ build-pass
+
+// Check that differentiating functions with ZST args does not break
+
+#![feature(autodiff)]
+
+#[core::autodiff::autodiff_forward(fd_inner, Const, Dual)]
+fn f(_zst: (), _x: &mut f64) {}
+
+fn fd(x: &mut f64, xd: &mut f64) {
+    fd_inner((), x, xd);
+}
+
+fn main() {}