diff options
| author | Karan Janthe <karanjanthe@gmail.com> | 2025-08-23 20:17:32 +0000 |
|---|---|---|
| committer | Karan Janthe <karanjanthe@gmail.com> | 2025-09-19 04:02:19 +0000 |
| commit | e1258e79d6cb709b26ded97d32de6c55f355e2aa (patch) | |
| tree | 120c3432caefd49bd9b4227318c70c2519c346dc | |
| parent | 2f4dfc753fd86c672aa4145940db075a8a149f17 (diff) | |
| download | rust-e1258e79d6cb709b26ded97d32de6c55f355e2aa.tar.gz rust-e1258e79d6cb709b26ded97d32de6c55f355e2aa.zip | |
autodiff: Add basic TypeTree with NoTT flag
Signed-off-by: Karan Janthe <karanjanthe@gmail.com>
| -rw-r--r-- | compiler/rustc_ast/src/expand/autodiff_attrs.rs | 17 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/lto.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_interface/src/tests.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/error.rs | 1 | ||||
| -rw-r--r-- | compiler/rustc_middle/src/ty/mod.rs | 80 | ||||
| -rw-r--r-- | compiler/rustc_session/src/config.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_session/src/options.rs | 3 | ||||
| -rw-r--r-- | tests/codegen-llvm/autodiff/typetree.rs | 33 | ||||
| -rw-r--r-- | tests/run-make/autodiff/type-trees/nott-flag/nott.check | 3 | ||||
| -rw-r--r-- | tests/run-make/autodiff/type-trees/nott-flag/rmake.rs | 38 | ||||
| -rw-r--r-- | tests/run-make/autodiff/type-trees/nott-flag/test.rs | 15 | ||||
| -rw-r--r-- | tests/run-make/autodiff/type-trees/nott-flag/with_tt.check | 3 | ||||
| -rw-r--r-- | tests/ui/autodiff/flag_nott.rs | 19 |
13 files changed, 212 insertions, 5 deletions
diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 33451f99748..90f15753e99 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -6,6 +6,7 @@ use std::fmt::{self, Display, Formatter}; use std::str::FromStr; +use crate::expand::typetree::TypeTree; use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::{Ty, TyKind}; @@ -84,6 +85,8 @@ pub struct AutoDiffItem { /// The name of the function being generated pub target: String, pub attrs: AutoDiffAttrs, + pub inputs: Vec<TypeTree>, + pub output: TypeTree, } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] @@ -275,14 +278,22 @@ impl AutoDiffAttrs { !matches!(self.mode, DiffMode::Error | DiffMode::Source) } - pub fn into_item(self, source: String, target: String) -> AutoDiffItem { - AutoDiffItem { source, target, attrs: self } + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec<TypeTree>, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } } } impl fmt::Display for AutoDiffItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Differentiating {} -> {}", self.source, self.target)?; - write!(f, " with attributes: {:?}", self.attrs) + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with inputs: {:?}", self.inputs)?; + write!(f, " with output: {:?}", self.output) } } diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 78107d95e5a..5ac3a87c158 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -563,6 +563,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) { config::AutoDiff::Enable => {} // We handle this below config::AutoDiff::NoPostopt => {} + // Disables TypeTree generation + config::AutoDiff::NoTT => {} } } // This helps with handling enums for now. diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 7730bddc0f1..837acdadd57 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -766,6 +766,7 @@ fn test_unstable_options_tracking_hash() { tracked!(always_encode_mir, true); tracked!(assume_incomplete_release, true); tracked!(autodiff, vec![AutoDiff::Enable]); + tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]); tracked!(binary_dep_depinfo, true); tracked!(box_noalias, false); tracked!( diff --git a/compiler/rustc_middle/src/error.rs b/compiler/rustc_middle/src/error.rs index e3e1393b5f9..ef014af7beb 100644 --- a/compiler/rustc_middle/src/error.rs +++ b/compiler/rustc_middle/src/error.rs @@ -37,7 +37,6 @@ pub(crate) struct OpaqueHiddenTypeMismatch<'tcx> { pub sub: TypeMismatchReason, } -// FIXME(autodiff): I should get used somewhere #[derive(Diagnostic)] #[diag(middle_unsupported_union)] pub struct UnsupportedUnion { diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 0ffef393a33..df42c400317 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -25,6 +25,7 @@ pub use generic_args::{GenericArgKind, TermKind, *}; pub use generics::*; pub use intrinsic::IntrinsicDef; use rustc_abi::{Align, FieldIdx, Integer, IntegerType, ReprFlags, ReprOptions, VariantIdx}; +use rustc_ast::expand::typetree::{FncTree, Kind, Type, TypeTree}; use rustc_ast::node_id::NodeMap; pub use rustc_ast_ir::{Movability, Mutability, try_visit}; use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap, FxIndexSet}; @@ -2216,3 +2217,82 @@ pub struct DestructuredConst<'tcx> { pub variant: Option<VariantIdx>, pub fields: &'tcx [ty::Const<'tcx>], } + +/// Generate TypeTree information for autodiff. +/// This function creates TypeTree metadata that describes the memory layout +/// of function parameters and return types for Enzyme autodiff. +pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { + // Check if TypeTrees are disabled via NoTT flag + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + + // Check if this is actually a function type + if !fn_ty.is_fn() { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + + // Get the function signature + let fn_sig = fn_ty.fn_sig(tcx); + let sig = tcx.instantiate_bound_regions_with_erased(fn_sig); + + // Create TypeTrees for each input parameter + let mut args = vec![]; + for ty in sig.inputs().iter() { + let type_tree = typetree_from_ty(tcx, *ty); + args.push(type_tree); + } + + // Create TypeTree for return type + let ret = typetree_from_ty(tcx, sig.output()); + + FncTree { args, ret } +} + +/// Generate TypeTree for a specific type. +/// This function analyzes a Rust type and creates appropriate TypeTree metadata. +fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + // Handle basic scalar types + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => return TypeTree::new(), // Unknown float type + } + } else { + // TODO(KMJ-007): Handle other scalar types if needed + return TypeTree::new(); + }; + + return TypeTree(vec![Type { + offset: -1, + size, + kind, + child: TypeTree::new() + }]); + } + + // Handle references and pointers + if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { + let inner_ty = if let Some(inner) = ty.builtin_deref(true) { + inner + } else { + // TODO(KMJ-007): Handle complex pointer types + return TypeTree::new(); + }; + + let child = typetree_from_ty(tcx, inner_ty); + return TypeTree(vec![Type { + offset: -1, + size: 8, // TODO(KMJ-007): Get actual pointer size from target + kind: Kind::Pointer, + child, + }]); + } + + // TODO(KMJ-007): Handle arrays, slices, structs, and other complex types + TypeTree::new() +} diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 297df7c2c97..3a035fcf9e8 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -257,6 +257,8 @@ pub enum AutoDiff { LooseTypes, /// Runs Enzyme's aggressive inlining Inline, + /// Disable Type Tree + NoTT, } /// Settings for `-Z instrument-xray` flag. diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 69facde6936..ee53dd39bc8 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -792,7 +792,7 @@ mod desc { pub(crate) const parse_list: &str = "a space-separated list of strings"; pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; - pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`"; + pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`, `NoTT`"; pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; @@ -1479,6 +1479,7 @@ pub mod parse { "PrintPasses" => AutoDiff::PrintPasses, "LooseTypes" => AutoDiff::LooseTypes, "Inline" => AutoDiff::Inline, + "NoTT" => AutoDiff::NoTT, _ => { // FIXME(ZuseZ4): print an error saying which value is not recognized return false; diff --git a/tests/codegen-llvm/autodiff/typetree.rs b/tests/codegen-llvm/autodiff/typetree.rs new file mode 100644 index 00000000000..3ad38d581b9 --- /dev/null +++ b/tests/codegen-llvm/autodiff/typetree.rs @@ -0,0 +1,33 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// Test that basic autodiff still works with our TypeTree infrastructure +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_simple, Duplicated, Active)] +#[no_mangle] +#[inline(never)] +fn simple(x: &f64) -> f64 { + 2.0 * x +} + +// CHECK-LABEL: @simple +// CHECK: fmul double + +// The derivative function should be generated normally +// CHECK-LABEL: diffesimple +// CHECK: fadd fast double + +fn main() { + let x = std::hint::black_box(3.0); + let output = simple(&x); + assert_eq!(6.0, output); + + let mut df_dx = 0.0; + let output_ = d_simple(&x, &mut df_dx, 1.0); + assert_eq!(output, output_); + assert_eq!(2.0, df_dx); +} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/nott.check b/tests/run-make/autodiff/type-trees/nott-flag/nott.check new file mode 100644 index 00000000000..56ef2f0bdf3 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/nott.check @@ -0,0 +1,3 @@ +// TODO(KMJ-007): Update this test when TypeTree integration is complete +// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{} +// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs new file mode 100644 index 00000000000..536164192dc --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/rmake.rs @@ -0,0 +1,38 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Test with NoTT flag - should not generate TypeTree metadata + let output_nott = rustc() + .input("test.rs") + .arg("-Zautodiff=Enable,NoTT,PrintTAFn=square") + .arg("-Zautodiff=NoPostopt") + .opt_level("3") + .arg("-Clto=fat") + .arg("-g") + .run(); + + // Write output for NoTT case + rfs::write("nott.stdout", output_nott.stdout_utf8()); + + // Test without NoTT flag - should generate TypeTree metadata + let output_with_tt = rustc() + .input("test.rs") + .arg("-Zautodiff=Enable,PrintTAFn=square") + .arg("-Zautodiff=NoPostopt") + .opt_level("3") + .arg("-Clto=fat") + .arg("-g") + .run(); + + // Write output for TypeTree case + rfs::write("with_tt.stdout", output_with_tt.stdout_utf8()); + + // Verify NoTT output has minimal TypeTree info + llvm_filecheck().patterns("nott.check").stdin_buf(rfs::read("nott.stdout")).run(); + + // Verify normal output will have TypeTree info (once implemented) + llvm_filecheck().patterns("with_tt.check").stdin_buf(rfs::read("with_tt.stdout")).run(); +} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/test.rs b/tests/run-make/autodiff/type-trees/nott-flag/test.rs new file mode 100644 index 00000000000..5c634eea035 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/test.rs @@ -0,0 +1,15 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +#[autodiff_reverse(d_square, Duplicated, Active)] +#[no_mangle] +fn square(x: &f64) -> f64 { + x * x +} + +fn main() { + let x = 2.0; + let mut dx = 0.0; + let _result = d_square(&x, &mut dx, 1.0); +} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check new file mode 100644 index 00000000000..56ef2f0bdf3 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/nott-flag/with_tt.check @@ -0,0 +1,3 @@ +// TODO(KMJ-007): Update this test when TypeTree integration is complete +// CHECK: square - {[-1]:Float@double} |{[-1]:Pointer}:{} +// CHECK: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double} \ No newline at end of file diff --git a/tests/ui/autodiff/flag_nott.rs b/tests/ui/autodiff/flag_nott.rs new file mode 100644 index 00000000000..7a97d892cd8 --- /dev/null +++ b/tests/ui/autodiff/flag_nott.rs @@ -0,0 +1,19 @@ +//@ compile-flags: -Zautodiff=Enable,NoTT +//@ needs-enzyme +//@ check-pass + +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; + +// Test that NoTT flag is accepted and doesn't cause compilation errors +#[autodiff_reverse(d_square, Duplicated, Active)] +fn square(x: &f64) -> f64 { + x * x +} + +fn main() { + let x = 2.0; + let mut dx = 0.0; + let result = d_square(&x, &mut dx, 1.0); +} \ No newline at end of file |
