about summary refs log tree commit diff
path: root/tests/codegen-llvm/autodiff/generic.rs
blob: 6f56460a2b6d179cdc002bdc2e86f2958794482f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;

#[autodiff_reverse(d_square, Duplicated, Active)]
#[inline(never)]
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
    *x * *x
}

// Ensure that `d_square::<f32>` code is generated
//
// CHECK: ; generic::square
// CHECK-NEXT: ; Function Attrs: {{.*}}
// CHECK-NEXT: define internal {{.*}} float
// CHECK-NEXT: start:
// CHECK-NOT: ret
// CHECK: fmul float

// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
//
// CHECK: ; generic::square
// CHECK-NEXT: ; Function Attrs:
// CHECK-NEXT: define internal {{.*}} double
// CHECK-NEXT: start:
// CHECK-NOT: ret
// CHECK: fmul double

fn main() {
    let xf32: f32 = std::hint::black_box(3.0);
    let xf64: f64 = std::hint::black_box(3.0);

    let outputf32 = square::<f32>(&xf32);
    assert_eq!(9.0, outputf32);

    let mut df_dxf64: f64 = std::hint::black_box(0.0);

    let output_f64 = d_square::<f64>(&xf64, &mut df_dxf64, 1.0);
    assert_eq!(6.0, df_dxf64);
}