about summary refs log tree commit diff
path: root/tests/codegen-llvm/autodiff
diff options
context:
space:
mode:
Diffstat (limited to 'tests/codegen-llvm/autodiff')
-rw-r--r--tests/codegen-llvm/autodiff/autodiffv2.rs114
-rw-r--r--tests/codegen-llvm/autodiff/batched.rs71
-rw-r--r--tests/codegen-llvm/autodiff/generic.rs17
-rw-r--r--tests/codegen-llvm/autodiff/identical_fnc.rs10
-rw-r--r--tests/codegen-llvm/autodiff/inline.rs23
-rw-r--r--tests/codegen-llvm/autodiff/scalar.rs3
-rw-r--r--tests/codegen-llvm/autodiff/sret.rs28
-rw-r--r--tests/codegen-llvm/autodiff/trait.rs30
8 files changed, 207 insertions, 89 deletions
diff --git a/tests/codegen-llvm/autodiff/autodiffv2.rs b/tests/codegen-llvm/autodiff/autodiffv2.rs
new file mode 100644
index 00000000000..85aed6a183b
--- /dev/null
+++ b/tests/codegen-llvm/autodiff/autodiffv2.rs
@@ -0,0 +1,114 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+//
+// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
+// breakages. One benefit is that we match the IR generated by Enzyme only after running it
+// through LLVM's O3 pipeline, which will remove most of the noise.
+// However, our integration test could also be affected by changes in how rustc lowers MIR into
+// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
+// reduce this test to only match the first lines and the ret instructions.
+//
+// The function tested here has 4 inputs and 5 outputs, so we could either call forward-mode
+// autodiff 4 times, or reverse mode 5 times. Since a forward-mode call is usually faster than
+// reverse mode, we prefer it here. This file also tests a new optimization (batch mode), which
+// allows us to call forward-mode autodiff only once, and get all 5 outputs in a single call.
+//
+// We support 2 different batch modes. `d_square2` has the same interface as scalar forward-mode,
+// but each shadow argument is `width` times larger (thus 16 and 20 elements here).
+// `d_square3` instead takes `width` (4) shadow arguments, which are all the same size as the
+// original function arguments.
+//
+// FIXME(autodiff): We currently can't test `d_square1` and `d_square3` in the same file, since they
+// generate the same dummy functions which get merged by LLVM, breaking pieces of our pipeline which
+// try to rewrite the dummy functions later. We should consider to change to pure declarations both
+// in our frontend and in the llvm backend to avoid these issues.
+
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_forward;
+
+// CHECK: ;
+#[no_mangle]
+//#[autodiff(d_square1, Forward, Dual, Dual)]
+#[autodiff_forward(d_square2, 4, Dualv, Dualv)]
+#[autodiff_forward(d_square3, 4, Dual, Dual)]
+fn square(x: &[f32], y: &mut [f32]) {
+    assert!(x.len() >= 4);
+    assert!(y.len() >= 5);
+    y[0] = 4.3 * x[0] + 1.2 * x[1] + 3.4 * x[2] + 2.1 * x[3];
+    y[1] = 2.3 * x[0] + 4.5 * x[1] + 1.7 * x[2] + 6.4 * x[3];
+    y[2] = 1.1 * x[0] + 3.3 * x[1] + 2.5 * x[2] + 4.7 * x[3];
+    y[3] = 5.2 * x[0] + 1.4 * x[1] + 2.6 * x[2] + 3.8 * x[3];
+    y[4] = 1.0 * x[0] + 2.0 * x[1] + 3.0 * x[2] + 4.0 * x[3];
+}
+
+fn main() {
+    let x1 = std::hint::black_box(vec![0.0, 1.0, 2.0, 3.0]);
+
+    let dx1 = std::hint::black_box(vec![1.0; 12]);
+
+    let z1 = std::hint::black_box(vec![1.0, 0.0, 0.0, 0.0]);
+    let z2 = std::hint::black_box(vec![0.0, 1.0, 0.0, 0.0]);
+    let z3 = std::hint::black_box(vec![0.0, 0.0, 1.0, 0.0]);
+    let z4 = std::hint::black_box(vec![0.0, 0.0, 0.0, 1.0]);
+
+    let z5 = std::hint::black_box(vec![
+        1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
+    ]);
+
+    let mut y1 = std::hint::black_box(vec![0.0; 5]);
+    let mut y2 = std::hint::black_box(vec![0.0; 5]);
+    let mut y3 = std::hint::black_box(vec![0.0; 5]);
+    let mut y4 = std::hint::black_box(vec![0.0; 5]);
+
+    let mut y5 = std::hint::black_box(vec![0.0; 5]);
+
+    let mut y6 = std::hint::black_box(vec![0.0; 5]);
+
+    let mut dy1_1 = std::hint::black_box(vec![0.0; 5]);
+    let mut dy1_2 = std::hint::black_box(vec![0.0; 5]);
+    let mut dy1_3 = std::hint::black_box(vec![0.0; 5]);
+    let mut dy1_4 = std::hint::black_box(vec![0.0; 5]);
+
+    let mut dy2 = std::hint::black_box(vec![0.0; 20]);
+
+    let mut dy3_1 = std::hint::black_box(vec![0.0; 5]);
+    let mut dy3_2 = std::hint::black_box(vec![0.0; 5]);
+    let mut dy3_3 = std::hint::black_box(vec![0.0; 5]);
+    let mut dy3_4 = std::hint::black_box(vec![0.0; 5]);
+
+    // scalar.
+    //d_square1(&x1, &z1, &mut y1, &mut dy1_1);
+    //d_square1(&x1, &z2, &mut y2, &mut dy1_2);
+    //d_square1(&x1, &z3, &mut y3, &mut dy1_3);
+    //d_square1(&x1, &z4, &mut y4, &mut dy1_4);
+
+    // assert y1 == y2 == y3 == y4
+    //for i in 0..5 {
+    //    assert_eq!(y1[i], y2[i]);
+    //    assert_eq!(y1[i], y3[i]);
+    //    assert_eq!(y1[i], y4[i]);
+    //}
+
+    // batch mode A)
+    d_square2(&x1, &z5, &mut y5, &mut dy2);
+
+    // assert y1 == y2 == y3 == y4 == y5
+    //for i in 0..5 {
+    //    assert_eq!(y1[i], y5[i]);
+    //}
+
+    // batch mode B)
+    d_square3(&x1, &z1, &z2, &z3, &z4, &mut y6, &mut dy3_1, &mut dy3_2, &mut dy3_3, &mut dy3_4);
+    for i in 0..5 {
+        assert_eq!(y5[i], y6[i]);
+    }
+
+    for i in 0..5 {
+        assert_eq!(dy2[0..5][i], dy3_1[i]);
+        assert_eq!(dy2[5..10][i], dy3_2[i]);
+        assert_eq!(dy2[10..15][i], dy3_3[i]);
+        assert_eq!(dy2[15..20][i], dy3_4[i]);
+    }
+}
diff --git a/tests/codegen-llvm/autodiff/batched.rs b/tests/codegen-llvm/autodiff/batched.rs
index d27aed50e6c..306a6ed9d1f 100644
--- a/tests/codegen-llvm/autodiff/batched.rs
+++ b/tests/codegen-llvm/autodiff/batched.rs
@@ -17,11 +17,12 @@ use std::autodiff::autodiff_forward;
 #[autodiff_forward(d_square2, 4, Dual, DualOnly)]
 #[autodiff_forward(d_square1, 4, Dual, Dual)]
 #[no_mangle]
