about summary refs log tree commit diff
path: root/tests/codegen-llvm/autodiff/batched.rs
blob: 0ff6134bc07d58f0fd1ef18adc2fe3c680bf7b44 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//@ compile-flags: -Zautodiff=Enable,NoTT,NoPostopt -C opt-level=3  -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
//
// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
// breakages. One benefit is that we match the IR generated by Enzyme only after running it
// through LLVM's O3 pipeline, which will remove most of the noise.
// However, our integration test could also be affected by changes in how rustc lowers MIR into
// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
// reduce this test to only match the first lines and the ret instructions.

#![feature(autodiff)]

use std::autodiff::autodiff_forward;

#[autodiff_forward(d_square3, Dual, DualOnly)]
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
#[autodiff_forward(d_square1, 4, Dual, Dual)]
#[no_mangle]
#[inline(never)]
fn square(x: &f32) -> f32 {
    x * x
}

// d_square2
// CHECK: define internal [4 x float] @fwddiffe4square(ptr noalias noundef readonly align 4 captures(none) dereferenceable(4) %x, [4 x ptr] %"x'")
// CHECK-NEXT: start:
// CHECK-NEXT:   %0 = extractvalue [4 x ptr] %"x'", 0
// CHECK-NEXT:   %"_2'ipl" = load float, ptr %0, align 4
// CHECK-NEXT:   %1 = extractvalue [4 x ptr] %"x'", 1
// CHECK-NEXT:   %"_2'ipl1" = load float, ptr %1, align 4
// CHECK-NEXT:   %2 = extractvalue [4 x ptr] %"x'", 2
// CHECK-NEXT:   %"_2'ipl2" = load float, ptr %2, align 4
// CHECK-NEXT:   %3 = extractvalue [4 x ptr] %"x'", 3
// CHECK-NEXT:   %"_2'ipl3" = load float, ptr %3, align 4
// CHECK-NEXT:   %_2 = load float, ptr %x, align 4
// CHECK-NEXT:   %4 = fmul fast float %"_2'ipl", %_2
// CHECK-NEXT:   %5 = fmul fast float %"_2'ipl1", %_2
// CHECK-NEXT:   %6 = fmul fast float %"_2'ipl2", %_2
// CHECK-NEXT:   %7 = fmul fast float %"_2'ipl3", %_2
// CHECK-NEXT:   %8 = fmul fast float %"_2'ipl", %_2
// CHECK-NEXT:   %9 = fmul fast float %"_2'ipl1", %_2
// CHECK-NEXT:   %10 = fmul fast float %"_2'ipl2", %_2
// CHECK-NEXT:   %11 = fmul fast float %"_2'ipl3", %_2
// CHECK-NEXT:   %12 = fadd fast float %4, %8
// CHECK-NEXT:   %13 = insertvalue [4 x float] undef, float %12, 0
// CHECK-NEXT:   %14 = fadd fast float %5, %9
// CHECK-NEXT:   %15 = insertvalue [4 x float] %13, float %14, 1
// CHECK-NEXT:   %16 = fadd fast float %6, %10
// CHECK-NEXT:   %17 = insertvalue [4 x float] %15, float %16, 2
// CHECK-NEXT:   %18 = fadd fast float %7, %11
// CHECK-NEXT:   %19 = insertvalue [4 x float] %17, float %18, 3
// CHECK-NEXT:   ret [4 x float] %19
// CHECK-NEXT:   }

// d_square3, the extra float is the original return value (x * x)
// CHECK: define internal { float, [4 x float] } @fwddiffe4square.1(ptr noalias noundef readonly align 4 captures(none) dereferenceable(4) %x, [4 x ptr] %"x'")
// CHECK-NEXT: start:
// CHECK-NEXT:   %0 = extractvalue [4 x ptr] %"x'", 0
// CHECK-NEXT:   %"_2'ipl" = load float, ptr %0, align 4
// CHECK-NEXT:   %1 = extractvalue [4 x ptr] %"x'", 1
// CHECK-NEXT:   %"_2'ipl1" = load float, ptr %1, align 4
// CHECK-NEXT:   %2 = extractvalue [4 x ptr] %"x'", 2
// CHECK-NEXT:   %"_2'ipl2" = load float, ptr %2, align 4
// CHECK-NEXT:   %3 = extractvalue [4 x ptr] %"x'", 3
// CHECK-NEXT:   %"_2'ipl3" = load float, ptr %3, align 4
// CHECK-NEXT:   %_2 = load float, ptr %x, align 4
// CHECK-NEXT:   %_0 = fmul float %_2, %_2
// CHECK-NEXT:   %4 = fmul fast float %"_2'ipl", %_2
// CHECK-NEXT:   %5 = fmul fast float %"_2'ipl1", %_2
// CHECK-NEXT:   %6 = fmul fast float %"_2'ipl2", %_2
// CHECK-NEXT:   %7 = fmul fast float %"_2'ipl3", %_2
// CHECK-NEXT:   %8 = fmul fast float %"_2'ipl", %_2
// CHECK-NEXT:   %9 = fmul fast float %"_2'ipl1", %_2
// CHECK-NEXT:   %10 = fmul fast float %"_2'ipl2", %_2
// CHECK-NEXT:   %11 = fmul fast float %"_2'ipl3", %_2
// CHECK-NEXT:   %12 = fadd fast float %4, %8
// CHECK-NEXT:   %13 = insertvalue [4 x float] undef, float %12, 0
// CHECK-NEXT:   %14 = fadd fast float %5, %9
// CHECK-NEXT:   %15 = insertvalue [4 x float] %13, float %14, 1
// CHECK-NEXT:   %16 = fadd fast float %6, %10
// CHECK-NEXT:   %17 = insertvalue [4 x float] %15, float %16, 2
// CHECK-NEXT:   %18 = fadd fast float %7, %11
// CHECK-NEXT:   %19 = insertvalue [4 x float] %17, float %18, 3
// CHECK-NEXT:   %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
// CHECK-NEXT:   %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
// CHECK-NEXT:   ret { float, [4 x float] } %21
// CHECK-NEXT:   }

fn main() {
    let x = std::hint::black_box(3.0);
    let output = square(&x);
    dbg!(&output);
    assert_eq!(9.0, output);
    dbg!(square(&x));

    let mut df_dx1 = 1.0;
    let mut df_dx2 = 2.0;
    let mut df_dx3 = 3.0;
    let mut df_dx4 = 0.0;
    let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
    dbg!(o1, o2, o3, o4);
    let [output2, o1, o2, o3, o4] =
        d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
    dbg!(o1, o2, o3, o4);
    assert_eq!(output, output2);
    assert!((6.0 - o1).abs() < 1e-10);
    assert!((12.0 - o2).abs() < 1e-10);
    assert!((18.0 - o3).abs() < 1e-10);
    assert!((0.0 - o4).abs() < 1e-10);
    assert_eq!(1.0, df_dx1);
    assert_eq!(2.0, df_dx2);
    assert_eq!(3.0, df_dx3);
    assert_eq!(0.0, df_dx4);
    assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1);
    assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2);
    assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3);
    assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4);
}