about summary refs log tree commit diff
path: root/compiler/rustc_middle/src/ty
diff options
context:
space:
mode:
authorKaran Janthe <karanjanthe@gmail.com>2025-08-23 20:17:32 +0000
committerKaran Janthe <karanjanthe@gmail.com>2025-09-19 04:02:19 +0000
commite1258e79d6cb709b26ded97d32de6c55f355e2aa (patch)
tree120c3432caefd49bd9b4227318c70c2519c346dc /compiler/rustc_middle/src/ty
parent2f4dfc753fd86c672aa4145940db075a8a149f17 (diff)
downloadrust-e1258e79d6cb709b26ded97d32de6c55f355e2aa.tar.gz
rust-e1258e79d6cb709b26ded97d32de6c55f355e2aa.zip
autodiff: Add basic TypeTree with NoTT flag
Signed-off-by: Karan Janthe <karanjanthe@gmail.com>
Diffstat (limited to 'compiler/rustc_middle/src/ty')
-rw-r--r--compiler/rustc_middle/src/ty/mod.rs80
1 files changed, 80 insertions, 0 deletions
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 0ffef393a33..df42c400317 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *};
 pub use generics::*;
 pub use intrinsic::IntrinsicDef;
 use rustc_abi::{Align, FieldIdx, Integer, IntegerType, ReprFlags, ReprOptions, VariantIdx};
+use rustc_ast::expand::typetree::{FncTree, Kind, Type, TypeTree};
 use rustc_ast::node_id::NodeMap;
 pub use rustc_ast_ir::{Movability, Mutability, try_visit};
 use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet};
@@ -2216,3 +2217,82 @@ pub struct DestructuredConst<'tcx> {
     pub variant: Option<VariantIdx>,
     pub fields: &'tcx [ty::Const<'tcx>],
 }
+
+/// Generate TypeTree information for autodiff.
+/// This function creates TypeTree metadata that describes the memory layout
+/// of function parameters and return types for Enzyme autodiff.
+pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree {
+    // Check if TypeTrees are disabled via NoTT flag
+    if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) {
+        return FncTree { args: vec![], ret: TypeTree::new() };
+    }
+
+    // Check if this is actually a function type
+    if !fn_ty.is_fn() {
+        return FncTree { args: vec![], ret: TypeTree::new() };
+    }
+
+    // Get the function signature
+    let fn_sig = fn_ty.fn_sig(tcx);
+    let sig = tcx.instantiate_bound_regions_with_erased(fn_sig);
+
+    // Create TypeTrees for each input parameter
+    let mut args = vec![];
+    for ty in sig.inputs().iter() {
+        let type_tree = typetree_from_ty(tcx, *ty);
+        args.push(type_tree);
+    }
+
+    // Create TypeTree for return type
+    let ret = typetree_from_ty(tcx, sig.output());
+
+    FncTree { args, ret }
+}
+
+/// Generate TypeTree for a specific type.
+/// This function analyzes a Rust type and creates appropriate TypeTree metadata.
+fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
+    // Handle basic scalar types
+    if ty.is_scalar() {
+        let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
+            (Kind::Integer, ty.primitive_size(tcx).bytes_usize())
+        } else if ty.is_floating_point() {
+            match ty {
+                x if x == tcx.types.f32 => (Kind::Float, 4),
+                x if x == tcx.types.f64 => (Kind::Double, 8),
+                _ => return TypeTree::new(), // Unknown float type
+            }
+        } else {
+            // TODO(KMJ-007): Handle other scalar types if needed
+            return TypeTree::new();
+        };
+        
+        return TypeTree(vec![Type { 
+            offset: -1, 
+            size, 
+            kind, 
+            child: TypeTree::new() 
+        }]);
+    }
+
+    // Handle references and pointers
+    if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() {
+        let inner_ty = if let Some(inner) = ty.builtin_deref(true) {
+            inner
+        } else {
+            // TODO(KMJ-007): Handle complex pointer types
+            return TypeTree::new();
+        };
+
+        let child = typetree_from_ty(tcx, inner_ty);
+        return TypeTree(vec![Type {
+            offset: -1,
+            size: 8, // TODO(KMJ-007): Get actual pointer size from target
+            kind: Kind::Pointer,
+            child,
+        }]);
+    }
+
+    // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
+    TypeTree::new()
+}