+#[inline(never)]
 fn square(x: &f32) -> f32 {
     x * x
 }
 
-// d_sqaure2
+// d_square2
 // CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
 // CHECK-NEXT: start:
 // CHECK-NEXT:   %0 = extractvalue [4 x ptr] %"x'", 0
@@ -32,24 +33,20 @@ fn square(x: &f32) -> f32 {
 // CHECK-NEXT:   %"_2'ipl2" = load float, ptr %2, align 4
 // CHECK-NEXT:   %3 = extractvalue [4 x ptr] %"x'", 3
 // CHECK-NEXT:   %"_2'ipl3" = load float, ptr %3, align 4
-// CHECK-NEXT:   %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
-// CHECK-NEXT:   %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
-// CHECK-NEXT:   %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
-// CHECK-NEXT:   %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
-// CHECK-NEXT:   %8 = fadd fast <4 x float> %7, %7
-// CHECK-NEXT:   %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
-// CHECK-NEXT:   %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:   %11 = fmul fast <4 x float> %8, %10
-// CHECK-NEXT:   %12 = extractelement <4 x float> %11, i64 0
-// CHECK-NEXT:   %13 = insertvalue [4 x float] undef, float %12, 0
-// CHECK-NEXT:   %14 = extractelement <4 x float> %11, i64 1
-// CHECK-NEXT:   %15 = insertvalue [4 x float] %13, float %14, 1
-// CHECK-NEXT:   %16 = extractelement <4 x float> %11, i64 2
-// CHECK-NEXT:   %17 = insertvalue [4 x float] %15, float %16, 2
-// CHECK-NEXT:   %18 = extractelement <4 x float> %11, i64 3
-// CHECK-NEXT:   %19 = insertvalue [4 x float] %17, float %18, 3
-// CHECK-NEXT:   ret [4 x float] %19
-// CHECK-NEXT: }
+// CHECK-NEXT:   %4 = fmul float %"_2'ipl", 2.000000e+00
+// CHECK-NEXT:   %5 = fmul fast float %4, %x.0.val
+// CHECK-NEXT:   %6 = insertvalue [4 x float] undef, float %5, 0
+// CHECK-NEXT:   %7 = fmul float %"_2'ipl1", 2.000000e+00
+// CHECK-NEXT:   %8 = fmul fast float %7, %x.0.val
+// CHECK-NEXT:   %9 = insertvalue [4 x float] %6, float %8, 1
+// CHECK-NEXT:   %10 = fmul float %"_2'ipl2", 2.000000e+00
+// CHECK-NEXT:   %11 = fmul fast float %10, %x.0.val
+// CHECK-NEXT:   %12 = insertvalue [4 x float] %9, float %11, 2
+// CHECK-NEXT:   %13 = fmul float %"_2'ipl3", 2.000000e+00
+// CHECK-NEXT:   %14 = fmul fast float %13, %x.0.val
+// CHECK-NEXT:   %15 = insertvalue [4 x float] %12, float %14, 3
+// CHECK-NEXT:   ret [4 x float] %15
+// CHECK-NEXT:   }
 
 // d_square3, the extra float is the original return value (x * x)
 // CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
@@ -63,26 +60,22 @@ fn square(x: &f32) -> f32 {
 // CHECK-NEXT:   %3 = extractvalue [4 x ptr] %"x'", 3
 // CHECK-NEXT:   %"_2'ipl3" = load float, ptr %3, align 4
 // CHECK-NEXT:   %_0 = fmul float %x.0.val, %x.0.val
-// CHECK-NEXT:   %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
-// CHECK-NEXT:   %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
-// CHECK-NEXT:   %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
-// CHECK-NEXT:   %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
-// CHECK-NEXT:   %8 = fadd fast <4 x float> %7, %7
-// CHECK-NEXT:   %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
-// CHECK-NEXT:   %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:   %11 = fmul fast <4 x float> %8, %10
-// CHECK-NEXT:   %12 = extractelement <4 x float> %11, i64 0
-// CHECK-NEXT:   %13 = insertvalue [4 x float] undef, float %12, 0
-// CHECK-NEXT:   %14 = extractelement <4 x float> %11, i64 1
-// CHECK-NEXT:   %15 = insertvalue [4 x float] %13, float %14, 1
-// CHECK-NEXT:   %16 = extractelement <4 x float> %11, i64 2
-// CHECK-NEXT:   %17 = insertvalue [4 x float] %15, float %16, 2
-// CHECK-NEXT:   %18 = extractelement <4 x float> %11, i64 3
-// CHECK-NEXT:   %19 = insertvalue [4 x float] %17, float %18, 3
-// CHECK-NEXT:   %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
-// CHECK-NEXT:   %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
-// CHECK-NEXT:   ret { float, [4 x float] } %21
-// CHECK-NEXT: }
+// CHECK-NEXT:   %4 = fmul float %"_2'ipl", 2.000000e+00
+// CHECK-NEXT:   %5 = fmul fast float %4, %x.0.val
+// CHECK-NEXT:   %6 = insertvalue [4 x float] undef, float %5, 0
+// CHECK-NEXT:   %7 = fmul float %"_2'ipl1", 2.000000e+00
+// CHECK-NEXT:   %8 = fmul fast float %7, %x.0.val
+// CHECK-NEXT:   %9 = insertvalue [4 x float] %6, float %8, 1
+// CHECK-NEXT:   %10 = fmul float %"_2'ipl2", 2.000000e+00
+// CHECK-NEXT:   %11 = fmul fast float %10, %x.0.val
+// CHECK-NEXT:   %12 = insertvalue [4 x float] %9, float %11, 2
+// CHECK-NEXT:   %13 = fmul float %"_2'ipl3", 2.000000e+00
+// CHECK-NEXT:   %14 = fmul fast float %13, %x.0.val
+// CHECK-NEXT:   %15 = insertvalue [4 x float] %12, float %14, 3
+// CHECK-NEXT:   %16 = insertvalue { float, [4 x float] } undef, float %_0, 0
+// CHECK-NEXT:   %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1
+// CHECK-NEXT:   ret { float, [4 x float] } %17
+// CHECK-NEXT:   }
 
 fn main() {
     let x = std::hint::black_box(3.0);
diff --git a/tests/codegen-llvm/autodiff/generic.rs b/tests/codegen-llvm/autodiff/generic.rs
index 2f674079be0..6f56460a2b6 100644
--- a/tests/codegen-llvm/autodiff/generic.rs
+++ b/tests/codegen-llvm/autodiff/generic.rs
@@ -6,27 +6,28 @@
 use std::autodiff::autodiff_reverse;
 
 #[autodiff_reverse(d_square, Duplicated, Active)]
+#[inline(never)]
 fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
     *x * *x
 }
 
-// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
+// Ensure that `d_square::<f32>` code is generated
 //
 // CHECK: ; generic::square
-// CHECK-NEXT: ; Function Attrs:
-// CHECK-NEXT: define internal {{.*}} double
+// CHECK-NEXT: ; Function Attrs: {{.*}}
+// CHECK-NEXT: define internal {{.*}} float
 // CHECK-NEXT: start:
 // CHECK-NOT: ret
-// CHECK: fmul double
+// CHECK: fmul float
 
-// Ensure that `d_square::<f32>` code is generated
+// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
 //
 // CHECK: ; generic::square
-// CHECK-NEXT: ; Function Attrs: {{.*}}
-// CHECK-NEXT: define internal {{.*}} float
+// CHECK-NEXT: ; Function Attrs:
+// CHECK-NEXT: define internal {{.*}} double
 // CHECK-NEXT: start:
 // CHECK-NOT: ret
-// CHECK: fmul float
+// CHECK: fmul double
 
 fn main() {
     let xf32: f32 = std::hint::black_box(3.0);
diff --git a/tests/codegen-llvm/autodiff/identical_fnc.rs b/tests/codegen-llvm/autodiff/identical_fnc.rs
index 1c25b3d09ab..6066f8cb34f 100644
--- a/tests/codegen-llvm/autodiff/identical_fnc.rs
+++ b/tests/codegen-llvm/autodiff/identical_fnc.rs
@@ -14,25 +14,27 @@
 use std::autodiff::autodiff_reverse;
 
 #[autodiff_reverse(d_square, Duplicated, Active)]
+#[inline(never)]
 fn square(x: &f64) -> f64 {
     x * x
 }
 
 #[autodiff_reverse(d_square2, Duplicated, Active)]
+#[inline(never)]
 fn square2(x: &f64) -> f64 {
     x * x
 }
 
 // CHECK:; identical_fnc::main
 // CHECK-NEXT:; Function Attrs:
-// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
+// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17h6009e4f751bf9407E()
 // CHECK-NEXT:start:
 // CHECK-NOT:br
 // CHECK-NOT:ret
 // CHECK:; call identical_fnc::d_square
-// CHECK-NEXT:  call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
-// CHECK-NEXT:; call identical_fnc::d_square
-// CHECK-NEXT:  call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
+// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1)
+// CHECK:; call identical_fnc::d_square
+// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square17hcb5768e95528c35fE(double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2)
 
 fn main() {
     let x = std::hint::black_box(3.0);
diff --git a/tests/codegen-llvm/autodiff/inline.rs b/tests/codegen-llvm/autodiff/inline.rs
deleted file mode 100644
index 65bed170207..00000000000
--- a/tests/codegen-llvm/autodiff/inline.rs
+++ /dev/null
@@ -1,23 +0,0 @@
-//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat -Zautodiff=NoPostopt
-//@ no-prefer-dynamic
-//@ needs-enzyme
-
-#![feature(autodiff)]
-
-use std::autodiff::autodiff_reverse;
-
-#[autodiff_reverse(d_square, Duplicated, Active)]
-fn square(x: &f64) -> f64 {
-    x * x
-}
-
-// CHECK: ; inline::d_square
-// CHECK-NEXT: ; Function Attrs: alwaysinline
-// CHECK-NOT: noinline
-// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE
-fn main() {
-    let x = std::hint::black_box(3.0);
-    let mut dx1 = std::hint::black_box(1.0);
-    let _ = d_square(&x, &mut dx1, 1.0);
-    assert_eq!(dx1, 6.0);
-}
diff --git a/tests/codegen-llvm/autodiff/scalar.rs b/tests/codegen-llvm/autodiff/scalar.rs
index 096b4209e84..55b989f920d 100644
--- a/tests/codegen-llvm/autodiff/scalar.rs
+++ b/tests/codegen-llvm/autodiff/scalar.rs
@@ -7,11 +7,12 @@ use std::autodiff::autodiff_reverse;
 
 #[autodiff_reverse(d_square, Duplicated, Active)]
 #[no_mangle]
+#[inline(never)]
 fn square(x: &f64) -> f64 {
     x * x
 }
 
-// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
+// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nonnull align 8 captures(none) %"x'")
 // CHECK-NEXT:invertstart:
 // CHECK-NEXT:  %_0 = fmul double %x.0.val, %x.0.val
 // CHECK-NEXT:  %0 = fadd fast double %x.0.val, %x.0.val
diff --git a/tests/codegen-llvm/autodiff/sret.rs b/tests/codegen-llvm/autodiff/sret.rs
index d2fa85e3e37..dbc253ce894 100644
--- a/tests/codegen-llvm/autodiff/sret.rs
+++ b/tests/codegen-llvm/autodiff/sret.rs
@@ -13,30 +13,30 @@ use std::autodiff::autodiff_reverse;
 
 #[no_mangle]
 #[autodiff_reverse(df, Active, Active, Active)]
+#[inline(never)]
 fn primal(x: f32, y: f32) -> f64 {
     (x * x * y) as f64
 }
 
-// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y)
-// CHECK-NEXT:start:
-// CHECK-NEXT:  %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y)
-// CHECK-NEXT:  %.elt = extractvalue { double, float, float } %0, 0
-// CHECK-NEXT:  store double %.elt, ptr %_0, align 8
-// CHECK-NEXT:  %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8
-// CHECK-NEXT:  %.elt2 = extractvalue { double, float, float } %0, 1
-// CHECK-NEXT:  store float %.elt2, ptr %_0.repack1, align 8
-// CHECK-NEXT:  %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12
-// CHECK-NEXT:  %.elt4 = extractvalue { double, float, float } %0, 2
-// CHECK-NEXT:  store float %.elt4, ptr %_0.repack3, align 4
-// CHECK-NEXT:  ret void
-// CHECK-NEXT:}
+// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y)
+// CHECK-NEXT: invertstart:
+// CHECK-NEXT: %_4 = fmul float %x, %x
+// CHECK-NEXT: %_3 = fmul float %_4, %y
+// CHECK-NEXT: %_0 = fpext float %_3 to double
+// CHECK-NEXT: %0 = fadd fast float %y, %y
+// CHECK-NEXT: %1 = fmul fast float %0, %x
+// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0
+// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1
+// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2
+// CHECK-NEXT: ret { double, float, float } %4
+// CHECK-NEXT: }
 
 fn main() {
     let x = std::hint::black_box(3.0);
     let y = std::hint::black_box(2.5);
     let scalar = std::hint::black_box(1.0);
     let (r1, r2, r3) = df(x, y, scalar);
-    // 3*3*1.5 = 22.5
+    // 3*3*2.5 = 22.5
     assert_eq!(r1, 22.5);
     // 2*x*y = 2*3*2.5 = 15.0
     assert_eq!(r2, 15.0);
diff --git a/tests/codegen-llvm/autodiff/trait.rs b/tests/codegen-llvm/autodiff/trait.rs
new file mode 100644
index 00000000000..701f3a9e843
--- /dev/null
+++ b/tests/codegen-llvm/autodiff/trait.rs
@@ -0,0 +1,30 @@
+//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+
+// Just check it does not crash for now
+// CHECK: ;
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+struct Foo {
+    a: f64,
+}
+
+trait MyTrait {
+    fn f(&self, x: f64) -> f64;
+    fn df(&self, x: f64, seed: f64) -> (f64, f64);
+}
+
+impl MyTrait for Foo {
+    #[autodiff_reverse(df, Const, Active, Active)]
+    fn f(&self, x: f64) -> f64 {
+        self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
+    }
+}
+
+fn main() {
+    let foo = Foo { a: 3.0f64 };
+    dbg!(foo.df(1.0, 1.0));
+}