diff options
| author | Stuart Cook <Zalathar@users.noreply.github.com> | 2025-04-07 22:29:21 +1000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-07 22:29:21 +1000 |
| commit | 5863b426b9e9b97424383da9005fe692c69673d5 (patch) | |
| tree | f4aa3455fbd664e6fafa03e44bc79bbe23692c46 /compiler/rustc_codegen_llvm/src | |
| parent | 0178254f46473eec4ac79bad750ea5fcd22362bf (diff) | |
| parent | ca5bea3ebbc4725c187abf4eac68f6c57fa938c1 (diff) | |
| download | rust-5863b426b9e9b97424383da9005fe692c69673d5.tar.gz rust-5863b426b9e9b97424383da9005fe692c69673d5.zip | |
Rollup merge of #139465 - EnzymeAD:autodiff-sret, r=oli-obk
add sret handling for scalar autodiff r? `@oli-obk` Fixing one of the todo's which I left in my previous batching PR. This one handles sret for scalar autodiff. `sret` mostly shows up when we try to return a lot of scalar floats. People often start testing autodiff which toy functions which just use a few scalars as inputs and outputs, and those were the most likely to be affected by this issue. So this fix should make learning/teaching hopefully a bit easier. Tracking: - https://github.com/rust-lang/rust/issues/124509
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder/autodiff.rs | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 7d264ba4d00..5e7ef27143b 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>( } if attrs.width == 1 { - todo!("Handle sret for scalar ad"); + // Enzyme returns a struct of style: + // `{ original_ret(if requested), float, float, ... }` + let mut struct_elements = vec![]; + if attrs.has_primal_ret() { + struct_elements.push(inner_ret_ty); + } + // Next, we push the list of active floats, since they will be lowered to `enzyme_out`, + // and therefore part of the return struct. + let param_tys = cx.func_params_types(fn_ty); + for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) { + if matches!(act, DiffActivity::Active) { + // Now find the float type at position i based on the fn_ty, + // to know what (f16/f32/f64/...) to add to the struct. + struct_elements.push(param_ty); + } + } + ret_ty = cx.type_struct(&struct_elements, false); } else { // First we check if we also have to deal with the primal return. match attrs.mode { @@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>( // now store the result of the enzyme call into the sret pointer. let sret_ptr = outer_args[0]; let call_ty = cx.val_ty(call); - assert_eq!(cx.type_kind(call_ty), TypeKind::Array); + if attrs.width == 1 { + assert_eq!(cx.type_kind(call_ty), TypeKind::Struct); + } else { + assert_eq!(cx.type_kind(call_ty), TypeKind::Array); + } llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); } builder.ret_void(); |
