From e1258e79d6cb709b26ded97d32de6c55f355e2aa Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 23 Aug 2025 20:17:32 +0000 Subject: autodiff: Add basic TypeTree with NoTT flag Signed-off-by: Karan Janthe --- tests/codegen-llvm/autodiff/typetree.rs | 33 +++++++++++++++++++ .../autodiff/type-trees/nott-flag/nott.check | 3 ++ .../autodiff/type-trees/nott-flag/rmake.rs | 38 ++++++++++++++++++++++ .../run-make/autodiff/type-trees/nott-flag/test.rs | 15 +++++++++ .../autodiff/type-trees/nott-flag/with_tt.check | 3 ++ tests/ui/autodiff/flag_nott.rs | 19 +++++++++++ 6 files changed, 111 insertions(+) create mode 100644 tests/codegen-llvm/autodiff/typetree.rs create mode 100644 tests/run-make/autodiff/type-trees/nott-flag/nott.check create mode 100644 tests/run-make/autodiff/type-trees/nott-flag/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/nott-flag/test.rs create mode 100644 tests/run-make/autodiff/type-trees/nott-flag/with_tt.check create mode 100644 tests/ui/autodiff/flag_nott.rs (limited to 'tests') diff --git a/tests/codegen-llvm/autodiff/typetree.rs b/tests/codegen-llvm/autodiff/typetree.rs new file mode 100644 index 00000000000..3ad38d581b9 --- /dev/null +++ b/tests/codegen-llvm/autodiff/typetree.rs @@ -0,0 +1,33 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// Test that basic autodiff still works with our TypeTree infrastructure +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_simple, Duplicated, Active)] +#[no_mangle] +#[inline(never)] +fn simple(x: &f64) -> f64 { + 2.0 * x +} + +// CHECK-LABEL: @simple +// CHECK: fmul double + +// The derivative function should be generated normally +// CHECK-LABEL: diffesimple +// CHECK: fadd fast double + +fn main() { + let x = std::hint::black_box(3.0); + let output = simple(&x); + assert_eq!(6.0, output); + + let mut df_dx = 0.0; + let output_ = d_simple(&x, &mut df_dx, 1.0); + assert_eq!(output, output_); + assert_eq!(2.0, df_dx); +} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/nott.check b/tests/run-make/autodiff/type-trees/nott-flag/nott.check new file mode 100644 index 00000000000..56ef2f0bdf3 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/nott.check @@ -0,0 +1,3 @@ +// TODO(KMJ-007): Update this test when TypeTree integration is complete +// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{} +// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs new file mode 100644 index 00000000000..536164192dc --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs @@ -0,0 +1,38 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Test with NoTT flag - should not generate TypeTree metadata + let output_nott = rustc() + .input("test.rs") + .arg("-Zautodiff=Enable,NoTT,PrintTAFn=square") + .arg("-Zautodiff=NoPostopt") + .opt_level("3") + .arg("-Clto=fat") + .arg("-g") + .run(); + + // Write output for NoTT case + rfs::write("nott.stdout", output_nott.stdout_utf8()); + + // Test without NoTT flag - should generate TypeTree metadata + let output_with_tt = rustc() + .input("test.rs") + .arg("-Zautodiff=Enable,PrintTAFn=square") + .arg("-Zautodiff=NoPostopt") + .opt_level("3") + .arg("-Clto=fat") + .arg("-g") + .run(); + + // Write output for TypeTree case + rfs::write("with_tt.stdout", output_with_tt.stdout_utf8()); + + // Verify NoTT output has minimal TypeTree info + llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.stdout")).run(); + + // Verify normal output will have TypeTree info (once implemented) + llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.stdout")).run(); +} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/test.rs b/tests/run-make/autodiff/type-trees/nott-flag/test.rs new file mode 100644 index 00000000000..5c634eea035 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_square, Duplicated, Active)] +#[no_mangle] +fn square(x: &f64) -> f64 { + x * x +} + +fn main() { + let x = 2.0; + let mut dx = 0.0; + let _result = d_square(&x, &mut dx, 1.0); +} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check new file mode 100644 index 00000000000..56ef2f0bdf3 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check @@ -0,0 +1,3 @@ +// TODO(KMJ-007): Update this test when TypeTree integration is complete +// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{} +// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double} \ No newline at end of file diff --git a/tests/ui/autodiff/flag_nott.rs b/tests/ui/autodiff/flag_nott.rs new file mode 100644 index 00000000000..7a97d892cd8 --- /dev/null +++ b/tests/ui/autodiff/flag_nott.rs @@ -0,0 +1,19 @@ +//@ compile-flags: -Zautodiff=Enable,NoTT +//@ needs-enzyme +//@ check-pass + +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +// Test that NoTT flag is accepted and doesn't cause compilation errors +#[autodiff_reverse(d_square, Duplicated, Active)] +fn square(x: &f64) -> f64 { + x * x +} + +fn main() { + let x = 2.0; + let mut dx = 0.0; + let result = d_square(&x, &mut dx, 1.0); +} \ No newline at end of file -- cgit 1.4.1-3-g733a5 From beba6b9b8692d0db14c56292d4fb793fcd8ceb1e Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 23 Aug 2025 21:57:37 +0000 Subject: Update TypeTree tests to verify metadata attachment - Fix nott-flag test to emit LLVM IR and check enzyme_type attributes - Replace TODO comments with actual TypeTree metadata verification - Test that NoTT flag properly disables TypeTree generation - Test that TypeTree enabled generates proper enzyme_type attributes Signed-off-by: Karan Janthe --- .../autodiff/type-trees/nott-flag/nott.check | 8 +++-- .../autodiff/type-trees/nott-flag/rmake.rs | 42 +++++++++++----------- .../autodiff/type-trees/nott-flag/with_tt.check | 7 ++-- 3 files changed, 29 insertions(+), 28 deletions(-) (limited to 'tests') diff --git a/tests/run-make/autodiff/type-trees/nott-flag/nott.check b/tests/run-make/autodiff/type-trees/nott-flag/nott.check index 56ef2f0bdf3..8d23e2ee319 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/nott.check +++ b/tests/run-make/autodiff/type-trees/nott-flag/nott.check @@ -1,3 +1,5 @@ -// TODO(KMJ-007): Update this test when TypeTree integration is complete -// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{} -// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double} \ No newline at end of file +// Check that enzyme_type attributes are NOT present when NoTT flag is used +// This verifies the NoTT flag correctly disables TypeTree metadata + +CHECK: define{{.*}}@square +CHECK-NOT: "enzyme_type" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs index 536164192dc..bab863ca9ff 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs +++ b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs @@ -5,34 +5,32 @@ use run_make_support::{llvm_filecheck, rfs, rustc}; fn main() { // Test with NoTT flag - should not generate TypeTree metadata - let output_nott = rustc() + rustc() .input("test.rs") - .arg("-Zautodiff=Enable,NoTT,PrintTAFn=square") - .arg("-Zautodiff=NoPostopt") - .opt_level("3") - .arg("-Clto=fat") - .arg("-g") + .arg("-Zautodiff=Enable,NoTT") + .emit("llvm-ir") + .arg("-o") + .arg("nott.ll") .run(); - // Write output for NoTT case - rfs::write("nott.stdout", output_nott.stdout_utf8()); - // Test without NoTT flag - should generate TypeTree metadata - let output_with_tt = rustc() + rustc() .input("test.rs") - .arg("-Zautodiff=Enable,PrintTAFn=square") - .arg("-Zautodiff=NoPostopt") - .opt_level("3") - .arg("-Clto=fat") - .arg("-g") + .arg("-Zautodiff=Enable") + .emit("llvm-ir") + .arg("-o") + .arg("with_tt.ll") .run(); - // Write output for TypeTree case - rfs::write("with_tt.stdout", output_with_tt.stdout_utf8()); - - // Verify NoTT output has minimal TypeTree info - llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.stdout")).run(); + // Verify NoTT version does NOT have enzyme_type attributes + llvm_filecheck() + .patterns("nott.check") + .stdin_buf(rfs::read("nott.ll")) + .run(); - // Verify normal output will have TypeTree info (once implemented) - llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.stdout")).run(); + // Verify TypeTree version DOES have enzyme_type attributes + llvm_filecheck() + .patterns("with_tt.check") + .stdin_buf(rfs::read("with_tt.ll")) + .run(); } \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check index 56ef2f0bdf3..53eafc10a65 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check +++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check @@ -1,3 +1,4 @@ -// TODO(KMJ-007): Update this test when TypeTree integration is complete -// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{} -// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double} \ No newline at end of file +// Check that enzyme_type attributes are present when TypeTree is enabled +// This verifies our TypeTree metadata attachment is working + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file -- cgit 1.4.1-3-g733a5 From 5d3ebc3804299503387b7ea1427f1619d410c2b2 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 23 Aug 2025 21:57:54 +0000 Subject: Add TypeTree tests for scalar types - Add specific tests for f32, f64, i32, f16, f128 TypeTree generation - Verify correct enzyme_type metadata for each scalar type - Ensure TypeTree metadata matches expected Enzyme format Signed-off-by: Karan Janthe --- .../type-trees/scalar-types/f128-typetree/f128.check | 4 ++++ .../type-trees/scalar-types/f128-typetree/rmake.rs | 12 ++++++++++++ .../type-trees/scalar-types/f128-typetree/test.rs | 15 +++++++++++++++ .../type-trees/scalar-types/f16-typetree/f16.check | 4 ++++ .../type-trees/scalar-types/f16-typetree/rmake.rs | 12 ++++++++++++ .../autodiff/type-trees/scalar-types/f16-typetree/test.rs | 15 +++++++++++++++ .../type-trees/scalar-types/f32-typetree/f32.check | 4 ++++ .../type-trees/scalar-types/f32-typetree/rmake.rs | 12 ++++++++++++ .../autodiff/type-trees/scalar-types/f32-typetree/test.rs | 15 +++++++++++++++ .../type-trees/scalar-types/f64-typetree/f64.check | 4 ++++ .../type-trees/scalar-types/f64-typetree/rmake.rs | 12 ++++++++++++ .../autodiff/type-trees/scalar-types/f64-typetree/test.rs | 15 +++++++++++++++ .../type-trees/scalar-types/i32-typetree/i32.check | 4 ++++ .../type-trees/scalar-types/i32-typetree/rmake.rs | 12 ++++++++++++ .../autodiff/type-trees/scalar-types/i32-typetree/test.rs | 14 ++++++++++++++ 15 files changed, 154 insertions(+) create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/test.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/test.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/test.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/test.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs (limited to 'tests') diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check new file mode 100644 index 00000000000..31cef113420 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check @@ -0,0 +1,4 @@ +; Check that f128 TypeTree metadata is correctly generated +; f128 maps to Unknown in our current implementation since CConcreteType doesn't have DT_F128 + +CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f128{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/rmake.rs new file mode 100644 index 00000000000..44320ecdd57 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/rmake.rs @@ -0,0 +1,12 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile with TypeTree enabled and emit LLVM IR + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + + // Check that f128 TypeTree metadata is correctly generated + llvm_filecheck().patterns("f128.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/test.rs new file mode 100644 index 00000000000..5c71baa3e69 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff, f128)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_f128(x: &f128) -> f128 { + *x * *x +} + +fn main() { + let x = 2.0_f128; + let mut dx = 0.0_f128; + let _result = d_test(&x, &mut dx, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check new file mode 100644 index 00000000000..97a3964a569 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check @@ -0,0 +1,4 @@ +; Check that f16 TypeTree metadata is correctly generated +; Should show Half for f16 values and Pointer for references + +CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f16{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/rmake.rs new file mode 100644 index 00000000000..0aebdbf5520 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/rmake.rs @@ -0,0 +1,12 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile with TypeTree enabled and emit LLVM IR + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + + // Check that f16 TypeTree metadata is correctly generated + llvm_filecheck().patterns("f16.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/test.rs new file mode 100644 index 00000000000..6b68e8252f4 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff, f16)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_f16(x: &f16) -> f16 { + *x * *x +} + +fn main() { + let x = 2.0_f16; + let mut dx = 0.0_f16; + let _result = d_test(&x, &mut dx, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check new file mode 100644 index 00000000000..b32d5086d07 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check @@ -0,0 +1,4 @@ +; Check that f32 TypeTree metadata is correctly generated +; Should show Float@float for f32 values and Pointer for references + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/rmake.rs new file mode 100644 index 00000000000..ee3ab753bf5 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/rmake.rs @@ -0,0 +1,12 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile with TypeTree enabled and emit LLVM IR + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + + // Check that f32 TypeTree metadata is correctly generated + llvm_filecheck().patterns("f32.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/test.rs new file mode 100644 index 00000000000..56c118399ee --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_f32(x: &f32) -> f32 { + x * x +} + +fn main() { + let x = 2.0_f32; + let mut dx = 0.0_f32; + let _result = d_test(&x, &mut dx, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check new file mode 100644 index 00000000000..0e9d9c03a0f --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check @@ -0,0 +1,4 @@ +; Check that f64 TypeTree metadata is correctly generated +; Should show Float@double for f64 values and Pointer for references + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/rmake.rs new file mode 100644 index 00000000000..5fac9b23bc8 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/rmake.rs @@ -0,0 +1,12 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile with TypeTree enabled and emit LLVM IR + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + + // Check that f64 TypeTree metadata is correctly generated + llvm_filecheck().patterns("f64.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/test.rs new file mode 100644 index 00000000000..235360b76b2 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_f64(x: &f64) -> f64 { + x * x +} + +fn main() { + let x = 2.0_f64; + let mut dx = 0.0_f64; + let _result = d_test(&x, &mut dx, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check new file mode 100644 index 00000000000..aa8e5223e8e --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check @@ -0,0 +1,4 @@ +; Check that i32 TypeTree metadata is correctly generated +; Should show Integer for i32 values (integers are typically Const in autodiff) + +CHECK: define{{.*}}"enzyme_type"="{[]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[]:Integer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/rmake.rs new file mode 100644 index 00000000000..a40fd55d88a --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/rmake.rs @@ -0,0 +1,12 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile with TypeTree enabled and emit LLVM IR + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + + // Check that i32 TypeTree metadata is correctly generated + llvm_filecheck().patterns("i32.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs new file mode 100644 index 00000000000..530c137f62b --- /dev/null +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs @@ -0,0 +1,14 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Const, Active)] +#[no_mangle] +fn test_i32(x: i32) -> i32 { + x * x +} + +fn main() { + let x = 5_i32; + let _result = d_test(x, 1); +} -- cgit 1.4.1-3-g733a5 From 664e83b3e76b51fb5192e74a64eef3bc5bbd4e32 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Sat, 23 Aug 2025 23:10:48 +0000 Subject: added typetree support for memcpy --- compiler/rustc_codegen_gcc/src/builder.rs | 1 + compiler/rustc_codegen_gcc/src/intrinsic/mod.rs | 1 + compiler/rustc_codegen_llvm/src/abi.rs | 1 + compiler/rustc_codegen_llvm/src/builder.rs | 15 +++++++-- compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 1 + compiler/rustc_codegen_llvm/src/typetree.rs | 20 ++++------- compiler/rustc_codegen_llvm/src/va_arg.rs | 1 + compiler/rustc_codegen_ssa/src/mir/block.rs | 1 + compiler/rustc_codegen_ssa/src/mir/intrinsic.rs | 2 +- compiler/rustc_codegen_ssa/src/mir/statement.rs | 2 +- compiler/rustc_codegen_ssa/src/traits/builder.rs | 3 +- compiler/rustc_interface/src/tests.rs | 1 - compiler/rustc_middle/src/ty/mod.rs | 4 +-- tests/codegen-llvm/autodiff/typetree.rs | 2 +- .../type-trees/memcpy-typetree/memcpy-ir.check | 8 +++++ .../type-trees/memcpy-typetree/memcpy.check | 13 ++++++++ .../autodiff/type-trees/memcpy-typetree/memcpy.rs | 36 ++++++++++++++++++++ .../autodiff/type-trees/memcpy-typetree/rmake.rs | 39 ++++++++++++++++++++++ .../autodiff/type-trees/nott-flag/rmake.rs | 14 +++----- .../run-make/autodiff/type-trees/nott-flag/test.rs | 2 +- tests/ui/autodiff/flag_nott.rs | 2 +- 21 files changed, 135 insertions(+), 34 deletions(-) create mode 100644 tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check create mode 100644 tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check create mode 100644 tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs create mode 100644 tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs (limited to 'tests') diff --git a/compiler/rustc_codegen_gcc/src/builder.rs b/compiler/rustc_codegen_gcc/src/builder.rs index f7a7a3f8c7e..5657620879c 100644 --- a/compiler/rustc_codegen_gcc/src/builder.rs +++ b/compiler/rustc_codegen_gcc/src/builder.rs @@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> { _src_align: Align, size: RValue<'gcc>, flags: MemFlags, + _tt: Option, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_size_t(), false); diff --git a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs index 84fa56cf903..3b897a19835 100644 --- a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs +++ b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs @@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(self.layout.size.bytes()), MemFlags::empty(), + None, ); bx.lifetime_end(scratch, scratch_size); diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs index 11be7041167..61fadad6066 100644 --- a/compiler/rustc_codegen_llvm/src/abi.rs +++ b/compiler/rustc_codegen_llvm/src/abi.rs @@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(copy_bytes), MemFlags::empty(), + None, ); bx.lifetime_end(llscratch, scratch_size); } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 0f17cc9063a..83a9cf620f1 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow}; use std::ops::Deref; use std::{iter, ptr}; +use rustc_ast::expand::typetree::FncTree; pub(crate) mod autodiff; pub(crate) mod gpu_offload; @@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memcpy = unsafe { llvm::LLVMRustBuildMemCpy( self.llbuilder, dst, @@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + + // TypeTree metadata for memcpy is especially important: when Enzyme encounters + // a memcpy during autodiff, it needs to know the structure of the data being + // copied to properly track derivatives. For example, copying an array of floats + // vs. copying a struct with mixed types requires different derivative handling. + // The TypeTree tells Enzyme exactly what memory layout to expect. + if let Some(tt) = tt { + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 12b6ffa3c0b..9dd84ad9a4d 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -25,6 +25,7 @@ pub(crate) enum CConcreteType { DT_Half = 3, DT_Float = 4, DT_Double = 5, + // FIXME(KMJ-007): handle f128 using long double here(https://github.com/EnzymeAD/Enzyme/issues/1600) DT_Unknown = 6, } diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 434316464e6..8e6272a186b 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -1,8 +1,11 @@ -use std::ffi::{CString, c_char, c_uint}; - -use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; +use rustc_ast::expand::typetree::FncTree; +#[cfg(llvm_enzyme)] +use { + crate::attributes, + rustc_ast::expand::typetree::TypeTree as RustTypeTree, + std::ffi::{CString, c_char, c_uint}, +}; -use crate::attributes; use crate::llvm::{self, Value}; /// Converts a Rust TypeTree to Enzyme's internal TypeTree format @@ -50,15 +53,6 @@ fn to_enzyme_typetree( enzyme_tt } -#[cfg(not(llvm_enzyme))] -fn to_enzyme_typetree( - _rust_typetree: RustTypeTree, - _data_layout: &str, - _llcx: &llvm::Context, -) -> ! { - unimplemented!("TypeTree conversion not available without llvm_enzyme support") -} - // Attaches TypeTree information to LLVM function as enzyme_type attributes. #[cfg(llvm_enzyme)] pub(crate) fn add_tt<'ll>( diff --git a/compiler/rustc_codegen_llvm/src/va_arg.rs b/compiler/rustc_codegen_llvm/src/va_arg.rs index ab08125217f..d48c7cf874a 100644 --- a/compiler/rustc_codegen_llvm/src/va_arg.rs +++ b/compiler/rustc_codegen_llvm/src/va_arg.rs @@ -738,6 +738,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>( src_align, bx.const_u32(layout.layout.size().bytes() as u32), MemFlags::empty(), + None, ); tmp } else { diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index 1b218a0d339..b2dc4fe32b0 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -1626,6 +1626,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { align, bx.const_usize(copy_bytes), MemFlags::empty(), + None, ); // ...and then load it with the ABI type. llval = load_cast(bx, cast, llscratch, scratch_align); diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index 3c667b8e882..befa00c6861 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( if allow_overlap { bx.memmove(dst, align, src, align, size, flags); } else { - bx.memcpy(dst, align, src, align, size, flags); + bx.memcpy(dst, align, src, align, size, flags, None); } } diff --git a/compiler/rustc_codegen_ssa/src/mir/statement.rs b/compiler/rustc_codegen_ssa/src/mir/statement.rs index f164e0f9123..0a50d7f18db 100644 --- a/compiler/rustc_codegen_ssa/src/mir/statement.rs +++ b/compiler/rustc_codegen_ssa/src/mir/statement.rs @@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let align = pointee_layout.align; let dst = dst_val.immediate(); let src = src_val.immediate(); - bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty()); + bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None); } mir::StatementKind::FakeRead(..) | mir::StatementKind::Retag { .. } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 4a5694e97fa..60296e36e0c 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -451,6 +451,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memmove( &mut self, @@ -507,7 +508,7 @@ pub trait BuilderMethods<'a, 'tcx>: temp.val.store_with_flags(self, dst.with_type(layout), flags); } else if !layout.is_zst() { let bytes = self.const_usize(layout.size.bytes()); - self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags); + self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None); } } diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 837acdadd57..0dc5b5af3ac 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -765,7 +765,6 @@ fn test_unstable_options_tracking_hash() { tracked!(allow_features, Some(vec![String::from("lang_items")])); tracked!(always_encode_mir, true); tracked!(assume_incomplete_release, true); - tracked!(autodiff, vec![AutoDiff::Enable]); tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]); tracked!(binary_dep_depinfo, true); tracked!(box_noalias, false); diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index e7130757f1d..741b5d7fd4e 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2280,12 +2280,12 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { let child = typetree_from_ty(tcx, inner_ty); return TypeTree(vec![Type { offset: -1, - size: 8, // TODO(KMJ-007): Get actual pointer size from target + size: tcx.data_layout.pointer_size().bytes_usize(), kind: Kind::Pointer, child, }]); } - // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types + // FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types TypeTree::new() } diff --git a/tests/codegen-llvm/autodiff/typetree.rs b/tests/codegen-llvm/autodiff/typetree.rs index 3ad38d581b9..1cb0c2fb68b 100644 --- a/tests/codegen-llvm/autodiff/typetree.rs +++ b/tests/codegen-llvm/autodiff/typetree.rs @@ -30,4 +30,4 @@ fn main() { let output_ = d_simple(&x, &mut df_dx, 1.0); assert_eq!(output, output_); assert_eq!(2.0, df_dx); -} \ No newline at end of file +} diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check new file mode 100644 index 00000000000..3a59f06dda1 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check @@ -0,0 +1,8 @@ +; Check that enzyme_type attributes are present in the LLVM IR function definition +; This verifies our TypeTree system correctly attaches metadata for Enzyme + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}" + +; Check that llvm.memcpy exists (either call or declare) +CHECK: {{(call|declare).*}}@llvm.memcpy + diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check new file mode 100644 index 00000000000..ae70830297a --- /dev/null +++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.check @@ -0,0 +1,13 @@ +CHECK: force_memcpy + +CHECK: @llvm.memcpy.p0.p0.i64 + +CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer}:{} + +CHECK-DAG: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double} + +CHECK-DAG: load double{{.*}}: {[-1]:Float@double} + +CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double} + +CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs new file mode 100644 index 00000000000..3c1029190c8 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy.rs @@ -0,0 +1,36 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; +use std::ptr; + +#[inline(never)] +fn force_memcpy(src: *const f64, dst: *mut f64, count: usize) { + unsafe { + ptr::copy_nonoverlapping(src, dst, count); + } +} + +#[autodiff_reverse(d_test_memcpy, Duplicated, Active)] +#[no_mangle] +fn test_memcpy(input: &[f64; 128]) -> f64 { + let mut local_data = [0.0f64; 128]; + + // Use a separate function to prevent inlining and optimization + force_memcpy(input.as_ptr(), local_data.as_mut_ptr(), 128); + + // Sum only first few elements to keep the computation simple + local_data[0] * local_data[0] + + local_data[1] * local_data[1] + + local_data[2] * local_data[2] + + local_data[3] * local_data[3] +} + +fn main() { + let input = [1.0; 128]; + let mut d_input = [0.0; 128]; + let result = test_memcpy(&input); + let result_d = d_test_memcpy(&input, &mut d_input, 1.0); + + assert_eq!(result, result_d); + println!("Memcpy test passed: result = {}", result); +} diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs new file mode 100644 index 00000000000..b4c650330fe --- /dev/null +++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/rmake.rs @@ -0,0 +1,39 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // First, compile to LLVM IR to check for enzyme_type attributes + let _ir_output = rustc() + .input("memcpy.rs") + .arg("-Zautodiff=Enable") + .arg("-Zautodiff=NoPostopt") + .opt_level("0") + .arg("--emit=llvm-ir") + .arg("-o") + .arg("main.ll") + .run(); + + // Then compile with TypeTree analysis output for the existing checks + let output = rustc() + .input("memcpy.rs") + .arg("-Zautodiff=Enable,PrintTAFn=test_memcpy") + .arg("-Zautodiff=NoPostopt") + .opt_level("3") + .arg("-Clto=fat") + .arg("-g") + .run(); + + let stdout = output.stdout_utf8(); + let stderr = output.stderr_utf8(); + let ir_content = rfs::read_to_string("main.ll"); + + rfs::write("memcpy.stdout", &stdout); + rfs::write("memcpy.stderr", &stderr); + rfs::write("main.ir", &ir_content); + + llvm_filecheck().patterns("memcpy.check").stdin_buf(stdout).run(); + + llvm_filecheck().patterns("memcpy-ir.check").stdin_buf(ir_content).run(); +} diff --git a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs index bab863ca9ff..de540b990ca 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs +++ b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs @@ -23,14 +23,8 @@ fn main() { .run(); // Verify NoTT version does NOT have enzyme_type attributes - llvm_filecheck() - .patterns("nott.check") - .stdin_buf(rfs::read("nott.ll")) - .run(); - + llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.ll")).run(); + // Verify TypeTree version DOES have enzyme_type attributes - llvm_filecheck() - .patterns("with_tt.check") - .stdin_buf(rfs::read("with_tt.ll")) - .run(); -} \ No newline at end of file + llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/nott-flag/test.rs b/tests/run-make/autodiff/type-trees/nott-flag/test.rs index 5c634eea035..de3549c37c6 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/test.rs +++ b/tests/run-make/autodiff/type-trees/nott-flag/test.rs @@ -12,4 +12,4 @@ fn main() { let x = 2.0; let mut dx = 0.0; let _result = d_square(&x, &mut dx, 1.0); -} \ No newline at end of file +} diff --git a/tests/ui/autodiff/flag_nott.rs b/tests/ui/autodiff/flag_nott.rs index 7a97d892cd8..faa9949fe81 100644 --- a/tests/ui/autodiff/flag_nott.rs +++ b/tests/ui/autodiff/flag_nott.rs @@ -16,4 +16,4 @@ fn main() { let x = 2.0; let mut dx = 0.0; let result = d_square(&x, &mut dx, 1.0); -} \ No newline at end of file +} -- cgit 1.4.1-3-g733a5 From 31541feb6f7e46c23141bdeb3e35ecd305bf8762 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Mon, 1 Sep 2025 05:05:49 +0000 Subject: autodiff: add TypeTree support for arrays --- compiler/rustc_middle/src/ty/mod.rs | 42 +++++++++++++++++++++- .../autodiff/type-trees/array-typetree/array.check | 4 +++ .../autodiff/type-trees/array-typetree/rmake.rs | 9 +++++ .../autodiff/type-trees/array-typetree/test.rs | 15 ++++++++ 4 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/run-make/autodiff/type-trees/array-typetree/array.check create mode 100644 tests/run-make/autodiff/type-trees/array-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/array-typetree/test.rs (limited to 'tests') diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 741b5d7fd4e..02a4e4e2b15 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2286,6 +2286,46 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { }]); } - // FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types + if ty.is_array() { + if let ty::Array(element_ty, len_const) = ty.kind() { + let len = len_const.try_to_target_usize(tcx).unwrap_or(0); + if len == 0 { + return TypeTree::new(); + } + + let element_tree = typetree_from_ty(tcx, *element_ty); + + let element_layout = tcx + .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty)) + .ok() + .map(|layout| layout.size.bytes_usize()) + .unwrap_or(0); + + if element_layout == 0 { + return TypeTree::new(); + } + + let mut types = Vec::new(); + for i in 0..len { + let base_offset = (i as usize * element_layout) as isize; + + for elem_type in &element_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + base_offset + } else { + base_offset + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + } + + return TypeTree(types); + } + } + TypeTree::new() } diff --git a/tests/run-make/autodiff/type-trees/array-typetree/array.check b/tests/run-make/autodiff/type-trees/array-typetree/array.check new file mode 100644 index 00000000000..7513458b8ab --- /dev/null +++ b/tests/run-make/autodiff/type-trees/array-typetree/array.check @@ -0,0 +1,4 @@ +; Check that array TypeTree metadata is correctly generated +; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes) + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs new file mode 100644 index 00000000000..20b6a066906 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/array-typetree/test.rs b/tests/run-make/autodiff/type-trees/array-typetree/test.rs new file mode 100644 index 00000000000..f54ebf5a4c7 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/array-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_array(arr: &[f64; 5]) -> f64 { + arr[0] + arr[1] + arr[2] + arr[3] + arr[4] +} + +fn main() { + let arr = [1.0, 2.0, 3.0, 4.0, 5.0]; + let mut d_arr = [0.0; 5]; + let _result = d_test(&arr, &mut d_arr, 1.0); +} -- cgit 1.4.1-3-g733a5 From be3617b04031c7d4f227ce54c4ce6ceeae00981c Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Mon, 1 Sep 2025 06:17:34 +0000 Subject: autodiff: slice support in typetree --- compiler/rustc_middle/src/ty/mod.rs | 7 +++++++ .../run-make/autodiff/type-trees/slice-typetree/rmake.rs | 9 +++++++++ .../autodiff/type-trees/slice-typetree/slice.check | 4 ++++ .../run-make/autodiff/type-trees/slice-typetree/test.rs | 16 ++++++++++++++++ 4 files changed, 36 insertions(+) create mode 100644 tests/run-make/autodiff/type-trees/slice-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/slice-typetree/slice.check create mode 100644 tests/run-make/autodiff/type-trees/slice-typetree/test.rs (limited to 'tests') diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 02a4e4e2b15..82a41f403f8 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2327,5 +2327,12 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { } } + if ty.is_slice() { + if let ty::Slice(element_ty) = ty.kind() { + let element_tree = typetree_from_ty(tcx, *element_ty); + return element_tree; + } + } + TypeTree::new() } diff --git a/tests/run-make/autodiff/type-trees/slice-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/slice-typetree/rmake.rs new file mode 100644 index 00000000000..b81fb50bf1a --- /dev/null +++ b/tests/run-make/autodiff/type-trees/slice-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("slice.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/slice-typetree/slice.check b/tests/run-make/autodiff/type-trees/slice-typetree/slice.check new file mode 100644 index 00000000000..8fc9e9c4f1b --- /dev/null +++ b/tests/run-make/autodiff/type-trees/slice-typetree/slice.check @@ -0,0 +1,4 @@ +; Check that slice TypeTree metadata is correctly generated +; Should show Float@double for slice elements + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_slice{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/slice-typetree/test.rs b/tests/run-make/autodiff/type-trees/slice-typetree/test.rs new file mode 100644 index 00000000000..7117fa3844f --- /dev/null +++ b/tests/run-make/autodiff/type-trees/slice-typetree/test.rs @@ -0,0 +1,16 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_slice(slice: &[f64]) -> f64 { + slice.iter().sum() +} + +fn main() { + let arr = [1.0, 2.0, 3.0, 4.0, 5.0]; + let slice = &arr[..]; + let mut d_slice = [0.0; 5]; + let _result = d_test(slice, &mut d_slice[..], 1.0); +} -- cgit 1.4.1-3-g733a5 From 7c5fbfbdbbb389462e0ffb936ba9b16cffbce6ed Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Mon, 1 Sep 2025 16:14:31 +0000 Subject: autodiff: tuple support in typetree --- compiler/rustc_middle/src/ty/mod.rs | 36 ++++++++++++++++++++++ .../autodiff/type-trees/tuple-typetree/rmake.rs | 9 ++++++ .../autodiff/type-trees/tuple-typetree/test.rs | 15 +++++++++ .../autodiff/type-trees/tuple-typetree/tuple.check | 4 +++ 4 files changed, 64 insertions(+) create mode 100644 tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/tuple-typetree/test.rs create mode 100644 tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check (limited to 'tests') diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 82a41f403f8..c0a091551c9 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2334,5 +2334,41 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { } } + if let ty::Tuple(tuple_types) = ty.kind() { + if tuple_types.is_empty() { + return TypeTree::new(); + } + + let mut types = Vec::new(); + let mut current_offset = 0; + + for tuple_ty in tuple_types.iter() { + let element_tree = typetree_from_ty(tcx, tuple_ty); + + let element_layout = tcx + .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty)) + .ok() + .map(|layout| layout.size.bytes_usize()) + .unwrap_or(0); + + for elem_type in &element_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + current_offset as isize + } else { + current_offset as isize + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + + current_offset += element_layout; + } + + return TypeTree(types); + } + TypeTree::new() } diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs new file mode 100644 index 00000000000..76913828901 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("tuple.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/test.rs b/tests/run-make/autodiff/type-trees/tuple-typetree/test.rs new file mode 100644 index 00000000000..32187b587a3 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/tuple-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_tuple(tuple: &(f64, f64, f64)) -> f64 { + tuple.0 + tuple.1 * 2.0 + tuple.2 * 3.0 +} + +fn main() { + let tuple = (1.0, 2.0, 3.0); + let mut d_tuple = (0.0, 0.0, 0.0); + let _result = d_test(&tuple, &mut d_tuple, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check b/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check new file mode 100644 index 00000000000..50aa25e96ae --- /dev/null +++ b/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check @@ -0,0 +1,4 @@ +; Check that tuple TypeTree metadata is correctly generated +; Should show Float@double at offsets 0, 8, 16 for (f64, f64, f64) + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file -- cgit 1.4.1-3-g733a5 From 574f0b97d6f30cd6cedb165fde13cdec176611b8 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Mon, 1 Sep 2025 16:28:14 +0000 Subject: autodiff: struct support in typetree --- compiler/rustc_middle/src/ty/mod.rs | 32 ++++++++++++++++++++++ .../autodiff/type-trees/struct-typetree/rmake.rs | 9 ++++++ .../type-trees/struct-typetree/struct.check | 4 +++ .../autodiff/type-trees/struct-typetree/test.rs | 22 +++++++++++++++ 4 files changed, 67 insertions(+) create mode 100644 tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/struct-typetree/struct.check create mode 100644 tests/run-make/autodiff/type-trees/struct-typetree/test.rs (limited to 'tests') diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index c0a091551c9..703d417f96d 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2370,5 +2370,37 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { return TypeTree(types); } + if let ty::Adt(adt_def, args) = ty.kind() { + if adt_def.is_struct() { + let struct_layout = + tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty)); + if let Ok(layout) = struct_layout { + let mut types = Vec::new(); + + for (field_idx, field_def) in adt_def.all_fields().enumerate() { + let field_ty = field_def.ty(tcx, args); + let field_tree = typetree_from_ty(tcx, field_ty); + + let field_offset = layout.fields.offset(field_idx).bytes_usize(); + + for elem_type in &field_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + field_offset as isize + } else { + field_offset as isize + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + } + + return TypeTree(types); + } + } + } + TypeTree::new() } diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs new file mode 100644 index 00000000000..0af1b65ee18 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("struct.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/struct.check b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check new file mode 100644 index 00000000000..2f763f18c1c --- /dev/null +++ b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check @@ -0,0 +1,4 @@ +; Check that struct TypeTree metadata is correctly generated +; Should show Float@double at offsets 0, 8, 16 for Point struct fields + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/test.rs b/tests/run-make/autodiff/type-trees/struct-typetree/test.rs new file mode 100644 index 00000000000..cbe7b10e409 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/struct-typetree/test.rs @@ -0,0 +1,22 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[repr(C)] +struct Point { + x: f64, + y: f64, + z: f64, +} + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_struct(point: &Point) -> f64 { + point.x + point.y * 2.0 + point.z * 3.0 +} + +fn main() { + let point = Point { x: 1.0, y: 2.0, z: 3.0 }; + let mut d_point = Point { x: 0.0, y: 0.0, z: 0.0 }; + let _result = d_test(&point, &mut d_point, 1.0); +} -- cgit 1.4.1-3-g733a5 From 4f3f0f48e7b1e61818b2bcbe4451f89bb4f47049 Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Thu, 4 Sep 2025 11:17:34 +0000 Subject: autodiff: fixed test to be more precise for type tree checking --- compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 23 +++++++ compiler/rustc_codegen_llvm/src/typetree.rs | 70 +++++++++------------- compiler/rustc_middle/src/ty/mod.rs | 36 +++-------- .../autodiff/type-trees/array-typetree/array.check | 2 +- .../type-trees/memcpy-typetree/memcpy-ir.check | 2 +- .../type-trees/mixed-struct-typetree/mixed.check | 2 + .../type-trees/mixed-struct-typetree/rmake.rs | 16 +++++ .../type-trees/mixed-struct-typetree/test.rs | 23 +++++++ .../autodiff/type-trees/nott-flag/with_tt.check | 2 +- .../scalar-types/f128-typetree/f128.check | 4 +- .../type-trees/scalar-types/f16-typetree/f16.check | 4 +- .../type-trees/scalar-types/f32-typetree/f32.check | 2 +- .../type-trees/scalar-types/f64-typetree/f64.check | 2 +- .../type-trees/scalar-types/i32-typetree/i32.check | 4 +- .../type-trees/scalar-types/i32-typetree/test.rs | 7 ++- .../autodiff/type-trees/slice-typetree/slice.check | 2 +- .../type-trees/struct-typetree/struct.check | 2 +- .../autodiff/type-trees/tuple-typetree/tuple.check | 2 +- .../type-trees/type-analysis/vec/vec.check | 2 +- 19 files changed, 120 insertions(+), 87 deletions(-) create mode 100644 tests/run-make/autodiff/type-trees/mixed-struct-typetree/mixed.check create mode 100644 tests/run-make/autodiff/type-trees/mixed-struct-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/mixed-struct-typetree/test.rs (limited to 'tests') diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 1596dc48379..e63043b2122 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -118,6 +118,13 @@ pub(crate) mod Enzyme_AD { max_size: i64, add_offset: u64, ); + pub(crate) fn EnzymeTypeTreeInsertEq( + CTT: CTypeTreeRef, + indices: *const i64, + len: usize, + ct: CConcreteType, + ctx: &Context, + ); pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); } @@ -234,6 +241,16 @@ pub(crate) mod Fallback_AD { unimplemented!() } + pub(crate) unsafe fn EnzymeTypeTreeInsertEq( + CTT: CTypeTreeRef, + indices: *const i64, + len: usize, + ct: CConcreteType, + ctx: &Context, + ) { + unimplemented!() + } + pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char { unimplemented!() } @@ -312,6 +329,12 @@ impl TypeTree { self } + + pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) { + unsafe { + EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx); + } + } } impl Clone for TypeTree { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 8c0d255bba8..ae6a2da62b5 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -8,22 +8,24 @@ use { use crate::llvm::{self, Value}; -/// Converts a Rust TypeTree to Enzyme's internal TypeTree format -/// -/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree) -/// and converts it to Enzyme's internal C++ TypeTree representation that -/// Enzyme can understand during differentiation analysis. #[cfg(llvm_enzyme)] fn to_enzyme_typetree( rust_typetree: RustTypeTree, - data_layout: &str, + _data_layout: &str, llcx: &llvm::Context, ) -> llvm::TypeTree { - // Start with an empty TypeTree let mut enzyme_tt = llvm::TypeTree::new(); - - // Convert each Type in the Rust TypeTree to Enzyme format - for rust_type in rust_typetree.0 { + process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx); + enzyme_tt +} +#[cfg(llvm_enzyme)] +fn process_typetree_recursive( + enzyme_tt: &mut llvm::TypeTree, + rust_typetree: &RustTypeTree, + parent_indices: &[i64], + llcx: &llvm::Context, +) { + for rust_type in &rust_typetree.0 { let concrete_type = match rust_type.kind { rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, @@ -35,25 +37,27 @@ fn to_enzyme_typetree( rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown, }; - // Create a TypeTree for this specific type - let type_tt = llvm::TypeTree::from_type(concrete_type, llcx); - - // Apply offset if specified - let type_tt = if rust_type.offset == -1 { - type_tt // -1 means everywhere/no specific offset + let mut indices = parent_indices.to_vec(); + if !parent_indices.is_empty() { + if rust_type.offset == -1 { + indices.push(-1); + } else { + indices.push(rust_type.offset as i64); + } + } else if rust_type.offset == -1 { + indices.push(-1); } else { - // Apply specific offset positioning - type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0) - }; + indices.push(rust_type.offset as i64); + } - // Merge this type into the main TypeTree - enzyme_tt = enzyme_tt.merge(type_tt); - } + enzyme_tt.insert(&indices, concrete_type, llcx); - enzyme_tt + if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() { + process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx); + } + } } -// Attaches TypeTree information to LLVM function as enzyme_type attributes. #[cfg(llvm_enzyme)] pub(crate) fn add_tt<'ll>( llmod: &'ll llvm::Module, @@ -64,28 +68,20 @@ pub(crate) fn add_tt<'ll>( let inputs = tt.args; let ret_tt: RustTypeTree = tt.ret; - // Get LLVM data layout string for TypeTree conversion let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; let llvm_data_layout = std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) .expect("got a non-UTF8 data-layout from LLVM"); - // Attribute name that Enzyme recognizes for TypeTree information let attr_name = "enzyme_type"; let c_attr_name = CString::new(attr_name).unwrap(); - // Attach TypeTree attributes to each input parameter - // Enzyme uses these to understand parameter memory layouts during differentiation for (i, input) in inputs.iter().enumerate() { unsafe { - // Convert Rust TypeTree to Enzyme's internal format let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); - - // Serialize TypeTree to string format that Enzyme can parse let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); let c_str = std::ffi::CStr::from_ptr(c_str); - // Create LLVM string attribute with TypeTree information let attr = llvm::LLVMCreateStringAttribute( llcx, c_attr_name.as_ptr(), @@ -94,17 +90,11 @@ pub(crate) fn add_tt<'ll>( c_str.to_bytes().len() as c_uint, ); - // Attach attribute to the specific function parameter - // Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); - - // Free the C string to prevent memory leaks llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); } } - // Attach TypeTree attribute to the return type - // Enzyme needs this to understand how to handle return value derivatives unsafe { let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); @@ -118,15 +108,11 @@ pub(crate) fn add_tt<'ll>( c_str.to_bytes().len() as c_uint, ); - // Attach to function return type attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); - - // Free the C string llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); } } -// Fallback implementation when Enzyme is not available #[cfg(not(llvm_enzyme))] pub(crate) fn add_tt<'ll>( _llmod: &'ll llvm::Module, diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 703d417f96d..581f20e8492 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2261,10 +2261,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { x if x == tcx.types.f32 => (Kind::Float, 4), x if x == tcx.types.f64 => (Kind::Double, 8), x if x == tcx.types.f128 => (Kind::F128, 16), - _ => return TypeTree::new(), + _ => (Kind::Integer, 0), } } else { - return TypeTree::new(); + (Kind::Integer, 0) }; return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]); @@ -2295,32 +2295,14 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { let element_tree = typetree_from_ty(tcx, *element_ty); - let element_layout = tcx - .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty)) - .ok() - .map(|layout| layout.size.bytes_usize()) - .unwrap_or(0); - - if element_layout == 0 { - return TypeTree::new(); - } - let mut types = Vec::new(); - for i in 0..len { - let base_offset = (i as usize * element_layout) as isize; - - for elem_type in &element_tree.0 { - types.push(Type { - offset: if elem_type.offset == -1 { - base_offset - } else { - base_offset + elem_type.offset - }, - size: elem_type.size, - kind: elem_type.kind, - child: elem_type.child.clone(), - }); - } + for elem_type in &element_tree.0 { + types.push(Type { + offset: -1, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); } return TypeTree(types); diff --git a/tests/run-make/autodiff/type-trees/array-typetree/array.check b/tests/run-make/autodiff/type-trees/array-typetree/array.check index 7513458b8ab..0d38bdec17e 100644 --- a/tests/run-make/autodiff/type-trees/array-typetree/array.check +++ b/tests/run-make/autodiff/type-trees/array-typetree/array.check @@ -1,4 +1,4 @@ ; Check that array TypeTree metadata is correctly generated ; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes) -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check index 3a59f06dda1..0e6351ac4d3 100644 --- a/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check +++ b/tests/run-make/autodiff/type-trees/memcpy-typetree/memcpy-ir.check @@ -1,7 +1,7 @@ ; Check that enzyme_type attributes are present in the LLVM IR function definition ; This verifies our TypeTree system correctly attaches metadata for Enzyme -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}" +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" ; Check that llvm.memcpy exists (either call or declare) CHECK: {{(call|declare).*}}@llvm.memcpy diff --git a/tests/run-make/autodiff/type-trees/mixed-struct-typetree/mixed.check b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/mixed.check new file mode 100644 index 00000000000..584f5840843 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/mixed.check @@ -0,0 +1,2 @@ +; Check that mixed struct with large array generates correct detailed type tree +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_mixed_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Float@float}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/mixed-struct-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/rmake.rs new file mode 100644 index 00000000000..1c19963bc36 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/rmake.rs @@ -0,0 +1,16 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc() + .input("test.rs") + .arg("-Zautodiff=Enable") + .arg("-Zautodiff=NoPostopt") + .opt_level("0") + .emit("llvm-ir") + .run(); + + llvm_filecheck().patterns("mixed.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/mixed-struct-typetree/test.rs b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/test.rs new file mode 100644 index 00000000000..7a734980e61 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/mixed-struct-typetree/test.rs @@ -0,0 +1,23 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[repr(C)] +struct Container { + header: i64, + data: [f32; 1000], +} + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +#[inline(never)] +fn test_mixed_struct(container: &Container) -> f32 { + container.data[0] + container.data[999] +} + +fn main() { + let container = Container { header: 42, data: [1.0; 1000] }; + let mut d_container = Container { header: 0, data: [0.0; 1000] }; + let result = d_test(&container, &mut d_container, 1.0); + std::hint::black_box(result); +} diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check index 53eafc10a65..3c02003c882 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check +++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check @@ -1,4 +1,4 @@ // Check that enzyme_type attributes are present when TypeTree is enabled // This verifies our TypeTree metadata attachment is working -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check index 31cef113420..733e46aa45a 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check @@ -1,4 +1,4 @@ ; Check that f128 TypeTree metadata is correctly generated -; f128 maps to Unknown in our current implementation since CConcreteType doesn't have DT_F128 +; Should show Float@fp128 for f128 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f128{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check index 97a3964a569..9caca26b5cf 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check @@ -1,4 +1,4 @@ ; Check that f16 TypeTree metadata is correctly generated -; Should show Half for f16 values and Pointer for references +; Should show Float@half for f16 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"={{.*}}@test_f16{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@half}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check index b32d5086d07..ec12ba6b234 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check @@ -1,4 +1,4 @@ ; Check that f32 TypeTree metadata is correctly generated ; Should show Float@float for f32 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check index 0e9d9c03a0f..f1af270824d 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check @@ -1,4 +1,4 @@ ; Check that f64 TypeTree metadata is correctly generated ; Should show Float@double for f64 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check index aa8e5223e8e..fd9a94be810 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check @@ -1,4 +1,4 @@ ; Check that i32 TypeTree metadata is correctly generated -; Should show Integer for i32 values (integers are typically Const in autodiff) +; Should show Integer for i32 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[]:Integer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs index 530c137f62b..249803c5d9f 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/test.rs @@ -2,13 +2,14 @@ use std::autodiff::autodiff_reverse; -#[autodiff_reverse(d_test, Const, Active)] +#[autodiff_reverse(d_test, Duplicated, Active)] #[no_mangle] -fn test_i32(x: i32) -> i32 { +fn test_i32(x: &i32) -> i32 { x * x } fn main() { let x = 5_i32; - let _result = d_test(x, 1); + let mut dx = 0_i32; + let _result = d_test(&x, &mut dx, 1); } diff --git a/tests/run-make/autodiff/type-trees/slice-typetree/slice.check b/tests/run-make/autodiff/type-trees/slice-typetree/slice.check index 8fc9e9c4f1b..6543b616115 100644 --- a/tests/run-make/autodiff/type-trees/slice-typetree/slice.check +++ b/tests/run-make/autodiff/type-trees/slice-typetree/slice.check @@ -1,4 +1,4 @@ ; Check that slice TypeTree metadata is correctly generated ; Should show Float@double for slice elements -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_slice{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_slice{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/struct.check b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check index 2f763f18c1c..54956317e1e 100644 --- a/tests/run-make/autodiff/type-trees/struct-typetree/struct.check +++ b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check @@ -1,4 +1,4 @@ ; Check that struct TypeTree metadata is correctly generated ; Should show Float@double at offsets 0, 8, 16 for Point struct fields -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check b/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check index 50aa25e96ae..47647e78cc3 100644 --- a/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check +++ b/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check @@ -1,4 +1,4 @@ ; Check that tuple TypeTree metadata is correctly generated ; Should show Float@double at offsets 0, 8, 16 for (f64, f64, f64) -CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/type-analysis/vec/vec.check b/tests/run-make/autodiff/type-trees/type-analysis/vec/vec.check index dcf9508b69d..cdb70eb83fc 100644 --- a/tests/run-make/autodiff/type-trees/type-analysis/vec/vec.check +++ b/tests/run-make/autodiff/type-trees/type-analysis/vec/vec.check @@ -1,7 +1,7 @@ // CHECK: callee - {[-1]:Float@float} |{[-1]:Pointer}:{} // CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer} // CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 8, !dbg !{{[0-9]+}}: {[-1]:Pointer} -// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !102, !noundef !{{[0-9]+}}: {} +// CHECK-DAG: %{{[0-9]+}} = load ptr, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !nonnull !{{[0-9]+}}, !noundef !{{[0-9]+}}: {} // CHECK-DAG: %{{[0-9]+}} = getelementptr inbounds nuw i8, ptr %{{[0-9]+}}, i64 16, !dbg !{{[0-9]+}}: {[-1]:Pointer} // CHECK-DAG: %{{[0-9]+}} = load i64, ptr %{{[0-9]+}}, align 8, !dbg !{{[0-9]+}}, !noundef !{{[0-9]+}}: {} // CHECK-DAG: %{{[0-9]+}} = icmp eq i64 %{{[0-9]+}}, 0, !dbg !{{[0-9]+}}: {[-1]:Integer} -- cgit 1.4.1-3-g733a5 From 4520926bb527bd43edbf0de84c2b0c6a9c5fc5ce Mon Sep 17 00:00:00 2001 From: Karan Janthe Date: Thu, 11 Sep 2025 07:30:35 +0000 Subject: autodiff: recurion added for typetree --- compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 1 + compiler/rustc_codegen_llvm/src/typetree.rs | 10 +-- compiler/rustc_middle/src/ty/mod.rs | 76 ++++++++++++++-- .../autodiff/type-trees/nott-flag/with_tt.check | 2 +- .../type-trees/recursion-typetree/recursion.check | 3 + .../type-trees/recursion-typetree/rmake.rs | 9 ++ .../autodiff/type-trees/recursion-typetree/test.rs | 100 +++++++++++++++++++++ .../scalar-types/f128-typetree/f128.check | 2 +- .../type-trees/scalar-types/f16-typetree/f16.check | 2 +- .../type-trees/scalar-types/f32-typetree/f32.check | 2 +- .../type-trees/scalar-types/f64-typetree/f64.check | 2 +- .../type-trees/scalar-types/i32-typetree/i32.check | 2 +- 12 files changed, 191 insertions(+), 20 deletions(-) create mode 100644 tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check create mode 100644 tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs create mode 100644 tests/run-make/autodiff/type-trees/recursion-typetree/test.rs (limited to 'tests') diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index e63043b2122..b604f5139c8 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -127,6 +127,7 @@ pub(crate) mod Enzyme_AD { ); pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + pub(crate) fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint; } unsafe extern "C" { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index ae6a2da62b5..1a54884f6c5 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -39,11 +39,7 @@ fn process_typetree_recursive( let mut indices = parent_indices.to_vec(); if !parent_indices.is_empty() { - if rust_type.offset == -1 { - indices.push(-1); - } else { - indices.push(rust_type.offset as i64); - } + indices.push(rust_type.offset as i64); } else if rust_type.offset == -1 { indices.push(-1); } else { @@ -52,7 +48,9 @@ fn process_typetree_recursive( enzyme_tt.insert(&indices, concrete_type, llcx); - if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer && !rust_type.child.0.is_empty() { + if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer + && !rust_type.child.0.is_empty() + { process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx); } } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 581f20e8492..7ca2355947a 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2252,6 +2252,61 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { /// Generate TypeTree for a specific type. /// This function analyzes a Rust type and creates appropriate TypeTree metadata. pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + let mut visited = Vec::new(); + typetree_from_ty_inner(tcx, ty, 0, &mut visited) +} + +/// Internal recursive function for TypeTree generation with cycle detection and depth limiting. +fn typetree_from_ty_inner<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, +) -> TypeTree { + #[cfg(llvm_enzyme)] + { + unsafe extern "C" { + fn EnzymeGetMaxTypeDepth() -> ::std::os::raw::c_uint; + } + let max_depth = unsafe { EnzymeGetMaxTypeDepth() } as usize; + if depth > max_depth { + return TypeTree::new(); + } + } + + #[cfg(not(llvm_enzyme))] + if depth > 6 { + return TypeTree::new(); + } + + if visited.contains(&ty) { + return TypeTree::new(); + } + + visited.push(ty); + let result = typetree_from_ty_impl(tcx, ty, depth, visited); + visited.pop(); + result +} + +/// Implementation of TypeTree generation logic. +fn typetree_from_ty_impl<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, +) -> TypeTree { + typetree_from_ty_impl_inner(tcx, ty, depth, visited, false) +} + +/// Internal implementation with context about whether this is for a reference target. +fn typetree_from_ty_impl_inner<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + depth: usize, + visited: &mut Vec>, + is_reference_target: bool, +) -> TypeTree { if ty.is_scalar() { let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) @@ -2267,7 +2322,10 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { (Kind::Integer, 0) }; - return TypeTree(vec![Type { offset: -1, size, kind, child: TypeTree::new() }]); + // Use offset 0 for scalars that are direct targets of references (like &f64) + // Use offset -1 for scalars used directly (like function return types) + let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 }; + return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); } if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { @@ -2277,7 +2335,7 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { return TypeTree::new(); }; - let child = typetree_from_ty(tcx, inner_ty); + let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true); return TypeTree(vec![Type { offset: -1, size: tcx.data_layout.pointer_size().bytes_usize(), @@ -2292,9 +2350,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { if len == 0 { return TypeTree::new(); } - - let element_tree = typetree_from_ty(tcx, *element_ty); - + let element_tree = + typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); let mut types = Vec::new(); for elem_type in &element_tree.0 { types.push(Type { @@ -2311,7 +2368,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { if ty.is_slice() { if let ty::Slice(element_ty) = ty.kind() { - let element_tree = typetree_from_ty(tcx, *element_ty); + let element_tree = + typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); return element_tree; } } @@ -2325,7 +2383,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { let mut current_offset = 0; for tuple_ty in tuple_types.iter() { - let element_tree = typetree_from_ty(tcx, tuple_ty); + let element_tree = + typetree_from_ty_impl_inner(tcx, tuple_ty, depth + 1, visited, false); let element_layout = tcx .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty)) @@ -2361,7 +2420,8 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { for (field_idx, field_def) in adt_def.all_fields().enumerate() { let field_ty = field_def.ty(tcx, args); - let field_tree = typetree_from_ty(tcx, field_ty); + let field_tree = + typetree_from_ty_impl_inner(tcx, field_ty, depth + 1, visited, false); let field_offset = layout.fields.offset(field_idx).bytes_usize(); diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check index 3c02003c882..0b4c9119179 100644 --- a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check +++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check @@ -1,4 +1,4 @@ // Check that enzyme_type attributes are present when TypeTree is enabled // This verifies our TypeTree metadata attachment is working -CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@square{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check b/tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check new file mode 100644 index 00000000000..1960e7b816c --- /dev/null +++ b/tests/run-make/autodiff/type-trees/recursion-typetree/recursion.check @@ -0,0 +1,3 @@ +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_deep{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}" +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_graph{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,8]:Integer, [-1,16]:Integer, [-1,24]:Float@double}" +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_node{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs new file mode 100644 index 00000000000..78718f3a215 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/recursion-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("recursion.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/recursion-typetree/test.rs b/tests/run-make/autodiff/type-trees/recursion-typetree/test.rs new file mode 100644 index 00000000000..9d40bec1bf1 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/recursion-typetree/test.rs @@ -0,0 +1,100 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +// Self-referential struct to test recursion detection +#[derive(Clone)] +struct Node { + value: f64, + next: Option>, +} + +// Mutually recursive structs to test cycle detection +#[derive(Clone)] +struct GraphNodeA { + value: f64, + connections: Vec, +} + +#[derive(Clone)] +struct GraphNodeB { + weight: f64, + target: Option>, +} + +#[autodiff_reverse(d_test_node, Duplicated, Active)] +#[no_mangle] +fn test_node(node: &Node) -> f64 { + node.value * 2.0 +} + +#[autodiff_reverse(d_test_graph, Duplicated, Active)] +#[no_mangle] +fn test_graph(a: &GraphNodeA) -> f64 { + a.value * 3.0 +} + +// Simple depth test - deeply nested but not circular +#[derive(Clone)] +struct Level1 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level2 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level3 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level4 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level5 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level6 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level7 { + val: f64, + next: Option>, +} +#[derive(Clone)] +struct Level8 { + val: f64, +} + +#[autodiff_reverse(d_test_deep, Duplicated, Active)] +#[no_mangle] +fn test_deep(deep: &Level1) -> f64 { + deep.val * 4.0 +} + +fn main() { + let node = Node { value: 1.0, next: None }; + + let graph = GraphNodeA { value: 2.0, connections: vec![] }; + + let deep = Level1 { val: 5.0, next: None }; + + let mut d_node = Node { value: 0.0, next: None }; + + let mut d_graph = GraphNodeA { value: 0.0, connections: vec![] }; + + let mut d_deep = Level1 { val: 0.0, next: None }; + + let _result1 = d_test_node(&node, &mut d_node, 1.0); + let _result2 = d_test_graph(&graph, &mut d_graph, 1.0); + let _result3 = d_test_deep(&deep, &mut d_deep, 1.0); +} diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check index 733e46aa45a..23db64eea52 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f128-typetree/f128.check @@ -1,4 +1,4 @@ ; Check that f128 TypeTree metadata is correctly generated ; Should show Float@fp128 for f128 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@fp128}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@fp128}"{{.*}}@test_f128{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@fp128}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check index 9caca26b5cf..9adff68d36f 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f16-typetree/f16.check @@ -1,4 +1,4 @@ ; Check that f16 TypeTree metadata is correctly generated ; Should show Float@half for f16 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@half}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@half}"{{.*}}@test_f16{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@half}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check index ec12ba6b234..176630f57e8 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f32-typetree/f32.check @@ -1,4 +1,4 @@ ; Check that f32 TypeTree metadata is correctly generated ; Should show Float@float for f32 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@float}"{{.*}}@test_f32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check index f1af270824d..929cd379694 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/f64-typetree/f64.check @@ -1,4 +1,4 @@ ; Check that f64 TypeTree metadata is correctly generated ; Should show Float@double for f64 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Float@double}"{{.*}}@test_f64{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check index fd9a94be810..dee4aa5bbb6 100644 --- a/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check +++ b/tests/run-make/autodiff/type-trees/scalar-types/i32-typetree/i32.check @@ -1,4 +1,4 @@ ; Check that i32 TypeTree metadata is correctly generated ; Should show Integer for i32 values and Pointer for references -CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}" \ No newline at end of file +CHECK: define{{.*}}"enzyme_type"="{[-1]:Integer}"{{.*}}@test_i32{{.*}}"enzyme_type"="{[-1]:Pointer, [-1,0]:Integer}" \ No newline at end of file -- cgit 1.4.1-3-g733a5