diff options
| author | Karan Janthe <karanjanthe@gmail.com> | 2025-09-01 05:05:49 +0000 |
|---|---|---|
| committer | Karan Janthe <karanjanthe@gmail.com> | 2025-09-19 04:11:35 +0000 |
| commit | 31541feb6f7e46c23141bdeb3e35ecd305bf8762 (patch) | |
| tree | 60836416827be070e55eaf27b9d28a52edf03165 | |
| parent | 54f9376660707d4ca9fce51fd423658f75128ac4 (diff) | |
| download | rust-31541feb6f7e46c23141bdeb3e35ecd305bf8762.tar.gz rust-31541feb6f7e46c23141bdeb3e35ecd305bf8762.zip | |
autodiff: add TypeTree support for arrays
4 files changed, 69 insertions, 1 deletions
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 741b5d7fd4e..02a4e4e2b15 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2286,6 +2286,46 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { }]); } - // FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types + if ty.is_array() { + if let ty::Array(element_ty, len_const) = ty.kind() { + let len = len_const.try_to_target_usize(tcx).unwrap_or(0); + if len == 0 { + return TypeTree::new(); + } + + let element_tree = typetree_from_ty(tcx, *element_ty); + + let element_layout = tcx + .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty)) + .ok() + .map(|layout| layout.size.bytes_usize()) + .unwrap_or(0); + + if element_layout == 0 { + return TypeTree::new(); + } + + let mut types = Vec::new(); + for i in 0..len { + let base_offset = (i as usize * element_layout) as isize; + + for elem_type in &element_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + base_offset + } else { + base_offset + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + } + + return TypeTree(types); + } + } + TypeTree::new() } diff --git a/tests/run-make/autodiff/type-trees/array-typetree/array.check b/tests/run-make/autodiff/type-trees/array-typetree/array.check new file mode 100644 index 00000000000..7513458b8ab --- /dev/null +++ b/tests/run-make/autodiff/type-trees/array-typetree/array.check @@ -0,0 +1,4 @@ +; Check that array TypeTree metadata is correctly generated +; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes) + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs new file mode 100644 index 00000000000..20b6a066906 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/array-typetree/test.rs b/tests/run-make/autodiff/type-trees/array-typetree/test.rs new file mode 100644 index 00000000000..f54ebf5a4c7 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/array-typetree/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_array(arr: &[f64; 5]) -> f64 { + arr[0] + arr[1] + arr[2] + arr[3] + arr[4] +} + +fn main() { + let arr = [1.0, 2.0, 3.0, 4.0, 5.0]; + let mut d_arr = [0.0; 5]; + let _result = d_test(&arr, &mut d_arr, 1.0); +} |
