diff options
| author | Karan Janthe <karanjanthe@gmail.com> | 2025-08-23 20:17:32 +0000 |
|---|---|---|
| committer | Karan Janthe <karanjanthe@gmail.com> | 2025-09-19 04:02:19 +0000 |
| commit | e1258e79d6cb709b26ded97d32de6c55f355e2aa (patch) | |
| tree | 120c3432caefd49bd9b4227318c70c2519c346dc /compiler/rustc_middle/src/ty | |
| parent | 2f4dfc753fd86c672aa4145940db075a8a149f17 (diff) | |
| download | rust-e1258e79d6cb709b26ded97d32de6c55f355e2aa.tar.gz rust-e1258e79d6cb709b26ded97d32de6c55f355e2aa.zip | |
autodiff: Add basic TypeTree with NoTT flag
Signed-off-by: Karan Janthe <karanjanthe@gmail.com>
Diffstat (limited to 'compiler/rustc_middle/src/ty')
| -rw-r--r-- | compiler/rustc_middle/src/ty/mod.rs | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 0ffef393a33..df42c400317 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *}; pub use generics::*; pub use intrinsic::IntrinsicDef; use rustc_abi::{Align, FieldIdx, Integer, IntegerType, ReprFlags, ReprOptions, VariantIdx}; +use rustc_ast::expand::typetree::{FncTree, Kind, Type, TypeTree}; use rustc_ast::node_id::NodeMap; pub use rustc_ast_ir::{Movability, Mutability, try_visit}; use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet}; @@ -2216,3 +2217,82 @@ pub struct DestructuredConst<'tcx> { pub variant: Option<VariantIdx>, pub fields: &'tcx [ty::Const<'tcx>], } + +/// Generate TypeTree information for autodiff. +/// This function creates TypeTree metadata that describes the memory layout +/// of function parameters and return types for Enzyme autodiff. +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { + // Check if TypeTrees are disabled via NoTT flag + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + + // Check if this is actually a function type + if !fn_ty.is_fn() { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + + // Get the function signature + let fn_sig = fn_ty.fn_sig(tcx); + let sig = tcx.instantiate_bound_regions_with_erased(fn_sig); + + // Create TypeTrees for each input parameter + let mut args = vec![]; + for ty in sig.inputs().iter() { + let type_tree = typetree_from_ty(tcx, *ty); + args.push(type_tree); + } + + // Create TypeTree for return type + let ret = typetree_from_ty(tcx, sig.output()); + + FncTree { args, ret } +} + +/// Generate TypeTree for a specific type. +/// This function analyzes a Rust type and creates appropriate TypeTree metadata. +fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + // Handle basic scalar types + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => return TypeTree::new(), // Unknown float type + } + } else { + // TODO(KMJ-007): Handle other scalar types if needed + return TypeTree::new(); + }; + + return TypeTree(vec![Type { + offset: -1, + size, + kind, + child: TypeTree::new() + }]); + } + + // Handle references and pointers + if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { + let inner_ty = if let Some(inner) = ty.builtin_deref(true) { + inner + } else { + // TODO(KMJ-007): Handle complex pointer types + return TypeTree::new(); + }; + + let child = typetree_from_ty(tcx, inner_ty); + return TypeTree(vec![Type { + offset: -1, + size: 8, // TODO(KMJ-007): Get actual pointer size from target + kind: Kind::Pointer, + child, + }]); + } + + // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types + TypeTree::new() +} |
