diff options
| author | Stuart Cook <Zalathar@users.noreply.github.com> | 2025-04-07 22:29:21 +1000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-07 22:29:21 +1000 |
| commit | 5863b426b9e9b97424383da9005fe692c69673d5 (patch) | |
| tree | f4aa3455fbd664e6fafa03e44bc79bbe23692c46 /tests | |
| parent | 0178254f46473eec4ac79bad750ea5fcd22362bf (diff) | |
| parent | ca5bea3ebbc4725c187abf4eac68f6c57fa938c1 (diff) | |
| download | rust-5863b426b9e9b97424383da9005fe692c69673d5.tar.gz rust-5863b426b9e9b97424383da9005fe692c69673d5.zip | |
Rollup merge of #139465 - EnzymeAD:autodiff-sret, r=oli-obk
add sret handling for scalar autodiff r? `@oli-obk` Fixing one of the todo's which I left in my previous batching PR. This one handles sret for scalar autodiff. `sret` mostly shows up when we try to return a lot of scalar floats. People often start testing autodiff which toy functions which just use a few scalars as inputs and outputs, and those were the most likely to be affected by this issue. So this fix should make learning/teaching hopefully a bit easier. Tracking: - https://github.com/rust-lang/rust/issues/124509
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/codegen/autodiff/batched.rs (renamed from tests/codegen/autodiffv.rs) | 0 | ||||
| -rw-r--r-- | tests/codegen/autodiff/scalar.rs (renamed from tests/codegen/autodiff.rs) | 0 | ||||
| -rw-r--r-- | tests/codegen/autodiff/sret.rs | 45 |
3 files changed, 45 insertions, 0 deletions
diff --git a/tests/codegen/autodiffv.rs b/tests/codegen/autodiff/batched.rs index e0047116405..e0047116405 100644 --- a/tests/codegen/autodiffv.rs +++ b/tests/codegen/autodiff/batched.rs diff --git a/tests/codegen/autodiff.rs b/tests/codegen/autodiff/scalar.rs index 85358f5fcb6..85358f5fcb6 100644 --- a/tests/codegen/autodiff.rs +++ b/tests/codegen/autodiff/scalar.rs diff --git a/tests/codegen/autodiff/sret.rs b/tests/codegen/autodiff/sret.rs new file mode 100644 index 00000000000..5ead90041ed --- /dev/null +++ b/tests/codegen/autodiff/sret.rs @@ -0,0 +1,45 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// This test is almost identical to the scalar.rs one, +// but we intentionally add a few more floats. +// `df` would ret `{ f64, f32, f32 }`, but is lowered as an sret. +// We therefore use this test to verify some of our sret handling. + +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[no_mangle] +#[autodiff(df, Reverse, Active, Active, Active)] +fn primal(x: f32, y: f32) -> f64 { + (x * x * y) as f64 +} + +// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y) +// CHECK-NEXT:start: +// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y) +// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0 +// CHECK-NEXT: store double %.elt, ptr %_0, align 8 +// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8 +// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1 +// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8 +// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12 +// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2 +// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4 +// CHECK-NEXT: ret void +// CHECK-NEXT:} + +fn main() { + let x = std::hint::black_box(3.0); + let y = std::hint::black_box(2.5); + let scalar = std::hint::black_box(1.0); + let (r1, r2, r3) = df(x, y, scalar); + // 3*3*1.5 = 22.5 + assert_eq!(r1, 22.5); + // 2*x*y = 2*3*2.5 = 15.0 + assert_eq!(r2, 15.0); + // x*x*1 = 3*3 = 9 + assert_eq!(r3, 9.0); +} |
