about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
diff options
context:
space:
mode:
authorManuel Drehwald <git@manuel.drehwald.info>2025-04-07 07:07:16 -0400
committerManuel Drehwald <git@manuel.drehwald.info>2025-04-07 07:07:16 -0400
commitd6467d34ae4057f493e2706a5625e0784f2a68bf (patch)
treed46b489b72a6194363a5686286a9e0ac051027d3 /compiler/rustc_codegen_llvm/src/builder/autodiff.rs
parent2fa8b11f0933dae9b4e5d287cc10c989218e8b36 (diff)
downloadrust-d6467d34ae4057f493e2706a5625e0784f2a68bf.tar.gz
rust-d6467d34ae4057f493e2706a5625e0784f2a68bf.zip
handle sret for scalar autodiff
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder/autodiff.rs')
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs24
1 files changed, 22 insertions, 2 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 7d264ba4d00..5e7ef27143b 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>(
         }
 
         if attrs.width == 1 {
-            todo!("Handle sret for scalar ad");
+            // Enzyme returns a struct of style:
+            // `{ original_ret(if requested), float, float, ... }`
+            let mut struct_elements = vec![];
+            if attrs.has_primal_ret() {
+                struct_elements.push(inner_ret_ty);
+            }
+            // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
+            // and therefore part of the return struct.
+            let param_tys = cx.func_params_types(fn_ty);
+            for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
+                if matches!(act, DiffActivity::Active) {
+                    // Now find the float type at position i based on the fn_ty,
+                    // to know what (f16/f32/f64/...) to add to the struct.
+                    struct_elements.push(param_ty);
+                }
+            }
+            ret_ty = cx.type_struct(&struct_elements, false);
         } else {
             // First we check if we also have to deal with the primal return.
             match attrs.mode {
@@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>(
                 // now store the result of the enzyme call into the sret pointer.
                 let sret_ptr = outer_args[0];
                 let call_ty = cx.val_ty(call);
-                assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
+                if attrs.width == 1 {
+                    assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
+                } else {
+                    assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
+                }
                 llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
             }
             builder.ret_void();