about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2025-04-07 12:58:15 +0000
committerbors <bors@rust-lang.org>2025-04-07 12:58:15 +0000
commite643f59f6da3a84f43e75dea99afaa5b041ea6bf (patch)
treea5449e907bcfbb64285463eaa9ab0d29775f8fe6 /compiler/rustc_codegen_llvm/src/builder/autodiff.rs
parent8fb32ab8e563124fe0968a2878b7f5b5d0e8d722 (diff)
parent6e0b67419c71b0adcd6108d268d7eda5330bd392 (diff)
downloadrust-e643f59f6da3a84f43e75dea99afaa5b041ea6bf.tar.gz
rust-e643f59f6da3a84f43e75dea99afaa5b041ea6bf.zip
Auto merge of #139482 - Zalathar:rollup-h2ht1y6, r=Zalathar
Rollup of 9 pull requests

Successful merges:

 - #139035 (Add new `PatKind::Missing` variants)
 - #139108 (Simplify `thir::PatKind::ExpandedConstant`)
 - #139112 (Implement `super let`)
 - #139365 (Default auto traits: fix perf)
 - #139397 (coverage: Build the CGU's global file table as late as possible)
 - #139455 ( Remove support for `extern "rust-intrinsic"` blocks)
 - #139461 (Stop calling `source_span` query in significant drop order code)
 - #139465 (add sret handling for scalar autodiff)
 - #139466 (Trivial tweaks to stop tracking source span directly)

r? `@ghost`
`@rustbot` modify labels: rollup
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder/autodiff.rs')
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs24
1 files changed, 22 insertions, 2 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 7d264ba4d00..5e7ef27143b 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>(
         }
 
         if attrs.width == 1 {
-            todo!("Handle sret for scalar ad");
+            // Enzyme returns a struct of style:
+            // `{ original_ret(if requested), float, float, ... }`
+            let mut struct_elements = vec![];
+            if attrs.has_primal_ret() {
+                struct_elements.push(inner_ret_ty);
+            }
+            // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
+            // and therefore part of the return struct.
+            let param_tys = cx.func_params_types(fn_ty);
+            for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
+                if matches!(act, DiffActivity::Active) {
+                    // Now find the float type at position i based on the fn_ty,
+                    // to know what (f16/f32/f64/...) to add to the struct.
+                    struct_elements.push(param_ty);
+                }
+            }
+            ret_ty = cx.type_struct(&struct_elements, false);
         } else {
             // First we check if we also have to deal with the primal return.
             match attrs.mode {
@@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>(
                 // now store the result of the enzyme call into the sret pointer.
                 let sret_ptr = outer_args[0];
                 let call_ty = cx.val_ty(call);
-                assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
+                if attrs.width == 1 {
+                    assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
+                } else {
+                    assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
+                }
                 llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
             }
             builder.ret_void();