about summary refs log tree commit diff
diff options
context:
space:
mode:
authorKaran Janthe <karanjanthe@gmail.com>2025-09-01 05:05:49 +0000
committerKaran Janthe <karanjanthe@gmail.com>2025-09-19 04:11:35 +0000
commit31541feb6f7e46c23141bdeb3e35ecd305bf8762 (patch)
tree60836416827be070e55eaf27b9d28a52edf03165
parent54f9376660707d4ca9fce51fd423658f75128ac4 (diff)
downloadrust-31541feb6f7e46c23141bdeb3e35ecd305bf8762.tar.gz
rust-31541feb6f7e46c23141bdeb3e35ecd305bf8762.zip
autodiff: add TypeTree support for arrays
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs42
-rw-r--r--tests/run-make/autodiff/type-trees/array-typetree/array.check4
-rw-r--r--tests/run-make/autodiff/type-trees/array-typetree/rmake.rs9
-rw-r--r--tests/run-make/autodiff/type-trees/array-typetree/test.rs15
4 files changed, 69 insertions, 1 deletions
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 741b5d7fd4e..02a4e4e2b15 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -2286,6 +2286,46 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
         }]);
     }
 
-    // FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types
+    if ty.is_array() {
+        if let ty::Array(element_ty, len_const) = ty.kind() {
+            let len = len_const.try_to_target_usize(tcx).unwrap_or(0);
+            if len == 0 {
+                return TypeTree::new();
+            }
+
+            let element_tree = typetree_from_ty(tcx, *element_ty);
+
+            let element_layout = tcx
+                .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty))
+                .ok()
+                .map(|layout| layout.size.bytes_usize())
+                .unwrap_or(0);
+
+            if element_layout == 0 {
+                return TypeTree::new();
+            }
+
+            let mut types = Vec::new();
+            for i in 0..len {
+                let base_offset = (i as usize * element_layout) as isize;
+
+                for elem_type in &element_tree.0 {
+                    types.push(Type {
+                        offset: if elem_type.offset == -1 {
+                            base_offset
+                        } else {
+                            base_offset + elem_type.offset
+                        },
+                        size: elem_type.size,
+                        kind: elem_type.kind,
+                        child: elem_type.child.clone(),
+                    });
+                }
+            }
+
+            return TypeTree(types);
+        }
+    }
+
     TypeTree::new()
 }
diff --git a/tests/run-make/autodiff/type-trees/array-typetree/array.check b/tests/run-make/autodiff/type-trees/array-typetree/array.check
new file mode 100644
index 00000000000..7513458b8ab
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/array-typetree/array.check
@@ -0,0 +1,4 @@
+; Check that array TypeTree metadata is correctly generated  
+; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes)
+
+CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[]:Pointer}"
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs
new file mode 100644
index 00000000000..20b6a066906
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/array-typetree/rmake.rs
@@ -0,0 +1,9 @@
+//@ needs-enzyme
+//@ ignore-cross-compile
+
+use run_make_support::{llvm_filecheck, rfs, rustc};
+
+fn main() {
+    rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
+    llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("test.ll")).run();
+}
diff --git a/tests/run-make/autodiff/type-trees/array-typetree/test.rs b/tests/run-make/autodiff/type-trees/array-typetree/test.rs
new file mode 100644
index 00000000000..f54ebf5a4c7
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/array-typetree/test.rs
@@ -0,0 +1,15 @@
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+#[autodiff_reverse(d_test, Duplicated, Active)]
+#[no_mangle]
+fn test_array(arr: &[f64; 5]) -> f64 {
+    arr[0] + arr[1] + arr[2] + arr[3] + arr[4]
+}
+
+fn main() {
+    let arr = [1.0, 2.0, 3.0, 4.0, 5.0];
+    let mut d_arr = [0.0; 5];
+    let _result = d_test(&arr, &mut d_arr, 1.0);
+}