diff options
| author | Karan Janthe <karanjanthe@gmail.com> | 2025-09-01 16:28:14 +0000 |
|---|---|---|
| committer | Karan Janthe <karanjanthe@gmail.com> | 2025-09-19 04:11:35 +0000 |
| commit | 574f0b97d6f30cd6cedb165fde13cdec176611b8 (patch) | |
| tree | e8df77057028818c27791a663ac01b539169e6d5 | |
| parent | 7c5fbfbdbbb389462e0ffb936ba9b16cffbce6ed (diff) | |
| download | rust-574f0b97d6f30cd6cedb165fde13cdec176611b8.tar.gz rust-574f0b97d6f30cd6cedb165fde13cdec176611b8.zip | |
autodiff: struct support in typetree
4 files changed, 67 insertions, 0 deletions
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index c0a091551c9..703d417f96d 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -2370,5 +2370,37 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { return TypeTree(types); } + if let ty::Adt(adt_def, args) = ty.kind() { + if adt_def.is_struct() { + let struct_layout = + tcx.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(ty)); + if let Ok(layout) = struct_layout { + let mut types = Vec::new(); + + for (field_idx, field_def) in adt_def.all_fields().enumerate() { + let field_ty = field_def.ty(tcx, args); + let field_tree = typetree_from_ty(tcx, field_ty); + + let field_offset = layout.fields.offset(field_idx).bytes_usize(); + + for elem_type in &field_tree.0 { + types.push(Type { + offset: if elem_type.offset == -1 { + field_offset as isize + } else { + field_offset as isize + elem_type.offset + }, + size: elem_type.size, + kind: elem_type.kind, + child: elem_type.child.clone(), + }); + } + } + + return TypeTree(types); + } + } + } + TypeTree::new() } diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs new file mode 100644 index 00000000000..0af1b65ee18 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/struct-typetree/rmake.rs @@ -0,0 +1,9 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run(); + llvm_filecheck().patterns("struct.check").stdin_buf(rfs::read("test.ll")).run(); +} diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/struct.check b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check new file mode 100644 index 00000000000..2f763f18c1c --- /dev/null +++ b/tests/run-make/autodiff/type-trees/struct-typetree/struct.check @@ -0,0 +1,4 @@ +; Check that struct TypeTree metadata is correctly generated +; Should show Float@double at offsets 0, 8, 16 for Point struct fields + +CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_struct{{.*}}"enzyme_type"="{[]:Pointer}" \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/struct-typetree/test.rs b/tests/run-make/autodiff/type-trees/struct-typetree/test.rs new file mode 100644 index 00000000000..cbe7b10e409 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/struct-typetree/test.rs @@ -0,0 +1,22 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[repr(C)] +struct Point { + x: f64, + y: f64, + z: f64, +} + +#[autodiff_reverse(d_test, Duplicated, Active)] +#[no_mangle] +fn test_struct(point: &Point) -> f64 { + point.x + point.y * 2.0 + point.z * 3.0 +} + +fn main() { + let point = Point { x: 1.0, y: 2.0, z: 3.0 }; + let mut d_point = Point { x: 0.0, y: 0.0, z: 0.0 }; + let _result = d_test(&point, &mut d_point, 1.0); +} |
