about summary refs log tree commit diff
path: root/tests/codegen/autodiff/batched.rs
diff options
context:
space:
mode:
authorGuillaume Gomez <guillaume1.gomez@gmail.com>2025-07-21 14:34:12 +0200
committerGuillaume Gomez <guillaume1.gomez@gmail.com>2025-07-22 14:28:48 +0200
commita27f3e3fd1e4d16160f8885b6b06665b5319f56c (patch)
treeb033935392cbadf6f85d2dbddf433a88e323aeeb /tests/codegen/autodiff/batched.rs
parented93c1783b404d728d4809973a0550eb33cd293f (diff)
downloadrust-a27f3e3fd1e4d16160f8885b6b06665b5319f56c.tar.gz
rust-a27f3e3fd1e4d16160f8885b6b06665b5319f56c.zip
Rename `tests/codegen` into `tests/codegen-llvm`
Diffstat (limited to 'tests/codegen/autodiff/batched.rs')
-rw-r--r--tests/codegen/autodiff/batched.rs116
1 files changed, 0 insertions, 116 deletions
diff --git a/tests/codegen/autodiff/batched.rs b/tests/codegen/autodiff/batched.rs
deleted file mode 100644
index d27aed50e6c..00000000000
--- a/tests/codegen/autodiff/batched.rs
+++ /dev/null
@@ -1,116 +0,0 @@
-//@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat
-//@ no-prefer-dynamic
-//@ needs-enzyme
-//
-// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
-// breakages. One benefit is that we match the IR generated by Enzyme only after running it
-// through LLVM's O3 pipeline, which will remove most of the noise.
-// However, our integration test could also be affected by changes in how rustc lowers MIR into
-// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
-// reduce this test to only match the first lines and the ret instructions.
-
-#![feature(autodiff)]
-
-use std::autodiff::autodiff_forward;
-
-#[autodiff_forward(d_square3, Dual, DualOnly)]
-#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
-#[autodiff_forward(d_square1, 4, Dual, Dual)]
-#[no_mangle]
-fn square(x: &f32) -> f32 {
-    x * x
-}
-
-// d_sqaure2
-// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
-// CHECK-NEXT: start:
-// CHECK-NEXT:   %0 = extractvalue [4 x ptr] %"x'", 0
-// CHECK-NEXT:   %"_2'ipl" = load float, ptr %0, align 4
-// CHECK-NEXT:   %1 = extractvalue [4 x ptr] %"x'", 1
-// CHECK-NEXT:   %"_2'ipl1" = load float, ptr %1, align 4
-// CHECK-NEXT:   %2 = extractvalue [4 x ptr] %"x'", 2
-// CHECK-NEXT:   %"_2'ipl2" = load float, ptr %2, align 4
-// CHECK-NEXT:   %3 = extractvalue [4 x ptr] %"x'", 3
-// CHECK-NEXT:   %"_2'ipl3" = load float, ptr %3, align 4
-// CHECK-NEXT:   %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
-// CHECK-NEXT:   %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
-// CHECK-NEXT:   %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
-// CHECK-NEXT:   %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
-// CHECK-NEXT:   %8 = fadd fast <4 x float> %7, %7
-// CHECK-NEXT:   %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
-// CHECK-NEXT:   %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:   %11 = fmul fast <4 x float> %8, %10
-// CHECK-NEXT:   %12 = extractelement <4 x float> %11, i64 0
-// CHECK-NEXT:   %13 = insertvalue [4 x float] undef, float %12, 0
-// CHECK-NEXT:   %14 = extractelement <4 x float> %11, i64 1
-// CHECK-NEXT:   %15 = insertvalue [4 x float] %13, float %14, 1
-// CHECK-NEXT:   %16 = extractelement <4 x float> %11, i64 2
-// CHECK-NEXT:   %17 = insertvalue [4 x float] %15, float %16, 2
-// CHECK-NEXT:   %18 = extractelement <4 x float> %11, i64 3
-// CHECK-NEXT:   %19 = insertvalue [4 x float] %17, float %18, 3
-// CHECK-NEXT:   ret [4 x float] %19
-// CHECK-NEXT: }
-
-// d_square3, the extra float is the original return value (x * x)
-// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
-// CHECK-NEXT: start:
-// CHECK-NEXT:   %0 = extractvalue [4 x ptr] %"x'", 0
-// CHECK-NEXT:   %"_2'ipl" = load float, ptr %0, align 4
-// CHECK-NEXT:   %1 = extractvalue [4 x ptr] %"x'", 1
-// CHECK-NEXT:   %"_2'ipl1" = load float, ptr %1, align 4
-// CHECK-NEXT:   %2 = extractvalue [4 x ptr] %"x'", 2
-// CHECK-NEXT:   %"_2'ipl2" = load float, ptr %2, align 4
-// CHECK-NEXT:   %3 = extractvalue [4 x ptr] %"x'", 3
-// CHECK-NEXT:   %"_2'ipl3" = load float, ptr %3, align 4
-// CHECK-NEXT:   %_0 = fmul float %x.0.val, %x.0.val
-// CHECK-NEXT:   %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
-// CHECK-NEXT:   %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
-// CHECK-NEXT:   %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
-// CHECK-NEXT:   %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
-// CHECK-NEXT:   %8 = fadd fast <4 x float> %7, %7
-// CHECK-NEXT:   %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
-// CHECK-NEXT:   %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:   %11 = fmul fast <4 x float> %8, %10
-// CHECK-NEXT:   %12 = extractelement <4 x float> %11, i64 0
-// CHECK-NEXT:   %13 = insertvalue [4 x float] undef, float %12, 0
-// CHECK-NEXT:   %14 = extractelement <4 x float> %11, i64 1
-// CHECK-NEXT:   %15 = insertvalue [4 x float] %13, float %14, 1
-// CHECK-NEXT:   %16 = extractelement <4 x float> %11, i64 2
-// CHECK-NEXT:   %17 = insertvalue [4 x float] %15, float %16, 2
-// CHECK-NEXT:   %18 = extractelement <4 x float> %11, i64 3
-// CHECK-NEXT:   %19 = insertvalue [4 x float] %17, float %18, 3
-// CHECK-NEXT:   %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
-// CHECK-NEXT:   %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
-// CHECK-NEXT:   ret { float, [4 x float] } %21
-// CHECK-NEXT: }
-
-fn main() {
-    let x = std::hint::black_box(3.0);
-    let output = square(&x);
-    dbg!(&output);
-    assert_eq!(9.0, output);
-    dbg!(square(&x));
-
-    let mut df_dx1 = 1.0;
-    let mut df_dx2 = 2.0;
-    let mut df_dx3 = 3.0;
-    let mut df_dx4 = 0.0;
-    let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
-    dbg!(o1, o2, o3, o4);
-    let [output2, o1, o2, o3, o4] =
-        d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
-    dbg!(o1, o2, o3, o4);
-    assert_eq!(output, output2);
-    assert!((6.0 - o1).abs() < 1e-10);
-    assert!((12.0 - o2).abs() < 1e-10);
-    assert!((18.0 - o3).abs() < 1e-10);
-    assert!((0.0 - o4).abs() < 1e-10);
-    assert_eq!(1.0, df_dx1);
-    assert_eq!(2.0, df_dx2);
-    assert_eq!(3.0, df_dx3);
-    assert_eq!(0.0, df_dx4);
-    assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1);
-    assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2);
-    assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3);
-    assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4);
-}