about summary refs log tree commit diff
diff options
context:
space:
mode:
authorKaran Janthe <karanjanthe@gmail.com>2025-08-23 20:17:32 +0000
committerKaran Janthe <karanjanthe@gmail.com>2025-09-19 04:02:19 +0000
commite1258e79d6cb709b26ded97d32de6c55f355e2aa (patch)
tree120c3432caefd49bd9b4227318c70c2519c346dc
parent2f4dfc753fd86c672aa4145940db075a8a149f17 (diff)
downloadrust-e1258e79d6cb709b26ded97d32de6c55f355e2aa.tar.gz
rust-e1258e79d6cb709b26ded97d32de6c55f355e2aa.zip
autodiff: Add basic TypeTree with NoTT flag
Signed-off-by: Karan Janthe <karanjanthe@gmail.com>
-rw-r--r--compiler/rustc_ast/src/expand/autodiff_attrs.rs17
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs2
-rw-r--r--compiler/rustc_interface/src/tests.rs1
-rw-r--r--compiler/rustc_middle/src/error.rs1
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs80
-rw-r--r--compiler/rustc_session/src/config.rs2
-rw-r--r--compiler/rustc_session/src/options.rs3
-rw-r--r--tests/codegen-llvm/autodiff/typetree.rs33
-rw-r--r--tests/run-make/autodiff/type-trees/nott-flag/nott.check3
-rw-r--r--tests/run-make/autodiff/type-trees/nott-flag/rmake.rs38
-rw-r--r--tests/run-make/autodiff/type-trees/nott-flag/test.rs15
-rw-r--r--tests/run-make/autodiff/type-trees/nott-flag/with_tt.check3
-rw-r--r--tests/ui/autodiff/flag_nott.rs19
13 files changed, 212 insertions, 5 deletions
diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
index 33451f99748..90f15753e99 100644
--- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs
+++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs
@@ -6,6 +6,7 @@
 use std::fmt::{self, Display, Formatter};
 use std::str::FromStr;
 
+use crate::expand::typetree::TypeTree;
 use crate::expand::{Decodable, Encodable, HashStable_Generic};
 use crate::{Ty, TyKind};
 
@@ -84,6 +85,8 @@ pub struct AutoDiffItem {
     /// The name of the function being generated
     pub target: String,
     pub attrs: AutoDiffAttrs,
+    pub inputs: Vec<TypeTree>,
+    pub output: TypeTree,
 }
 
 #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@@ -275,14 +278,22 @@ impl AutoDiffAttrs {
         !matches!(self.mode, DiffMode::Error | DiffMode::Source)
     }
 
-    pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
-        AutoDiffItem { source, target, attrs: self }
+    pub fn into_item(
+        self,
+        source: String,
+        target: String,
+        inputs: Vec<TypeTree>,
+        output: TypeTree,
+    ) -> AutoDiffItem {
+        AutoDiffItem { source, target, inputs, output, attrs: self }
     }
 }
 
 impl fmt::Display for AutoDiffItem {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         write!(f, "Differentiating {} -> {}", self.source, self.target)?;
-        write!(f, " with attributes: {:?}", self.attrs)
+        write!(f, " with attributes: {:?}", self.attrs)?;
+        write!(f, " with inputs: {:?}", self.inputs)?;
+        write!(f, " with output: {:?}", self.output)
     }
 }
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index 78107d95e5a..5ac3a87c158 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -563,6 +563,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
             config::AutoDiff::Enable => {}
             // We handle this below
             config::AutoDiff::NoPostopt => {}
+            // Disables TypeTree generation
+            config::AutoDiff::NoTT => {}
         }
     }
     // This helps with handling enums for now.
diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs
index 7730bddc0f1..837acdadd57 100644
--- a/compiler/rustc_interface/src/tests.rs
+++ b/compiler/rustc_interface/src/tests.rs
@@ -766,6 +766,7 @@ fn test_unstable_options_tracking_hash() {
     tracked!(always_encode_mir, true);
     tracked!(assume_incomplete_release, true);
     tracked!(autodiff, vec![AutoDiff::Enable]);
+    tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]);
     tracked!(binary_dep_depinfo, true);
     tracked!(box_noalias, false);
     tracked!(
diff --git a/compiler/rustc_middle/src/error.rs b/compiler/rustc_middle/src/error.rs
index e3e1393b5f9..ef014af7beb 100644
--- a/compiler/rustc_middle/src/error.rs
+++ b/compiler/rustc_middle/src/error.rs
@@ -37,7 +37,6 @@ pub(crate) struct OpaqueHiddenTypeMismatch<'tcx> {
     pub sub: TypeMismatchReason,
 }
 
