From ca5bea3ebbc4725c187abf4eac68f6c57fa938c1 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Mon, 7 Apr 2025 07:11:52 -0400 Subject: move old tests, add sret test --- tests/codegen/autodiff.rs | 33 ----------- tests/codegen/autodiff/batched.rs | 116 ++++++++++++++++++++++++++++++++++++++ tests/codegen/autodiff/scalar.rs | 33 +++++++++++ tests/codegen/autodiff/sret.rs | 45 +++++++++++++++ tests/codegen/autodiffv.rs | 116 -------------------------------------- 5 files changed, 194 insertions(+), 149 deletions(-) delete mode 100644 tests/codegen/autodiff.rs create mode 100644 tests/codegen/autodiff/batched.rs create mode 100644 tests/codegen/autodiff/scalar.rs create mode 100644 tests/codegen/autodiff/sret.rs delete mode 100644 tests/codegen/autodiffv.rs (limited to 'tests/codegen') diff --git a/tests/codegen/autodiff.rs b/tests/codegen/autodiff.rs deleted file mode 100644 index 85358f5fcb6..00000000000 --- a/tests/codegen/autodiff.rs +++ /dev/null @@ -1,33 +0,0 @@ -//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -//@ no-prefer-dynamic -//@ needs-enzyme -#![feature(autodiff)] - -use std::autodiff::autodiff; - -#[autodiff(d_square, Reverse, Duplicated, Active)] -#[no_mangle] -fn square(x: &f64) -> f64 { - x * x -} - -// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'" -// CHECK-NEXT:invertstart: -// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val -// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val -// CHECK-NEXT: %1 = load double, ptr %"x'", align 8 -// CHECK-NEXT: %2 = fadd fast double %1, %0 -// CHECK-NEXT: store double %2, ptr %"x'", align 8 -// CHECK-NEXT: ret double %_0 -// CHECK-NEXT:} - -fn main() { - let x = std::hint::black_box(3.0); - let output = square(&x); - assert_eq!(9.0, output); - - let mut df_dx = 0.0; - let output_ = d_square(&x, &mut df_dx, 1.0); - assert_eq!(output, output_); - assert_eq!(6.0, df_dx); -} diff --git a/tests/codegen/autodiff/batched.rs b/tests/codegen/autodiff/batched.rs new file mode 100644 index 00000000000..e0047116405 --- /dev/null +++ b/tests/codegen/autodiff/batched.rs @@ -0,0 +1,116 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme +// +// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many +// breakages. One benefit is that we match the IR generated by Enzyme only after running it +// through LLVM's O3 pipeline, which will remove most of the noise. +// However, our integration test could also be affected by changes in how rustc lowers MIR into +// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should +// reduce this test to only match the first lines and the ret instructions. + +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[autodiff(d_square3, Forward, Dual, DualOnly)] +#[autodiff(d_square2, Forward, 4, Dual, DualOnly)] +#[autodiff(d_square1, Forward, 4, Dual, Dual)] +#[no_mangle] +fn square(x: &f32) -> f32 { + x * x +} + +// d_sqaure2 +// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'") +// CHECK-NEXT: start: +// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 +// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4 +// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1 +// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4 +// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2 +// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 +// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 +// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 +// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 +// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 +// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 +// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 +// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 +// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 +// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer +// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 +// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 +// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 +// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 +// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 +// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 +// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 +// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 +// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 +// CHECK-NEXT: ret [4 x float] %19 +// CHECK-NEXT: } + +// d_square3, the extra float is the original return value (x * x) +// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'") +// CHECK-NEXT: start: +// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 +// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4 +// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1 +// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4 +// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2 +// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 +// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 +// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 +// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val +// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 +// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 +// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 +// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 +// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 +// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 +// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer +// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 +// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 +// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 +// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 +// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 +// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 +// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 +// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 +// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 +// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0 +// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1 +// CHECK-NEXT: ret { float, [4 x float] } %21 +// CHECK-NEXT: } + +fn main() { + let x = std::hint::black_box(3.0); + let output = square(&x); + dbg!(&output); + assert_eq!(9.0, output); + dbg!(square(&x)); + + let mut df_dx1 = 1.0; + let mut df_dx2 = 2.0; + let mut df_dx3 = 3.0; + let mut df_dx4 = 0.0; + let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4); + dbg!(o1, o2, o3, o4); + let [output2, o1, o2, o3, o4] = + d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4); + dbg!(o1, o2, o3, o4); + assert_eq!(output, output2); + assert!((6.0 - o1).abs() < 1e-10); + assert!((12.0 - o2).abs() < 1e-10); + assert!((18.0 - o3).abs() < 1e-10); + assert!((0.0 - o4).abs() < 1e-10); + assert_eq!(1.0, df_dx1); + assert_eq!(2.0, df_dx2); + assert_eq!(3.0, df_dx3); + assert_eq!(0.0, df_dx4); + assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1); + assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2); + assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3); + assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4); +} diff --git a/tests/codegen/autodiff/scalar.rs b/tests/codegen/autodiff/scalar.rs new file mode 100644 index 00000000000..85358f5fcb6 --- /dev/null +++ b/tests/codegen/autodiff/scalar.rs @@ -0,0 +1,33 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[autodiff(d_square, Reverse, Duplicated, Active)] +#[no_mangle] +fn square(x: &f64) -> f64 { + x * x +} + +// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'" +// CHECK-NEXT:invertstart: +// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val +// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val +// CHECK-NEXT: %1 = load double, ptr %"x'", align 8 +// CHECK-NEXT: %2 = fadd fast double %1, %0 +// CHECK-NEXT: store double %2, ptr %"x'", align 8 +// CHECK-NEXT: ret double %_0 +// CHECK-NEXT:} + +fn main() { + let x = std::hint::black_box(3.0); + let output = square(&x); + assert_eq!(9.0, output); + + let mut df_dx = 0.0; + let output_ = d_square(&x, &mut df_dx, 1.0); + assert_eq!(output, output_); + assert_eq!(6.0, df_dx); +} diff --git a/tests/codegen/autodiff/sret.rs b/tests/codegen/autodiff/sret.rs new file mode 100644 index 00000000000..5ead90041ed --- /dev/null +++ b/tests/codegen/autodiff/sret.rs @@ -0,0 +1,45 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// This test is almost identical to the scalar.rs one, +// but we intentionally add a few more floats. +// `df` would ret `{ f64, f32, f32 }`, but is lowered as an sret. +// We therefore use this test to verify some of our sret handling. + +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[no_mangle] +#[autodiff(df, Reverse, Active, Active, Active)] +fn primal(x: f32, y: f32) -> f64 { + (x * x * y) as f64 +} + +// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y) +// CHECK-NEXT:start: +// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y) +// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0 +// CHECK-NEXT: store double %.elt, ptr %_0, align 8 +// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8 +// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1 +// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8 +// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12 +// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2 +// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4 +// CHECK-NEXT: ret void +// CHECK-NEXT:} + +fn main() { + let x = std::hint::black_box(3.0); + let y = std::hint::black_box(2.5); + let scalar = std::hint::black_box(1.0); + let (r1, r2, r3) = df(x, y, scalar); + // 3*3*1.5 = 22.5 + assert_eq!(r1, 22.5); + // 2*x*y = 2*3*2.5 = 15.0 + assert_eq!(r2, 15.0); + // x*x*1 = 3*3 = 9 + assert_eq!(r3, 9.0); +} diff --git a/tests/codegen/autodiffv.rs b/tests/codegen/autodiffv.rs deleted file mode 100644 index e0047116405..00000000000 --- a/tests/codegen/autodiffv.rs +++ /dev/null @@ -1,116 +0,0 @@ -//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -//@ no-prefer-dynamic -//@ needs-enzyme -// -// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many -// breakages. One benefit is that we match the IR generated by Enzyme only after running it -// through LLVM's O3 pipeline, which will remove most of the noise. -// However, our integration test could also be affected by changes in how rustc lowers MIR into -// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should -// reduce this test to only match the first lines and the ret instructions. - -#![feature(autodiff)] - -use std::autodiff::autodiff; - -#[autodiff(d_square3, Forward, Dual, DualOnly)] -#[autodiff(d_square2, Forward, 4, Dual, DualOnly)] -#[autodiff(d_square1, Forward, 4, Dual, Dual)] -#[no_mangle] -fn square(x: &f32) -> f32 { - x * x -} - -// d_sqaure2 -// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'") -// CHECK-NEXT: start: -// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 -// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4 -// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1 -// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4 -// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2 -// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 -// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 -// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 -// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 -// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 -// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 -// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 -// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 -// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 -// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer -// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 -// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 -// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 -// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 -// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 -// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 -// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 -// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 -// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 -// CHECK-NEXT: ret [4 x float] %19 -// CHECK-NEXT: } - -// d_square3, the extra float is the original return value (x * x) -// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'") -// CHECK-NEXT: start: -// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 -// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4 -// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1 -// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4 -// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2 -// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 -// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 -// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 -// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val -// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 -// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 -// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 -// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 -// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 -// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 -// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer -// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 -// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 -// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 -// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 -// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 -// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 -// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 -// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 -// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 -// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0 -// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1 -// CHECK-NEXT: ret { float, [4 x float] } %21 -// CHECK-NEXT: } - -fn main() { - let x = std::hint::black_box(3.0); - let output = square(&x); - dbg!(&output); - assert_eq!(9.0, output); - dbg!(square(&x)); - - let mut df_dx1 = 1.0; - let mut df_dx2 = 2.0; - let mut df_dx3 = 3.0; - let mut df_dx4 = 0.0; - let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4); - dbg!(o1, o2, o3, o4); - let [output2, o1, o2, o3, o4] = - d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4); - dbg!(o1, o2, o3, o4); - assert_eq!(output, output2); - assert!((6.0 - o1).abs() < 1e-10); - assert!((12.0 - o2).abs() < 1e-10); - assert!((18.0 - o3).abs() < 1e-10); - assert!((0.0 - o4).abs() < 1e-10); - assert_eq!(1.0, df_dx1); - assert_eq!(2.0, df_dx2); - assert_eq!(3.0, df_dx3); - assert_eq!(0.0, df_dx4); - assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1); - assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2); - assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3); - assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4); -} -- cgit 1.4.1-3-g733a5