about summary refs log tree commit diff
path: root/tests/codegen-llvm/autodiff
diff options
context:
space:
mode:
Diffstat (limited to 'tests/codegen-llvm/autodiff')
-rw-r--r--tests/codegen-llvm/autodiff/batched.rs116
-rw-r--r--tests/codegen-llvm/autodiff/generic.rs42
-rw-r--r--tests/codegen-llvm/autodiff/identical_fnc.rs45
-rw-r--r--tests/codegen-llvm/autodiff/inline.rs23
-rw-r--r--tests/codegen-llvm/autodiff/scalar.rs33
-rw-r--r--tests/codegen-llvm/autodiff/sret.rs45
6 files changed, 304 insertions, 0 deletions
diff --git a/tests/codegen-llvm/autodiff/batched.rs b/tests/codegen-llvm/autodiff/batched.rs
new file mode 100644
index 00000000000..d27aed50e6c
--- /dev/null
+++ b/tests/codegen-llvm/autodiff/batched.rs
@@ -0,0 +1,116 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+//
+// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
+// breakages. One benefit is that we match the IR generated by Enzyme only after running it
+// through LLVM's O3 pipeline, which will remove most of the noise.
+// However, our integration test could also be affected by changes in how rustc lowers MIR into
+// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
+// reduce this test to only match the first lines and the ret instructions.
+
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_forward;
+
+#[autodiff_forward(d_square3, Dual, DualOnly)]
+#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
+#[autodiff_forward(d_square1, 4, Dual, Dual)]
+#[no_mangle]
+fn square(x: &f32) -> f32 {
+    x * x
+}
+
+// d_sqaure2
+// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
+// CHECK-NEXT: start:
+// CHECK-NEXT:   %0 = extractvalue [4 x ptr] %"x'", 0
+// CHECK-NEXT:   %"_2'ipl" = load float, ptr %0, align 4
+// CHECK-NEXT:   %1 = extractvalue [4 x ptr] %"x'", 1
+// CHECK-NEXT:   %"_2'ipl1" = load float, ptr %1, align 4
+// CHECK-NEXT:   %2 = extractvalue [4 x ptr] %"x'", 2
+// CHECK-NEXT:   %"_2'ipl2" = load float, ptr %2, align 4
+// CHECK-NEXT:   %3 = extractvalue [4 x ptr] %"x'", 3
+// CHECK-NEXT:   %"_2'ipl3" = load float, ptr %3, align 4
+// CHECK-NEXT:   %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
+// CHECK-NEXT:   %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
+// CHECK-NEXT:   %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
+// CHECK-NEXT:   %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
+// CHECK-NEXT:   %8 = fadd fast <4 x float> %7, %7
+// CHECK-NEXT:   %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
+// CHECK-NEXT:   %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:   %11 = fmul fast <4 x float> %8, %10
+// CHECK-NEXT:   %12 = extractelement <4 x float> %11, i64 0
+// CHECK-NEXT:   %13 = insertvalue [4 x float] undef, float %12, 0
+// CHECK-NEXT:   %14 = extractelement <4 x float> %11, i64 1
+// CHECK-NEXT:   %15 = insertvalue [4 x float] %13, float %14, 1
+// CHECK-NEXT:   %16 = extractelement <4 x float> %11, i64 2
+// CHECK-NEXT:   %17 = insertvalue [4 x float] %15, float %16, 2
+// CHECK-NEXT:   %18 = extractelement <4 x float> %11, i64 3
+// CHECK-NEXT:   %19 = insertvalue [4 x float] %17, float %18, 3
+// CHECK-NEXT:   ret [4 x float] %19
+// CHECK-NEXT: }
+
+// d_square3, the extra float is the original return value (x * x)
+// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
+// CHECK-NEXT: start:
+// CHECK-NEXT:   %0 = extractvalue [4 x ptr] %"x'", 0
+// CHECK-NEXT:   %"_2'ipl" = load float, ptr %0, align 4
+// CHECK-NEXT:   %1 = extractvalue [4 x ptr] %"x'", 1
+// CHECK-NEXT:   %"_2'ipl1" = load float, ptr %1, align 4
+// CHECK-NEXT:   %2 = extractvalue [4 x ptr] %"x'", 2
+// CHECK-NEXT:   %"_2'ipl2" = load float, ptr %2, align 4
+// CHECK-NEXT:   %3 = extractvalue [4 x ptr] %"x'", 3
+// CHECK-NEXT:   %"_2'ipl3" = load float, ptr %3, align 4
+// CHECK-NEXT:   %_0 = fmul float %x.0.val, %x.0.val
+// CHECK-NEXT:   %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
+// CHECK-NEXT:   %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
+// CHECK-NEXT:   %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
+// CHECK-NEXT:   %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
+// CHECK-NEXT:   %8 = fadd fast <4 x float> %7, %7
+// CHECK-NEXT:   %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
+// CHECK-NEXT:   %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:   %11 = fmul fast <4 x float> %8, %10
+// CHECK-NEXT:   %12 = extractelement <4 x float> %11, i64 0
+// CHECK-NEXT:   %13 = insertvalue [4 x float] undef, float %12, 0
+// CHECK-NEXT:   %14 = extractelement <4 x float> %11, i64 1
+// CHECK-NEXT:   %15 = insertvalue [4 x float] %13, float %14, 1
+// CHECK-NEXT:   %16 = extractelement <4 x float> %11, i64 2
+// CHECK-NEXT:   %17 = insertvalue [4 x float] %15, float %16, 2
+// CHECK-NEXT:   %18 = extractelement <4 x float> %11, i64 3
+// CHECK-NEXT:   %19 = insertvalue [4 x float] %17, float %18, 3
+// CHECK-NEXT:   %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
+// CHECK-NEXT:   %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
+// CHECK-NEXT:   ret { float, [4 x float] } %21
+// CHECK-NEXT: }
+
+fn main() {
+    let x = std::hint::black_box(3.0);
+    let output = square(&x);
+    dbg!(&output);
+    assert_eq!(9.0, output);
+    dbg!(square(&x));
+
+    let mut df_dx1 = 1.0;
+    let mut df_dx2 = 2.0;
+    let mut df_dx3 = 3.0;
+    let mut df_dx4 = 0.0;
+    let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
+    dbg!(o1, o2, o3, o4);
+    let [output2, o1, o2, o3, o4] =
+        d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
+    dbg!(o1, o2, o3, o4);
+    assert_eq!(output, output2);
+    assert!((6.0 - o1).abs() < 1e-10);
+    assert!((12.0 - o2).abs() < 1e-10);
+    assert!((18.0 - o3).abs() < 1e-10);
+    assert!((0.0 - o4).abs() < 1e-10);
+    assert_eq!(1.0, df_dx1);
+    assert_eq!(2.0, df_dx2);
+    assert_eq!(3.0, df_dx3);
+    assert_eq!(0.0, df_dx4);
+    assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1);
+    assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2);
+    assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3);
+    assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4);
+}
diff --git a/tests/codegen-llvm/autodiff/generic.rs b/tests/codegen-llvm/autodiff/generic.rs
new file mode 100644
index 00000000000..2f674079be0
--- /dev/null
+++ b/tests/codegen-llvm/autodiff/generic.rs
@@ -0,0 +1,42 @@
+//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+#[autodiff_reverse(d_square, Duplicated, Active)]
+fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
+    *x * *x
+}
+
+// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
+//
+// CHECK: ; generic::square
+// CHECK-NEXT: ; Function Attrs:
+// CHECK-NEXT: define internal {{.*}} double
+// CHECK-NEXT: start:
+// CHECK-NOT: ret
+// CHECK: fmul double
+
+// Ensure that `d_square::<f32>` code is generated
+//
+// CHECK: ; generic::square
+// CHECK-NEXT: ; Function Attrs: {{.*}}
+// CHECK-NEXT: define internal {{.*}} float
+// CHECK-NEXT: start:
+// CHECK-NOT: ret
+// CHECK: fmul float
+
+fn main() {
+    let xf32: f32 = std::hint::black_box(3.0);
+    let xf64: f64 = std::hint::black_box(3.0);
+
+    let outputf32 = square::<f32>(&xf32);
+    assert_eq!(9.0, outputf32);
+
+    let mut df_dxf64: f64 = std::hint::black_box(0.0);
+
+    let output_f64 = d_square::<f64>(&xf64, &mut df_dxf64, 1.0);
+    assert_eq!(6.0, df_dxf64);
+}
diff --git a/tests/codegen-llvm/autodiff/identical_fnc.rs b/tests/codegen-llvm/autodiff/identical_fnc.rs
new file mode 100644
index 00000000000..1c25b3d09ab
--- /dev/null
+++ b/tests/codegen-llvm/autodiff/identical_fnc.rs
@@ -0,0 +1,45 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+//
+// Each autodiff invocation creates a new placeholder function, which we will replace on llvm-ir
+// level. If a user tries to differentiate two identical functions within the same compilation unit,
+// then LLVM might merge them in release mode before AD. In that case we can't rewrite one of the
+// merged placeholder function anymore, and compilation would fail. We prevent this by disabling
+// LLVM's merge_function pass before AD. Here we implicetely test that our solution keeps working.
+// We also explicetly test that we keep running merge_function after AD, by checking for two
+// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+#[autodiff_reverse(d_square, Duplicated, Active)]
+fn square(x: &f64) -> f64 {
+    x * x
+}
+
+#[autodiff_reverse(d_square2, Duplicated, Active)]
+fn square2(x: &f64) -> f64 {
+    x * x
+}
+
+// CHECK:; identical_fnc::main
+// CHECK-NEXT:; Function Attrs:
+// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
+// CHECK-NEXT:start:
+// CHECK-NOT:br
+// CHECK-NOT:ret
+// CHECK:; call identical_fnc::d_square
+// CHECK-NEXT:  call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
+// CHECK-NEXT:; call identical_fnc::d_square
+// CHECK-NEXT:  call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
+
+fn main() {
+    let x = std::hint::black_box(3.0);
+    let mut dx1 = std::hint::black_box(1.0);
+    let mut dx2 = std::hint::black_box(1.0);
+    let _ = d_square(&x, &mut dx1, 1.0);
+    let _ = d_square2(&x, &mut dx2, 1.0);
+    assert_eq!(dx1, 6.0);
+    assert_eq!(dx2, 6.0);
+}
diff --git a/tests/codegen-llvm/autodiff/inline.rs b/tests/codegen-llvm/autodiff/inline.rs
new file mode 100644
index 00000000000..65bed170207
--- /dev/null
+++ b/tests/codegen-llvm/autodiff/inline.rs
@@ -0,0 +1,23 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat -Zautodiff=NoPostopt
+//@ no-prefer-dynamic
+//@ needs-enzyme
+
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+#[autodiff_reverse(d_square, Duplicated, Active)]
+fn square(x: &f64) -> f64 {
+    x * x
+}
+
+// CHECK: ; inline::d_square
+// CHECK-NEXT: ; Function Attrs: alwaysinline
+// CHECK-NOT: noinline
+// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE
+fn main() {
+    let x = std::hint::black_box(3.0);
+    let mut dx1 = std::hint::black_box(1.0);
+    let _ = d_square(&x, &mut dx1, 1.0);
+    assert_eq!(dx1, 6.0);
+}
diff --git a/tests/codegen-llvm/autodiff/scalar.rs b/tests/codegen-llvm/autodiff/scalar.rs
new file mode 100644
index 00000000000..096b4209e84
--- /dev/null
+++ b/tests/codegen-llvm/autodiff/scalar.rs
@@ -0,0 +1,33 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+#[autodiff_reverse(d_square, Duplicated, Active)]
+#[no_mangle]
+fn square(x: &f64) -> f64 {
+    x * x
+}
+
+// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
+// CHECK-NEXT:invertstart:
+// CHECK-NEXT:  %_0 = fmul double %x.0.val, %x.0.val
+// CHECK-NEXT:  %0 = fadd fast double %x.0.val, %x.0.val
+// CHECK-NEXT:  %1 = load double, ptr %"x'", align 8
+// CHECK-NEXT:  %2 = fadd fast double %1, %0
+// CHECK-NEXT:  store double %2, ptr %"x'", align 8
+// CHECK-NEXT:  ret double %_0
+// CHECK-NEXT:}
+
+fn main() {
+    let x = std::hint::black_box(3.0);
+    let output = square(&x);
+    assert_eq!(9.0, output);
+
+    let mut df_dx = 0.0;
+    let output_ = d_square(&x, &mut df_dx, 1.0);
+    assert_eq!(output, output_);
+    assert_eq!(6.0, df_dx);
+}
diff --git a/tests/codegen-llvm/autodiff/sret.rs b/tests/codegen-llvm/autodiff/sret.rs
new file mode 100644
index 00000000000..d2fa85e3e37
--- /dev/null
+++ b/tests/codegen-llvm/autodiff/sret.rs
@@ -0,0 +1,45 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+
+// This test is almost identical to the scalar.rs one,
+// but we intentionally add a few more floats.
+// `df` would ret `{ f64, f32, f32 }`, but is lowered as an sret.
+// We therefore use this test to verify some of our sret handling.
+
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+#[no_mangle]
+#[autodiff_reverse(df, Active, Active, Active)]
+fn primal(x: f32, y: f32) -> f64 {
+    (x * x * y) as f64
+}
+
+// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y)
+// CHECK-NEXT:start:
+// CHECK-NEXT:  %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y)
+// CHECK-NEXT:  %.elt = extractvalue { double, float, float } %0, 0
+// CHECK-NEXT:  store double %.elt, ptr %_0, align 8
+// CHECK-NEXT:  %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8
+// CHECK-NEXT:  %.elt2 = extractvalue { double, float, float } %0, 1
+// CHECK-NEXT:  store float %.elt2, ptr %_0.repack1, align 8
+// CHECK-NEXT:  %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12
+// CHECK-NEXT:  %.elt4 = extractvalue { double, float, float } %0, 2
+// CHECK-NEXT:  store float %.elt4, ptr %_0.repack3, align 4
+// CHECK-NEXT:  ret void
+// CHECK-NEXT:}
+
+fn main() {
+    let x = std::hint::black_box(3.0);
+    let y = std::hint::black_box(2.5);
+    let scalar = std::hint::black_box(1.0);
+    let (r1, r2, r3) = df(x, y, scalar);
+    // 3*3*1.5 = 22.5
+    assert_eq!(r1, 22.5);
+    // 2*x*y = 2*3*2.5 = 15.0
+    assert_eq!(r2, 15.0);
+    // x*x*1 = 3*3 = 9
+    assert_eq!(r3, 9.0);
+}