diff options
| author | Stuart Cook <Zalathar@users.noreply.github.com> | 2025-04-05 13:18:13 +1100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-05 13:18:13 +1100 |
| commit | c6bf3a01efd2ba8720cede198d3e9b9eadfc712c (patch) | |
| tree | b1b774445d0ab1a7074a4096a84161e4d1fa4936 /tests/codegen/autodiffv.rs | |
| parent | 2e4e196a5bfe8ac1fc960a722ca8ffd9e19b8f5d (diff) | |
| parent | 89d8948835c603ec0c8234cd94772b6ac8ef9a34 (diff) | |
| download | rust-c6bf3a01efd2ba8720cede198d3e9b9eadfc712c.tar.gz rust-c6bf3a01efd2ba8720cede198d3e9b9eadfc712c.zip | |
Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk
Autodiff batching
Enzyme supports batching, which is especially known from the ML side when training neural networks.
There we would normally have a training loop, where in each iteration we would pass in some data (e.g. an image), and a target vector. Based on how close we are with our prediction we compute our loss, and then use backpropagation to compute the gradients and update our weights.
That's quite inefficient, so what you normally do is passing in a batch of 8/16/.. images and targets, and compute the gradients for those all at once, allowing better optimizations.
Enzyme supports batching in two ways, the first one (which I implemented here) just accepts a Batch size,
and then each Dual/Duplicated argument has not one, but N shadow arguments. So instead of
```rs
for i in 0..100 {
df(x[i], y[i], 1234);
}
```
You can now do
```rs
for i in 0..100.step_by(4) {
df(x[i+0],x[i+1],x[i+2],x[i+3], y[i+0], y[i+1], y[i+2], y[i+3], 1234);
}
```
which will give the same results, but allows better compiler optimizations. See the testcase for details.
There is a second variant, where we can mark certain arguments and instead of having to pass in N shadow arguments, Enzyme assumes that the argument is N times longer. I.e. instead of accepting 4 slices with 12 floats each, we would accept one slice with 48 floats. I'll implement this over the next days.
I will also add more tests for both modes.
For any one preferring some more interactive explanation, here's a video of Tim's llvm dev talk, where he presents his work. https://www.youtube.com/watch?v=edvaLAL5RqU
I'll also add some other docs to the dev guide and user docs in another PR.
r? ghost
Tracking:
- https://github.com/rust-lang/rust/issues/124509
- https://github.com/rust-lang/rust/issues/135283
Diffstat (limited to 'tests/codegen/autodiffv.rs')
| -rw-r--r-- | tests/codegen/autodiffv.rs | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/tests/codegen/autodiffv.rs b/tests/codegen/autodiffv.rs new file mode 100644 index 00000000000..e0047116405 --- /dev/null +++ b/tests/codegen/autodiffv.rs @@ -0,0 +1,116 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme +// +// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many +// breakages. One benefit is that we match the IR generated by Enzyme only after running it +// through LLVM's O3 pipeline, which will remove most of the noise. +// However, our integration test could also be affected by changes in how rustc lowers MIR into +// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should +// reduce this test to only match the first lines and the ret instructions. + +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[autodiff(d_square3, Forward, Dual, DualOnly)] +#[autodiff(d_square2, Forward, 4, Dual, DualOnly)] +#[autodiff(d_square1, Forward, 4, Dual, Dual)] +#[no_mangle] +fn square(x: &f32) -> f32 { + x * x +} + +// d_sqaure2 +// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'") +// CHECK-NEXT: start: +// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 +// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4 +// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1 +// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4 +// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2 +// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 +// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 +// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 +// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 +// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 +// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 +// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 +// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 +// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 +// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer +// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 +// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 +// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 +// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 +// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 +// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 +// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 +// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 +// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 +// CHECK-NEXT: ret [4 x float] %19 +// CHECK-NEXT: } + +// d_square3, the extra float is the original return value (x * x) +// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'") +// CHECK-NEXT: start: +// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 +// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4 +// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1 +// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4 +// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2 +// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 +// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 +// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 +// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val +// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 +// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 +// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 +// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 +// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 +// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 +// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer +// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 +// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 +// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 +// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 +// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 +// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 +// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 +// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 +// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 +// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0 +// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1 +// CHECK-NEXT: ret { float, [4 x float] } %21 +// CHECK-NEXT: } + +fn main() { + let x = std::hint::black_box(3.0); + let output = square(&x); + dbg!(&output); + assert_eq!(9.0, output); + dbg!(square(&x)); + + let mut df_dx1 = 1.0; + let mut df_dx2 = 2.0; + let mut df_dx3 = 3.0; + let mut df_dx4 = 0.0; + let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4); + dbg!(o1, o2, o3, o4); + let [output2, o1, o2, o3, o4] = + d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4); + dbg!(o1, o2, o3, o4); + assert_eq!(output, output2); + assert!((6.0 - o1).abs() < 1e-10); + assert!((12.0 - o2).abs() < 1e-10); + assert!((18.0 - o3).abs() < 1e-10); + assert!((0.0 - o4).abs() < 1e-10); + assert_eq!(1.0, df_dx1); + assert_eq!(2.0, df_dx2); + assert_eq!(3.0, df_dx3); + assert_eq!(0.0, df_dx4); + assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1); + assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2); + assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3); + assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4); +} |