-// FIXME(autodiff): I should get used somewhere
 #[derive(Diagnostic)]
 #[diag(middle_unsupported_union)]
 pub struct UnsupportedUnion {
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 0ffef393a33..df42c400317 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *};
 pub use generics::*;
 pub use intrinsic::IntrinsicDef;
 use rustc_abi::{Align, FieldIdx, Integer, IntegerType, ReprFlags, ReprOptions, VariantIdx};
+use rustc_ast::expand::typetree::{FncTree, Kind, Type, TypeTree};
 use rustc_ast::node_id::NodeMap;
 pub use rustc_ast_ir::{Movability, Mutability, try_visit};
 use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet};
@@ -2216,3 +2217,82 @@ pub struct DestructuredConst<'tcx> {
     pub variant: Option<VariantIdx>,
     pub fields: &'tcx [ty::Const<'tcx>],
 }
+
+/// Generate TypeTree information for autodiff.
+/// This function creates TypeTree metadata that describes the memory layout
+/// of function parameters and return types for Enzyme autodiff.
+pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
+    // Check if TypeTrees are disabled via NoTT flag
+    if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) {
+        return FncTree { args: vec![], ret: TypeTree::new() };
+    }
+
+    // Check if this is actually a function type
+    if !fn_ty.is_fn() {
+        return FncTree { args: vec![], ret: TypeTree::new() };
+    }
+
+    // Get the function signature
+    let fn_sig = fn_ty.fn_sig(tcx);
+    let sig = tcx.instantiate_bound_regions_with_erased(fn_sig);
+
+    // Create TypeTrees for each input parameter
+    let mut args = vec![];
+    for ty in sig.inputs().iter() {
+        let type_tree = typetree_from_ty(tcx, *ty);
+        args.push(type_tree);
+    }
+
+    // Create TypeTree for return type
+    let ret = typetree_from_ty(tcx, sig.output());
+
+    FncTree { args, ret }
+}
+
+/// Generate TypeTree for a specific type.
+/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
+fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
+    // Handle basic scalar types
+    if ty.is_scalar() {
+        let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
+            (Kind::Integer, ty.primitive_size(tcx).bytes_usize())
+        } else if ty.is_floating_point() {
+            match ty {
+                x if x == tcx.types.f32 => (Kind::Float, 4),
+                x if x == tcx.types.f64 => (Kind::Double, 8),
+                _ => return TypeTree::new(), // Unknown float type
+            }
+        } else {
+            // TODO(KMJ-007): Handle other scalar types if needed
+            return TypeTree::new();
+        };
+        
+        return TypeTree(vec![Type { 
+            offset: -1, 
+            size, 
+            kind, 
+            child: TypeTree::new() 
+        }]);
+    }
+
+    // Handle references and pointers
+    if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
+        let inner_ty = if let Some(inner) = ty.builtin_deref(true) {
+            inner
+        } else {
+            // TODO(KMJ-007): Handle complex pointer types
+            return TypeTree::new();
+        };
+
+        let child = typetree_from_ty(tcx, inner_ty);
+        return TypeTree(vec![Type {
+            offset: -1,
+            size: 8, // TODO(KMJ-007): Get actual pointer size from target
+            kind: Kind::Pointer,
+            child,
+        }]);
+    }
+
+    // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
+    TypeTree::new()
+}
diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs
index 297df7c2c97..3a035fcf9e8 100644
--- a/compiler/rustc_session/src/config.rs
+++ b/compiler/rustc_session/src/config.rs
@@ -257,6 +257,8 @@ pub enum AutoDiff {
     LooseTypes,
     /// Runs Enzyme's aggressive inlining
     Inline,
+    /// Disable Type Tree
+    NoTT,
 }
 
 /// Settings for `-Z instrument-xray` flag.
diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs
index 69facde6936..ee53dd39bc8 100644
--- a/compiler/rustc_session/src/options.rs
+++ b/compiler/rustc_session/src/options.rs
@@ -792,7 +792,7 @@ mod desc {
     pub(crate) const parse_list: &str = "a space-separated list of strings";
     pub(crate) const parse_list_with_polarity: &str =
         "a comma-separated list of strings, with elements beginning with + or -";
-    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
+    pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`, `NoTT`";
     pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`";
     pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
     pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
@@ -1479,6 +1479,7 @@ pub mod parse {
                 "PrintPasses" => AutoDiff::PrintPasses,
                 "LooseTypes" => AutoDiff::LooseTypes,
                 "Inline" => AutoDiff::Inline,
+                "NoTT" => AutoDiff::NoTT,
                 _ => {
                     // FIXME(ZuseZ4): print an error saying which value is not recognized
                     return false;
diff --git a/tests/codegen-llvm/autodiff/typetree.rs b/tests/codegen-llvm/autodiff/typetree.rs
new file mode 100644
index 00000000000..3ad38d581b9
--- /dev/null
+++ b/tests/codegen-llvm/autodiff/typetree.rs
@@ -0,0 +1,33 @@
+//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
+//@ no-prefer-dynamic
+//@ needs-enzyme
+
+// Test that basic autodiff still works with our TypeTree infrastructure
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+#[autodiff_reverse(d_simple, Duplicated, Active)]
+#[no_mangle]
+#[inline(never)]
+fn simple(x: &f64) -> f64 {
+    2.0 * x
+}
+
+// CHECK-LABEL: @simple
+// CHECK: fmul double
+
+// The derivative function should be generated normally
+// CHECK-LABEL: diffesimple
+// CHECK: fadd fast double
+
+fn main() {
+    let x = std::hint::black_box(3.0);
+    let output = simple(&x);
+    assert_eq!(6.0, output);
+
+    let mut df_dx = 0.0;
+    let output_ = d_simple(&x, &mut df_dx, 1.0);
+    assert_eq!(output, output_);
+    assert_eq!(2.0, df_dx);
+}
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/nott-flag/nott.check b/tests/run-make/autodiff/type-trees/nott-flag/nott.check
new file mode 100644
index 00000000000..56ef2f0bdf3
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/nott-flag/nott.check
@@ -0,0 +1,3 @@
+// TODO(KMJ-007): Update this test when TypeTree integration is complete
+// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{}
+// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double}
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs
new file mode 100644
index 00000000000..536164192dc
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs
@@ -0,0 +1,38 @@
+//@ needs-enzyme
+//@ ignore-cross-compile
+
+use run_make_support::{llvm_filecheck, rfs, rustc};
+
+fn main() {
+    // Test with NoTT flag - should not generate TypeTree metadata
+    let output_nott = rustc()
+        .input("test.rs")
+        .arg("-Zautodiff=Enable,NoTT,PrintTAFn=square")
+        .arg("-Zautodiff=NoPostopt")
+        .opt_level("3")
+        .arg("-Clto=fat")
+        .arg("-g")
+        .run();
+
+    // Write output for NoTT case
+    rfs::write("nott.stdout", output_nott.stdout_utf8());
+    
+    // Test without NoTT flag - should generate TypeTree metadata
+    let output_with_tt = rustc()
+        .input("test.rs")
+        .arg("-Zautodiff=Enable,PrintTAFn=square")
+        .arg("-Zautodiff=NoPostopt")
+        .opt_level("3")
+        .arg("-Clto=fat")
+        .arg("-g")
+        .run();
+
+    // Write output for TypeTree case  
+    rfs::write("with_tt.stdout", output_with_tt.stdout_utf8());
+
+    // Verify NoTT output has minimal TypeTree info
+    llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.stdout")).run();
+    
+    // Verify normal output will have TypeTree info (once implemented)
+    llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.stdout")).run();
+}
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/nott-flag/test.rs b/tests/run-make/autodiff/type-trees/nott-flag/test.rs
new file mode 100644
index 00000000000..5c634eea035
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/nott-flag/test.rs
@@ -0,0 +1,15 @@
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+#[autodiff_reverse(d_square, Duplicated, Active)]
+#[no_mangle]
+fn square(x: &f64) -> f64 {
+    x * x
+}
+
+fn main() {
+    let x = 2.0;
+    let mut dx = 0.0;
+    let _result = d_square(&x, &mut dx, 1.0);
+}
\ No newline at end of file
diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check
new file mode 100644
index 00000000000..56ef2f0bdf3
--- /dev/null
+++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check
@@ -0,0 +1,3 @@
+// TODO(KMJ-007): Update this test when TypeTree integration is complete
+// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{}
+// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double}
\ No newline at end of file
diff --git a/tests/ui/autodiff/flag_nott.rs b/tests/ui/autodiff/flag_nott.rs
new file mode 100644
index 00000000000..7a97d892cd8
--- /dev/null
+++ b/tests/ui/autodiff/flag_nott.rs
@@ -0,0 +1,19 @@
+//@ compile-flags: -Zautodiff=Enable,NoTT
+//@ needs-enzyme
+//@ check-pass
+
+#![feature(autodiff)]
+
+use std::autodiff::autodiff_reverse;
+
+// Test that NoTT flag is accepted and doesn't cause compilation errors
+#[autodiff_reverse(d_square, Duplicated, Active)]
+fn square(x: &f64) -> f64 {
+    x * x
+}
+
+fn main() {
+    let x = 2.0;
+    let mut dx = 0.0;
+    let result = d_square(&x, &mut dx, 1.0);
+}
\ No newline at end of file