about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src
diff options
context:
space:
mode:
authorStuart Cook <Zalathar@users.noreply.github.com>2025-04-07 22:29:21 +1000
committerGitHub <noreply@github.com>2025-04-07 22:29:21 +1000
commit5863b426b9e9b97424383da9005fe692c69673d5 (patch)
treef4aa3455fbd664e6fafa03e44bc79bbe23692c46 /compiler/rustc_codegen_llvm/src
parent0178254f46473eec4ac79bad750ea5fcd22362bf (diff)
parentca5bea3ebbc4725c187abf4eac68f6c57fa938c1 (diff)
downloadrust-5863b426b9e9b97424383da9005fe692c69673d5.tar.gz
rust-5863b426b9e9b97424383da9005fe692c69673d5.zip
Rollup merge of #139465 - EnzymeAD:autodiff-sret, r=oli-obk
add sret handling for scalar autodiff

r? `@oli-obk`

Fixing one of the todo's which I left in my previous batching PR.
This one handles sret for scalar autodiff.  `sret` mostly shows up when we try to return a lot of scalar floats.
People often start testing autodiff which toy functions which just use a few scalars as inputs and outputs, and those were the most likely to be affected by this issue. So this fix should make learning/teaching hopefully a bit easier.

Tracking:

- https://github.com/rust-lang/rust/issues/124509
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs24
1 files changed, 22 insertions, 2 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 7d264ba4d00..5e7ef27143b 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>(
         }
 
         if attrs.width == 1 {
-            todo!("Handle sret for scalar ad");
+            // Enzyme returns a struct of style:
+            // `{ original_ret(if requested), float, float, ... }`
+            let mut struct_elements = vec![];
+            if attrs.has_primal_ret() {
+                struct_elements.push(inner_ret_ty);
+            }
+            // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
+            // and therefore part of the return struct.
+            let param_tys = cx.func_params_types(fn_ty);
+            for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
+                if matches!(act, DiffActivity::Active) {
+                    // Now find the float type at position i based on the fn_ty,
+                    // to know what (f16/f32/f64/...) to add to the struct.
+                    struct_elements.push(param_ty);
+                }
+            }
+            ret_ty = cx.type_struct(&struct_elements, false);
         } else {
             // First we check if we also have to deal with the primal return.
             match attrs.mode {
@@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>(
                 // now store the result of the enzyme call into the sret pointer.
                 let sret_ptr = outer_args[0];
                 let call_ty = cx.val_ty(call);
-                assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
+                if attrs.width == 1 {
+                    assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
+                } else {
+                    assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
+                }
                 llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
             }
             builder.ret_void();