diff options
Diffstat (limited to 'tests/codegen-llvm/autodiff')
| -rw-r--r-- | tests/codegen-llvm/autodiff/abi_handling.rs | 4 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/batched.rs | 2 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/scalar.rs | 2 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/sret.rs | 2 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/typetree.rs | 33 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/void_ret.rs | 41 |
6 files changed, 79 insertions, 5 deletions
diff --git a/tests/codegen-llvm/autodiff/abi_handling.rs b/tests/codegen-llvm/autodiff/abi_handling.rs index 454ec698b91..5c8126898a8 100644 --- a/tests/codegen-llvm/autodiff/abi_handling.rs +++ b/tests/codegen-llvm/autodiff/abi_handling.rs @@ -1,7 +1,7 @@ //@ revisions: debug release -//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat -//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@[debug] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=0 -Clto=fat +//@[release] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat //@ no-prefer-dynamic //@ needs-enzyme diff --git a/tests/codegen-llvm/autodiff/batched.rs b/tests/codegen-llvm/autodiff/batched.rs index 306a6ed9d1f..dc82403212f 100644 --- a/tests/codegen-llvm/autodiff/batched.rs +++ b/tests/codegen-llvm/autodiff/batched.rs @@ -1,4 +1,4 @@ -//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat //@ no-prefer-dynamic //@ needs-enzyme // diff --git a/tests/codegen-llvm/autodiff/scalar.rs b/tests/codegen-llvm/autodiff/scalar.rs index 55b989f920d..53672a89230 100644 --- a/tests/codegen-llvm/autodiff/scalar.rs +++ b/tests/codegen-llvm/autodiff/scalar.rs @@ -1,4 +1,4 @@ -//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] diff --git a/tests/codegen-llvm/autodiff/sret.rs b/tests/codegen-llvm/autodiff/sret.rs index dbc253ce894..498cd3fea01 100644 --- a/tests/codegen-llvm/autodiff/sret.rs +++ b/tests/codegen-llvm/autodiff/sret.rs @@ -1,4 +1,4 @@ -//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat //@ no-prefer-dynamic //@ needs-enzyme diff --git a/tests/codegen-llvm/autodiff/typetree.rs b/tests/codegen-llvm/autodiff/typetree.rs new file mode 100644 index 00000000000..1cb0c2fb68b --- /dev/null +++ b/tests/codegen-llvm/autodiff/typetree.rs @@ -0,0 +1,33 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// Test that basic autodiff still works with our TypeTree infrastructure +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_simple, Duplicated, Active)] +#[no_mangle] +#[inline(never)] +fn simple(x: &f64) -> f64 { + 2.0 * x +} + +// CHECK-LABEL: @simple +// CHECK: fmul double + +// The derivative function should be generated normally +// CHECK-LABEL: diffesimple +// CHECK: fadd fast double + +fn main() { + let x = std::hint::black_box(3.0); + let output = simple(&x); + assert_eq!(6.0, output); + + let mut df_dx = 0.0; + let output_ = d_simple(&x, &mut df_dx, 1.0); + assert_eq!(output, output_); + assert_eq!(2.0, df_dx); +} diff --git a/tests/codegen-llvm/autodiff/void_ret.rs b/tests/codegen-llvm/autodiff/void_ret.rs new file mode 100644 index 00000000000..98c6b98eef4 --- /dev/null +++ b/tests/codegen-llvm/autodiff/void_ret.rs @@ -0,0 +1,41 @@ +//@ compile-flags: -Zautodiff=Enable,NoTT,NoPostopt -C no-prepopulate-passes -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +#![feature(autodiff)] +use std::autodiff::*; + +// Usually we would store the return value of the differentiated function. +// However, if the return type is void or an empty struct, +// we don't need to store anything. Verify this, since it caused a bug. + +// CHECK:; void_ret::main +// CHECK-NEXT: ; Function Attrs: +// CHECK-NEXT: define internal +// CHECK-NOT: store {} undef, ptr undef +// CHECK: ret void + +#[autodiff_reverse(bar, Duplicated, Duplicated)] +pub fn foo(r: &[f64; 10], res: &mut f64) { + let mut output = [0.0; 10]; + output[0] = r[0]; + output[1] = r[1] * r[2]; + output[2] = r[4] * r[5]; + output[3] = r[2] * r[6]; + output[4] = r[1] * r[7]; + output[5] = r[2] * r[8]; + output[6] = r[1] * r[9]; + output[7] = r[5] * r[6]; + output[8] = r[5] * r[7]; + output[9] = r[4] * r[8]; + *res = output.iter().sum(); +} +fn main() { + let inputs = Box::new([3.1; 10]); + let mut d_inputs = Box::new([0.0; 10]); + let mut res = Box::new(0.0); + let mut d_res = Box::new(1.0); + + bar(&inputs, &mut d_inputs, &mut res, &mut d_res); + dbg!(&d_inputs); +} |
