about summary refs log tree commit diff
diff options
context:
space:
mode:
authorKaran Janthe <karanjanthe@gmail.com>2025-09-01 16:28:14 +0000
committerKaran Janthe <karanjanthe@gmail.com>2025-09-19 04:11:35 +0000
commit574f0b97d6f30cd6cedb165fde13cdec176611b8 (patch)
treee8df77057028818c27791a663ac01b539169e6d5
parent7c5fbfbdbbb389462e0ffb936ba9b16cffbce6ed (diff)
downloadrust-574f0b97d6f30cd6cedb165fde13cdec176611b8.tar.gz
rust-574f0b97d6f30cd6cedb165fde13cdec176611b8.zip
autodiff: struct support in typetree
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs32
-rw-r--r--tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs9
-rw-r--r--tests/run-make/autodiff/type-trees/struct-typetree/struct.check4
-rw-r--r--tests/run-make/autodiff/type-trees/struct-typetree/test.rs22
4 files changed, 67 insertions, 0 deletions
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index c0a091551c9..703d417f96d 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -2370,5 +2370,37 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
         return TypeTree(types);
     }
 
+    if let ty::Adt(adt_def, args) = ty.kind() {
+        if adt_def.is_struct() {
+            let struct_layout =
+                tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty));
+            if let Ok(layout) = struct_layout {
+                let mut types = Vec::new();
+
+                for (field_idx, field_def) in adt_def.all_fields().enumerate() {
+                    let field_ty = field_def.ty(tcx, args);
+                    let field_tree = typetree_from_ty(tcx, field_ty);
+
+                    let field_offset = layout.fields.offset(field_idx).bytes_usize();
+
+                    for elem_type in &field_tree.0 {
+                        types.push(Type {
+                            offset: if elem_type.offset == -1 {
+                                field_offset as isize
+                            } else {
+                                field_offset as isize + elem_type.offset
+                            },
+                            size: elem_type.size,
+                            kind: elem_type.kind,
+                            child: elem_type.child.clone(),
+                        });
+                    }
+                }
+
+                return TypeTree(types);
+            }
+        }
+    }
+
     TypeTree::new()
 }
diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs
new file mode 100644
index 00000000000..0af1b65ee18
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs
@@ -0,0 +1,9 @@
+//@ needs-enzyme
+//@ ignore-cross-compile
+
+use run_make_support::{llvm_filecheck, rfs, rustc};
+
+fn main() {
+    rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
+    llvm_filecheck().patterns("struct.check").stdin_buf(rfs::read("test.ll")).run();
+}
diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/struct.check b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check
new file mode 100644
index 00000000000..2f763f18c1c
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check
@@ -0,0 +1,4 @@
+; Check that struct TypeTree metadata is correctly generated
+; Should show Float@double at offsets 0, 8, 16 for Point struct fields
+
+CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}"
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/test.rs b/tests/run-make/autodiff/type-trees/struct-typetree/test.rs
new file mode 100644
index 00000000000..cbe7b10e409
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/struct-typetree/test.rs
@@ -0,0 +1,22 @@
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+#[repr(C)]
+struct Point {
+    x: f64,
+    y: f64,
+    z: f64,
+}
+
+#[autodiff_reverse(d_test, Duplicated, Active)]
+#[no_mangle]
+fn test_struct(point: &Point) -> f64 {
+    point.x + point.y * 2.0 + point.z * 3.0
+}
+
+fn main() {
+    let point = Point { x: 1.0, y: 2.0, z: 3.0 };
+    let mut d_point = Point { x: 0.0, y: 0.0, z: 0.0 };
+    let _result = d_test(&point, &mut d_point, 1.0);
+}