about summary refs log tree commit diff
path: root/tests/pretty
diff options
context:
space:
mode:
authorStuart Cook <Zalathar@users.noreply.github.com>2025-04-05 13:18:13 +1100
committerGitHub <noreply@github.com>2025-04-05 13:18:13 +1100
commitc6bf3a01efd2ba8720cede198d3e9b9eadfc712c (patch)
treeb1b774445d0ab1a7074a4096a84161e4d1fa4936 /tests/pretty
parent2e4e196a5bfe8ac1fc960a722ca8ffd9e19b8f5d (diff)
parent89d8948835c603ec0c8234cd94772b6ac8ef9a34 (diff)
downloadrust-c6bf3a01efd2ba8720cede198d3e9b9eadfc712c.tar.gz
rust-c6bf3a01efd2ba8720cede198d3e9b9eadfc712c.zip
Rollup merge of #137880 - EnzymeAD:autodiff-batching, r=oli-obk
Autodiff batching

Enzyme supports batching, which is especially known from the ML side when training neural networks.
There we would normally have a training loop, where in each iteration we would pass in some data (e.g. an image), and a target vector. Based on how close we are with our prediction we compute our loss, and then use backpropagation to compute the gradients and update our weights.
That's quite inefficient, so what you normally do is passing in a batch of 8/16/.. images and targets, and compute the gradients for those all at once, allowing better optimizations.

Enzyme supports batching in two ways, the first one (which I implemented here) just accepts a Batch size,
and then each Dual/Duplicated argument has not one, but N shadow arguments.  So instead of
```rs
for i in 0..100 {
   df(x[i], y[i], 1234);
}
```
You can now do
```rs
for i in 0..100.step_by(4) {
   df(x[i+0],x[i+1],x[i+2],x[i+3], y[i+0], y[i+1], y[i+2], y[i+3], 1234);
}
```
which will give the same results, but allows better compiler optimizations. See the testcase for details.

There is a second variant, where we can mark certain arguments and instead of having to pass in N shadow arguments, Enzyme assumes that the argument is N times longer. I.e. instead of accepting 4 slices with 12 floats each, we would accept one slice with 48 floats. I'll implement this over the next days.

I will also add more tests for both modes.

For any one preferring some more interactive explanation, here's a video of Tim's llvm dev talk, where he presents his work. https://www.youtube.com/watch?v=edvaLAL5RqU
I'll also add some other docs to the dev guide and user docs in another PR.

r? ghost

Tracking:

- https://github.com/rust-lang/rust/issues/124509
- https://github.com/rust-lang/rust/issues/135283
Diffstat (limited to 'tests/pretty')
-rw-r--r--tests/pretty/autodiff_forward.pp100
-rw-r--r--tests/pretty/autodiff_forward.rs18
-rw-r--r--tests/pretty/autodiff_reverse.pp22
3 files changed, 108 insertions, 32 deletions
diff --git a/tests/pretty/autodiff_forward.pp b/tests/pretty/autodiff_forward.pp
index dc7a2712f42..4b2fb6166ff 100644
--- a/tests/pretty/autodiff_forward.pp
+++ b/tests/pretty/autodiff_forward.pp
@@ -25,27 +25,31 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
 
     // We want to be sure that the same function can be differentiated in different ways
 
+
+    // Make sure, that we add the None for the default return.
+
+
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Forward, Dual, Const, Dual,)]
+#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
 #[inline(never)]
-pub fn df1(x: &[f64], bx: &[f64], y: f64) -> (f64, f64) {
+pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f1(x, y));
-    ::core::hint::black_box((bx,));
-    ::core::hint::black_box((f1(x, y), f64::default()))
+    ::core::hint::black_box((bx_0,));
+    ::core::hint::black_box(<(f64, f64)>::default())
 }
 #[rustc_autodiff]
 #[inline(never)]
 pub fn f2(x: &[f64], y: f64) -> f64 {
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Forward, Dual, Const, Const,)]
+#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
 #[inline(never)]
-pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 {
+pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f2(x, y));
-    ::core::hint::black_box((bx,));
+    ::core::hint::black_box((bx_0,));
     ::core::hint::black_box(f2(x, y))
 }
 #[rustc_autodiff]
@@ -53,20 +57,20 @@ pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 {
 pub fn f3(x: &[f64], y: f64) -> f64 {
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Forward, Dual, Const, Const,)]
+#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
 #[inline(never)]
-pub fn df3(x: &[f64], bx: &[f64], y: f64) -> f64 {
+pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f3(x, y));
-    ::core::hint::black_box((bx,));
+    ::core::hint::black_box((bx_0,));
     ::core::hint::black_box(f3(x, y))
 }
 #[rustc_autodiff]
 #[inline(never)]
 pub fn f4() {}
-#[rustc_autodiff(Forward, None)]
+#[rustc_autodiff(Forward, 1, None)]
 #[inline(never)]
-pub fn df4() {
+pub fn df4() -> () {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f4());
     ::core::hint::black_box(());
@@ -76,28 +80,82 @@ pub fn df4() {
 pub fn f5(x: &[f64], y: f64) -> f64 {
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Forward, Const, Dual, Const,)]
+#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
 #[inline(never)]
-pub fn df5_y(x: &[f64], y: f64, by: f64) -> f64 {
+pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f5(x, y));
-    ::core::hint::black_box((by,));
+    ::core::hint::black_box((by_0,));
     ::core::hint::black_box(f5(x, y))
 }
-#[rustc_autodiff(Forward, Dual, Const, Const,)]
+#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
 #[inline(never)]
