about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs36
-rw-r--r--tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs9
-rw-r--r--tests/run-make/autodiff/type-trees/tuple-typetree/test.rs15
-rw-r--r--tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check4
4 files changed, 64 insertions, 0 deletions
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 82a41f403f8..c0a091551c9 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -2334,5 +2334,41 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
         }
     }
 
+    if let ty::Tuple(tuple_types) = ty.kind() {
+        if tuple_types.is_empty() {
+            return TypeTree::new();
+        }
+
+        let mut types = Vec::new();
+        let mut current_offset = 0;
+
+        for tuple_ty in tuple_types.iter() {
+            let element_tree = typetree_from_ty(tcx, tuple_ty);
+
+            let element_layout = tcx
+                .layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(tuple_ty))
+                .ok()
+                .map(|layout| layout.size.bytes_usize())
+                .unwrap_or(0);
+
+            for elem_type in &element_tree.0 {
+                types.push(Type {
+                    offset: if elem_type.offset == -1 {
+                        current_offset as isize
+                    } else {
+                        current_offset as isize + elem_type.offset
+                    },
+                    size: elem_type.size,
+                    kind: elem_type.kind,
+                    child: elem_type.child.clone(),
+                });
+            }
+
+            current_offset += element_layout;
+        }
+
+        return TypeTree(types);
+    }
+
     TypeTree::new()
 }
diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs b/tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs
new file mode 100644
index 00000000000..76913828901
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/tuple-typetree/rmake.rs
@@ -0,0 +1,9 @@
+//@ needs-enzyme
+//@ ignore-cross-compile
+
+use run_make_support::{llvm_filecheck, rfs, rustc};
+
+fn main() {
+    rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
+    llvm_filecheck().patterns("tuple.check").stdin_buf(rfs::read("test.ll")).run();
+}
diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/test.rs b/tests/run-make/autodiff/type-trees/tuple-typetree/test.rs
new file mode 100644
index 00000000000..32187b587a3
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/tuple-typetree/test.rs
@@ -0,0 +1,15 @@
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+#[autodiff_reverse(d_test, Duplicated, Active)]
+#[no_mangle]
+fn test_tuple(tuple: &(f64, f64, f64)) -> f64 {
+    tuple.0 + tuple.1 * 2.0 + tuple.2 * 3.0
+}
+
+fn main() {
+    let tuple = (1.0, 2.0, 3.0);
+    let mut d_tuple = (0.0, 0.0, 0.0);
+    let _result = d_test(&tuple, &mut d_tuple, 1.0);
+}
diff --git a/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check b/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check
new file mode 100644
index 00000000000..50aa25e96ae
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/tuple-typetree/tuple.check
@@ -0,0 +1,4 @@
+; Check that tuple TypeTree metadata is correctly generated
+; Should show Float@double at offsets 0, 8, 16 for (f64, f64, f64)
+
+CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_tuple{{.*}}"enzyme_type"="{[]:Pointer}"
\ No newline at end of file