about summary refs log tree commit diff
path: root/compiler/rustc_builtin_macros/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_builtin_macros/src')
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs23
1 files changed, 17 insertions, 6 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 4161829480d..daebd516499 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -799,8 +799,19 @@ mod llvm_enzyme {
                         d_inputs.push(shadow_arg.clone());
                     }
                 }
-                DiffActivity::Dual | DiffActivity::DualOnly => {
-                    for i in 0..x.width {
+                DiffActivity::Dual
+                | DiffActivity::DualOnly
+                | DiffActivity::Dualv
+                | DiffActivity::DualvOnly => {
+                    // the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
+                    // Enzyme to not expect N arguments, but one argument (which is instead larger).
+                    let iterations =
+                        if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
+                            1
+                        } else {
+                            x.width
+                        };
+                    for i in 0..iterations {
                         let mut shadow_arg = arg.clone();
                         let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
                             ident.name
@@ -823,7 +834,7 @@ mod llvm_enzyme {
                 DiffActivity::Const => {
                     // Nothing to do here.
                 }
-                DiffActivity::None | DiffActivity::FakeActivitySize => {
+                DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
                     panic!("Should not happen");
                 }
             }
@@ -887,8 +898,8 @@ mod llvm_enzyme {
                 }
             };
 
-            if let DiffActivity::Dual = x.ret_activity {
-                let kind = if x.width == 1 {
+            if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
+                let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
                     // Dual can only be used for f32/f64 ret.
                     // In that case we return now a tuple with two floats.
                     TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
@@ -903,7 +914,7 @@ mod llvm_enzyme {
                 let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
                 d_decl.output = FnRetTy::Ty(ty);
             }
-            if let DiffActivity::DualOnly = x.ret_activity {
+            if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
                 // No need to change the return type,
                 // we will just return the shadow in place of the primal return.
                 // However, if we have a width > 1, then we don't return -> T, but -> [T; width]