-pub fn df5_x(x: &[f64], bx: &[f64], y: f64) -> f64 {
+pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f5(x, y));
-    ::core::hint::black_box((bx,));
+    ::core::hint::black_box((bx_0,));
     ::core::hint::black_box(f5(x, y))
 }
-#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
+#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
 #[inline(never)]
-pub fn df5_rev(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
+pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f5(x, y));
-    ::core::hint::black_box((dx, dret));
+    ::core::hint::black_box((dx_0, dret));
     ::core::hint::black_box(f5(x, y))
 }
+struct DoesNotImplDefault;
+#[rustc_autodiff]
+#[inline(never)]
+pub fn f6() -> DoesNotImplDefault {
+    ::core::panicking::panic("not implemented")
+}
+#[rustc_autodiff(Forward, 1, Const)]
+#[inline(never)]
+pub fn df6() -> DoesNotImplDefault {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f6());
+    ::core::hint::black_box(());
+    ::core::hint::black_box(f6())
+}
+#[rustc_autodiff]
+#[inline(never)]
+pub fn f7(x: f32) -> () {}
+#[rustc_autodiff(Forward, 1, Const, None)]
+#[inline(never)]
+pub fn df7(x: f32) -> () {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f7(x));
+    ::core::hint::black_box(());
+}
+#[no_mangle]
+#[rustc_autodiff]
+#[inline(never)]
+fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
+#[rustc_autodiff(Forward, 4, Dual, Dual)]
+#[inline(never)]
+fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
+    -> [f32; 5usize] {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f8(x));
+    ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
+    ::core::hint::black_box(<[f32; 5usize]>::default())
+}
+#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
+#[inline(never)]
+fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
+    -> [f32; 4usize] {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f8(x));
+    ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
+    ::core::hint::black_box(<[f32; 4usize]>::default())
+}
+#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
+#[inline(never)]
+fn f8_1(x: &f32, bx_0: &f32) -> f32 {
+    unsafe { asm!("NOP", options(pure, nomem)); };
+    ::core::hint::black_box(f8(x));
+    ::core::hint::black_box((bx_0,));
+    ::core::hint::black_box(<f32>::default())
+}
 fn main() {}
diff --git a/tests/pretty/autodiff_forward.rs b/tests/pretty/autodiff_forward.rs
index bc558211632..a765738c2a8 100644
--- a/tests/pretty/autodiff_forward.rs
+++ b/tests/pretty/autodiff_forward.rs
@@ -36,4 +36,22 @@ pub fn f5(x: &[f64], y: f64) -> f64 {
     unimplemented!()
 }
 
+struct DoesNotImplDefault;
+#[autodiff(df6, Forward, Const)]
+pub fn f6() -> DoesNotImplDefault {
+    unimplemented!()
+}
+
+// Make sure, that we add the None for the default return.
+#[autodiff(df7, Forward, Const)]
+pub fn f7(x: f32) -> () {}
+
+#[autodiff(f8_1, Forward, Dual, DualOnly)]
+#[autodiff(f8_2, Forward, 4, Dual, DualOnly)]
+#[autodiff(f8_3, Forward, 4, Dual, Dual)]
+#[no_mangle]
+fn f8(x: &f32) -> f32 {
+    unimplemented!()
+}
+
 fn main() {}
diff --git a/tests/pretty/autodiff_reverse.pp b/tests/pretty/autodiff_reverse.pp
index b2cf0244af4..31920694a3a 100644
--- a/tests/pretty/autodiff_reverse.pp
+++ b/tests/pretty/autodiff_reverse.pp
@@ -28,18 +28,18 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
 
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
+#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
 #[inline(never)]
-pub fn df1(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
+pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f1(x, y));
-    ::core::hint::black_box((dx, dret));
+    ::core::hint::black_box((dx_0, dret));
     ::core::hint::black_box(f1(x, y))
 }
 #[rustc_autodiff]
 #[inline(never)]
 pub fn f2() {}
-#[rustc_autodiff(Reverse, None)]
+#[rustc_autodiff(Reverse, 1, None)]
 #[inline(never)]
 pub fn df2() {
     unsafe { asm!("NOP", options(pure, nomem)); };
@@ -51,12 +51,12 @@ pub fn df2() {
 pub fn f3(x: &[f64], y: f64) -> f64 {
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
+#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
 #[inline(never)]
-pub fn df3(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
+pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f3(x, y));
-    ::core::hint::black_box((dx, dret));
+    ::core::hint::black_box((dx_0, dret));
     ::core::hint::black_box(f3(x, y))
 }
 enum Foo { Reverse, }
@@ -64,7 +64,7 @@ use Foo::Reverse;
 #[rustc_autodiff]
 #[inline(never)]
 pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
-#[rustc_autodiff(Reverse, Const, None)]
+#[rustc_autodiff(Reverse, 1, Const, None)]
 #[inline(never)]
 pub fn df4(x: f32) {
     unsafe { asm!("NOP", options(pure, nomem)); };
@@ -76,11 +76,11 @@ pub fn df4(x: f32) {
 pub fn f5(x: *const f32, y: &f32) {
     ::core::panicking::panic("not implemented")
 }
-#[rustc_autodiff(Reverse, DuplicatedOnly, Duplicated, None)]
+#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
 #[inline(never)]
-pub unsafe fn df5(x: *const f32, dx: *mut f32, y: &f32, dy: &mut f32) {
+pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
     unsafe { asm!("NOP", options(pure, nomem)); };
     ::core::hint::black_box(f5(x, y));
-    ::core::hint::black_box((dx, dy));
+    ::core::hint::black_box((dx_0, dy_0));
 }
 fn main() {}