diff options
| author | Manuel Drehwald <git@manuel.drehwald.info> | 2025-04-07 07:07:16 -0400 |
|---|---|---|
| committer | Manuel Drehwald <git@manuel.drehwald.info> | 2025-04-07 07:07:16 -0400 |
| commit | d6467d34ae4057f493e2706a5625e0784f2a68bf (patch) | |
| tree | d46b489b72a6194363a5686286a9e0ac051027d3 /compiler/rustc_codegen_llvm/src/builder/autodiff.rs | |
| parent | 2fa8b11f0933dae9b4e5d287cc10c989218e8b36 (diff) | |
| download | rust-d6467d34ae4057f493e2706a5625e0784f2a68bf.tar.gz rust-d6467d34ae4057f493e2706a5625e0784f2a68bf.zip | |
handle sret for scalar autodiff
Diffstat (limited to 'compiler/rustc_codegen_llvm/src/builder/autodiff.rs')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder/autodiff.rs | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 7d264ba4d00..5e7ef27143b 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>( } if attrs.width == 1 { - todo!("Handle sret for scalar ad"); + // Enzyme returns a struct of style: + // `{ original_ret(if requested), float, float, ... }` + let mut struct_elements = vec![]; + if attrs.has_primal_ret() { + struct_elements.push(inner_ret_ty); + } + // Next, we push the list of active floats, since they will be lowered to `enzyme_out`, + // and therefore part of the return struct. + let param_tys = cx.func_params_types(fn_ty); + for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) { + if matches!(act, DiffActivity::Active) { + // Now find the float type at position i based on the fn_ty, + // to know what (f16/f32/f64/...) to add to the struct. + struct_elements.push(param_ty); + } + } + ret_ty = cx.type_struct(&struct_elements, false); } else { // First we check if we also have to deal with the primal return. match attrs.mode { @@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>( // now store the result of the enzyme call into the sret pointer. let sret_ptr = outer_args[0]; let call_ty = cx.val_ty(call); - assert_eq!(cx.type_kind(call_ty), TypeKind::Array); + if attrs.width == 1 { + assert_eq!(cx.type_kind(call_ty), TypeKind::Struct); + } else { + assert_eq!(cx.type_kind(call_ty), TypeKind::Array); + } llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); } builder.ret_void(); |
