diff options
| author | Marcelo DomÃnguez <dmmarcelo27@gmail.com> | 2025-08-14 15:42:14 +0000 |
|---|---|---|
| committer | Marcelo DomÃnguez <dmmarcelo27@gmail.com> | 2025-08-14 18:33:43 +0000 |
| commit | cdd4118204e48227fae6a09321b77c7c4a39a186 (patch) | |
| tree | 7f14b5d864f5ece2bc50fb911c406988fa08950e /tests/codegen-llvm/autodiff | |
| parent | e1d79b9aad9cfcfcf23c68ec8625411819d4e3f7 (diff) | |
| download | rust-cdd4118204e48227fae6a09321b77c7c4a39a186.tar.gz rust-cdd4118204e48227fae6a09321b77c7c4a39a186.zip | |
Update autodiff tests for the new intrinsics impl
Diffstat (limited to 'tests/codegen-llvm/autodiff')
| -rw-r--r-- | tests/codegen-llvm/autodiff/autodiffv2.rs | 114 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/batched.rs | 71 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/generic.rs | 17 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/identical_fnc.rs | 10 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/inline.rs | 23 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/scalar.rs | 3 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/sret.rs | 28 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/trait.rs | 30 |
8 files changed, 207 insertions, 89 deletions
diff --git a/tests/codegen-llvm/autodiff/autodiffv2.rs b/tests/codegen-llvm/autodiff/autodiffv2.rs new file mode 100644 index 00000000000..85aed6a183b --- /dev/null +++ b/tests/codegen-llvm/autodiff/autodiffv2.rs @@ -0,0 +1,114 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme +// +// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many +// breakages. One benefit is that we match the IR generated by Enzyme only after running it +// through LLVM's O3 pipeline, which will remove most of the noise. +// However, our integration test could also be affected by changes in how rustc lowers MIR into +// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should +// reduce this test to only match the first lines and the ret instructions. +// +// The function tested here has 4 inputs and 5 outputs, so we could either call forward-mode +// autodiff 4 times, or reverse mode 5 times. Since a forward-mode call is usually faster than +// reverse mode, we prefer it here. This file also tests a new optimization (batch mode), which +// allows us to call forward-mode autodiff only once, and get all 5 outputs in a single call. +// +// We support 2 different batch modes. `d_square2` has the same interface as scalar forward-mode, +// but each shadow argument is `width` times larger (thus 16 and 20 elements here). +// `d_square3` instead takes `width` (4) shadow arguments, which are all the same size as the +// original function arguments. +// +// FIXME(autodiff): We currently can't test `d_square1` and `d_square3` in the same file, since they +// generate the same dummy functions which get merged by LLVM, breaking pieces of our pipeline which +// try to rewrite the dummy functions later. We should consider to change to pure declarations both +// in our frontend and in the llvm backend to avoid these issues. + +#![feature(autodiff)] + +use std::autodiff::autodiff_forward; + +// CHECK: ; +#[no_mangle] +//#[autodiff(d_square1, Forward, Dual, Dual)] +#[autodiff_forward(d_square2, 4, Dualv, Dualv)] +#[autodiff_forward(d_square3, 4, Dual, Dual)] +fn square(x: &[f32], y: &mut [f32]) { + assert!(x.len() >= 4); + assert!(y.len() >= 5); + y[0] = 4.3 * x[0] + 1.2 * x[1] + 3.4 * x[2] + 2.1 * x[3]; + y[1] = 2.3 * x[0] + 4.5 * x[1] + 1.7 * x[2] + 6.4 * x[3]; + y[2] = 1.1 * x[0] + 3.3 * x[1] + 2.5 * x[2] + 4.7 * x[3]; + y[3] = 5.2 * x[0] + 1.4 * x[1] + 2.6 * x[2] + 3.8 * x[3]; + y[4] = 1.0 * x[0] + 2.0 * x[1] + 3.0 * x[2] + 4.0 * x[3]; +} + +fn main() { + let x1 = std::hint::black_box(vec![0.0, 1.0, 2.0, 3.0]); + + let dx1 = std::hint::black_box(vec![1.0; 12]); + + let z1 = std::hint::black_box(vec![1.0, 0.0, 0.0, 0.0]); + let z2 = std::hint::black_box(vec![0.0, 1.0, 0.0, 0.0]); + let z3 = std::hint::black_box(vec![0.0, 0.0, 1.0, 0.0]); + let z4 = std::hint::black_box(vec![0.0, 0.0, 0.0, 1.0]); + + let z5 = std::hint::black_box(vec![ + 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + ]); + + let mut y1 = std::hint::black_box(vec![0.0; 5]); + let mut y2 = std::hint::black_box(vec![0.0; 5]); + let mut y3 = std::hint::black_box(vec![0.0; 5]); + let mut y4 = std::hint::black_box(vec![0.0; 5]); + + let mut y5 = std::hint::black_box(vec![0.0; 5]); + + let mut y6 = std::hint::black_box(vec![0.0; 5]); + + let mut dy1_1 = std::hint::black_box(vec![0.0; 5]); + let mut dy1_2 = std::hint::black_box(vec![0.0; 5]); + let mut dy1_3 = std::hint::black_box(vec![0.0; 5]); + let mut dy1_4 = std::hint::black_box(vec![0.0; 5]); + + let mut dy2 = std::hint::black_box(vec![0.0; 20]); + + let mut dy3_1 = std::hint::black_box(vec![0.0; 5]); + let mut dy3_2 = std::hint::black_box(vec![0.0; 5]); + let mut dy3_3 = std::hint::black_box(vec![0.0; 5]); + let mut dy3_4 = std::hint::black_box(vec![0.0; 5]); + + // scalar. + //d_square1(&x1, &z1, &mut y1, &mut dy1_1); + //d_square1(&x1, &z2, &mut y2, &mut dy1_2); + //d_square1(&x1, &z3, &mut y3, &mut dy1_3); + //d_square1(&x1, &z4, &mut y4, &mut dy1_4); + + // assert y1 == y2 == y3 == y4 + //for i in 0..5 { + // assert_eq!(y1[i], y2[i]); + // assert_eq!(y1[i], y3[i]); + // assert_eq!(y1[i], y4[i]); + //} + + // batch mode A) + d_square2(&x1, &z5, &mut y5, &mut dy2); + + // assert y1 == y2 == y3 == y4 == y5 + //for i in 0..5 { + // assert_eq!(y1[i], y5[i]); + //} + + // batch mode B) + d_square3(&x1, &z1, &z2, &z3, &z4, &mut y6, &mut dy3_1, &mut dy3_2, &mut dy3_3, &mut dy3_4); + for i in 0..5 { + assert_eq!(y5[i], y6[i]); + } + + for i in 0..5 { + assert_eq!(dy2[0..5][i], dy3_1[i]); + assert_eq!(dy2[5..10][i], dy3_2[i]); + assert_eq!(dy2[10..15][i], dy3_3[i]); + assert_eq!(dy2[15..20][i], dy3_4[i]); + } +} diff --git a/tests/codegen-llvm/autodiff/batched.rs b/tests/codegen-llvm/autodiff/batched.rs index d27aed50e6c..306a6ed9d1f 100644 --- a/tests/codegen-llvm/autodiff/batched.rs +++ b/tests/codegen-llvm/autodiff/batched.rs @@ -17,11 +17,12 @@ use std::autodiff::autodiff_forward; #[autodiff_forward(d_square2, 4, Dual, DualOnly)] #[autodiff_forward(d_square1, 4, Dual, Dual)] #[no_mangle] +#[inline(never)] fn square(x: &f32) -> f32 { x * x } -// d_sqaure2 +// d_square2 // CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'") // CHECK-NEXT: start: // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 @@ -32,24 +33,20 @@ fn square(x: &f32) -> f32 { // CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 -// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 -// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 -// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 -// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 -// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 -// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 -// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer -// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 -// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 -// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 -// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 -// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 -// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 -// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 -// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 -// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 -// CHECK-NEXT: ret [4 x float] %19 -// CHECK-NEXT: } +// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00 +// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val +// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0 +// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00 +// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val +// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1 +// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00 +// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val +// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2 +// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00 +// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val +// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3 +// CHECK-NEXT: ret [4 x float] %15 +// CHECK-NEXT: } // d_square3, the extra float is the original return value (x * x) // CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'") @@ -63,26 +60,22 @@ fn square(x: &f32) -> f32 { // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 // CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val -// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 -// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 -// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 -// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 -// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 -// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 -// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer -// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 -// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 -// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 -// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 -// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 -// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 -// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 -// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 -// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 -// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0 -// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1 -// CHECK-NEXT: ret { float, [4 x float] } %21 -// CHECK-NEXT: } +// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00 +// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val +// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0 +// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00 +// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val +// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1 +// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00 +// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val +// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2 +// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00 +// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val +// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3 +// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0 +// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1 +// CHECK-NEXT: ret { float, [4 x float] } %17 +// CHECK-NEXT: } fn main() { let x = std::hint::black_box(3.0); diff --git a/tests/codegen-llvm/autodiff/generic.rs b/tests/codegen-llvm/autodiff/generic.rs index 2f674079be0..6f56460a2b6 100644 --- a/tests/codegen-llvm/autodiff/generic.rs +++ b/tests/codegen-llvm/autodiff/generic.rs @@ -6,27 +6,28 @@ use std::autodiff::autodiff_reverse; #[autodiff_reverse(d_square, Duplicated, Active)] +#[inline(never)] fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x } -// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called +// Ensure that `d_square::<f32>` code is generated // // CHECK: ; generic::square -// CHECK-NEXT: ; Function Attrs: -// CHECK-NEXT: define internal {{.*}} double +// CHECK-NEXT: ; Function Attrs: {{.*}} +// CHECK-NEXT: define internal {{.*}} float // CHECK-NEXT: start: // CHECK-NOT: ret -// CHECK: fmul double +// CHECK: fmul float -// Ensure that `d_square::<f32>` code is generated +// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called // // CHECK: ; generic::square -// CHECK-NEXT: ; Function Attrs: {{.*}} -// CHECK-NEXT: define internal {{.*}} float +// CHECK-NEXT: ; Function Attrs: +// CHECK-NEXT: define internal {{.*}} double // CHECK-NEXT: start: // CHECK-NOT: ret -// CHECK: fmul float +// CHECK: fmul double fn main() { let xf32: f32 = std::hint::black_box(3.0); diff --git a/tests/codegen-llvm/autodiff/identical_fnc.rs b/tests/codegen-llvm/autodiff/identical_fnc.rs index 1c25b3d09ab..6066f8cb34f 100644 --- a/tests/codegen-llvm/autodiff/identical_fnc.rs +++ b/tests/codegen-llvm/autodiff/identical_fnc.rs @@ -14,25 +14,27 @@ use std::autodiff::autodiff_reverse; #[autodiff_reverse(d_square, Duplicated, Active)] +#[inline(never)] fn square(x: &f64) -> f64 { x * x } #[autodiff_reverse(d_square2, Duplicated, Active)] +#[inline(never)] fn square2(x: &f64) -> f64 { x * x } // CHECK:; identical_fnc::main // CHECK-NEXT:; Function Attrs: -// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E() +// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17h6009e4f751bf9407E() // CHECK-NEXT:start: // CHECK-NOT:br // CHECK-NOT:ret // CHECK:; call identical_fnc::d_square -// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1) -// CHECK-NEXT:; call identical_fnc::d_square -// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2) +// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1) +// CHECK:; call identical_fnc::d_square +// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2) fn main() { let x = std::hint::black_box(3.0); diff --git a/tests/codegen-llvm/autodiff/inline.rs b/tests/codegen-llvm/autodiff/inline.rs deleted file mode 100644 index 65bed170207..00000000000 --- a/tests/codegen-llvm/autodiff/inline.rs +++ /dev/null @@ -1,23 +0,0 @@ -//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt -//@ no-prefer-dynamic -//@ needs-enzyme - -#![feature(autodiff)] - -use std::autodiff::autodiff_reverse; - -#[autodiff_reverse(d_square, Duplicated, Active)] -fn square(x: &f64) -> f64 { - x * x -} - -// CHECK: ; inline::d_square -// CHECK-NEXT: ; Function Attrs: alwaysinline -// CHECK-NOT: noinline -// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE -fn main() { - let x = std::hint::black_box(3.0); - let mut dx1 = std::hint::black_box(1.0); - let _ = d_square(&x, &mut dx1, 1.0); - assert_eq!(dx1, 6.0); -} diff --git a/tests/codegen-llvm/autodiff/scalar.rs b/tests/codegen-llvm/autodiff/scalar.rs index 096b4209e84..55b989f920d 100644 --- a/tests/codegen-llvm/autodiff/scalar.rs +++ b/tests/codegen-llvm/autodiff/scalar.rs @@ -7,11 +7,12 @@ use std::autodiff::autodiff_reverse; #[autodiff_reverse(d_square, Duplicated, Active)] #[no_mangle] +#[inline(never)] fn square(x: &f64) -> f64 { x * x } -// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'" +// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nonnull align 8 captures(none) %"x'") // CHECK-NEXT:invertstart: // CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val // CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val diff --git a/tests/codegen-llvm/autodiff/sret.rs b/tests/codegen-llvm/autodiff/sret.rs index d2fa85e3e37..dbc253ce894 100644 --- a/tests/codegen-llvm/autodiff/sret.rs +++ b/tests/codegen-llvm/autodiff/sret.rs @@ -13,30 +13,30 @@ use std::autodiff::autodiff_reverse; #[no_mangle] #[autodiff_reverse(df, Active, Active, Active)] +#[inline(never)] fn primal(x: f32, y: f32) -> f64 { (x * x * y) as f64 } -// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y) -// CHECK-NEXT:start: -// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y) -// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0 -// CHECK-NEXT: store double %.elt, ptr %_0, align 8 -// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8 -// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1 -// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8 -// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12 -// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2 -// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4 -// CHECK-NEXT: ret void -// CHECK-NEXT:} +// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y) +// CHECK-NEXT: invertstart: +// CHECK-NEXT: %_4 = fmul float %x, %x +// CHECK-NEXT: %_3 = fmul float %_4, %y +// CHECK-NEXT: %_0 = fpext float %_3 to double +// CHECK-NEXT: %0 = fadd fast float %y, %y +// CHECK-NEXT: %1 = fmul fast float %0, %x +// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0 +// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1 +// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2 +// CHECK-NEXT: ret { double, float, float } %4 +// CHECK-NEXT: } fn main() { let x = std::hint::black_box(3.0); let y = std::hint::black_box(2.5); let scalar = std::hint::black_box(1.0); let (r1, r2, r3) = df(x, y, scalar); - // 3*3*1.5 = 22.5 + // 3*3*2.5 = 22.5 assert_eq!(r1, 22.5); // 2*x*y = 2*3*2.5 = 15.0 assert_eq!(r2, 15.0); diff --git a/tests/codegen-llvm/autodiff/trait.rs b/tests/codegen-llvm/autodiff/trait.rs new file mode 100644 index 00000000000..701f3a9e843 --- /dev/null +++ b/tests/codegen-llvm/autodiff/trait.rs @@ -0,0 +1,30 @@ +//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// Just check it does not crash for now +// CHECK: ; +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +struct Foo { + a: f64, +} + +trait MyTrait { + fn f(&self, x: f64) -> f64; + fn df(&self, x: f64, seed: f64) -> (f64, f64); +} + +impl MyTrait for Foo { + #[autodiff_reverse(df, Const, Active, Active)] + fn f(&self, x: f64) -> f64 { + self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) + } +} + +fn main() { + let foo = Foo { a: 3.0f64 }; + dbg!(foo.df(1.0, 1.0)); +} |
