diff options
| author | Manuel Drehwald <git@manuel.drehwald.info> | 2025-01-29 21:31:13 -0500 |
|---|---|---|
| committer | Manuel Drehwald <git@manuel.drehwald.info> | 2025-01-29 21:31:13 -0500 |
| commit | 1f30517d40a9a8fe3b89479891c7a167adb75cbd (patch) | |
| tree | 4da71f0331d185e202236924a854190f913815c7 /compiler/rustc_monomorphize | |
| parent | ebcf860e7345e3387b4c6961338c77424b43cbd5 (diff) | |
| download | rust-1f30517d40a9a8fe3b89479891c7a167adb75cbd.tar.gz rust-1f30517d40a9a8fe3b89479891c7a167adb75cbd.zip | |
upstream rustc_codegen_ssa/rustc_middle changes for enzyme/autodiff
Diffstat (limited to 'compiler/rustc_monomorphize')
| -rw-r--r-- | compiler/rustc_monomorphize/Cargo.toml | 2 | ||||
| -rw-r--r-- | compiler/rustc_monomorphize/src/collector.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_monomorphize/src/partitioning.rs | 32 | ||||
| -rw-r--r-- | compiler/rustc_monomorphize/src/partitioning/autodiff.rs | 121 |
4 files changed, 154 insertions, 3 deletions
diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index 9bdaeb015cd..5462105e5e8 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] # tidy-alphabetical-start rustc_abi = { path = "../rustc_abi" } +rustc_ast = { path = "../rustc_ast" } rustc_attr_parsing = { path = "../rustc_attr_parsing" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } @@ -15,6 +16,7 @@ rustc_macros = { path = "../rustc_macros" } rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } +rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } rustc_target = { path = "../rustc_target" } serde = "1" serde_json = "1" diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index bb603df1129..f3b5488e975 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -257,7 +257,7 @@ struct SharedState<'tcx> { pub(crate) struct UsageMap<'tcx> { // Maps every mono item to the mono items used by it. - used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>, + pub used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>, // Maps every mono item to the mono items that use it. user_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>, diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index e08c348a64d..c985ea04278 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -92,6 +92,8 @@ //! source-level module, functions from the same module will be available for //! inlining, even when they are not marked `#[inline]`. +mod autodiff; + use std::cmp; use std::collections::hash_map::Entry; use std::fs::{self, File}; @@ -251,7 +253,17 @@ where can_export_generics, always_export_generics, ); - if visibility == Visibility::Hidden && can_be_internalized { + + // We can't differentiate something that got inlined. + let autodiff_active = cfg!(llvm_enzyme) + && cx + .tcx + .codegen_fn_attrs(mono_item.def_id()) + .autodiff_item + .as_ref() + .is_some_and(|ad| ad.is_active()); + + if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); } let size_estimate = mono_item.size_estimate(cx.tcx); @@ -1176,6 +1188,18 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio }) .collect(); + let autodiff_mono_items: Vec<_> = items + .iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance) => Some((item, instance)), + _ => None, + }) + .collect(); + + let autodiff_items = + autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items); + let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); + // Output monomorphization stats per def_id if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats { if let Err(err) = @@ -1236,7 +1260,11 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio } } - MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units } + MonoItemPartitions { + all_mono_items: tcx.arena.alloc(mono_items), + codegen_units, + autodiff_items, + } } /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs new file mode 100644 index 00000000000..bce31bf0748 --- /dev/null +++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs @@ -0,0 +1,121 @@ +use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity}; +use rustc_hir::def_id::LOCAL_CRATE; +use rustc_middle::bug; +use rustc_middle::mir::mono::MonoItem; +use rustc_middle::ty::{self, Instance, Ty, TyCtxt}; +use rustc_symbol_mangling::symbol_name_for_instance_in_crate; +use tracing::{debug, trace}; + +use crate::partitioning::UsageMap; + +fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) { + if !matches!(fn_ty.kind(), ty::FnDef(..)) { + bug!("expected fn def for autodiff, got {:?}", fn_ty); + } + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // If rustc compiles the unmodified primal, we know that this copy of the function + // also has correct lifetimes. We know that Enzyme won't free the shadow too early + // (or actually at all), so let's strip lifetimes when computing the layout. + let x = tcx.instantiate_bound_regions_with_erased(fnc_binder); + let mut new_activities = vec![]; + let mut new_positions = vec![]; + for (i, ty) in x.inputs().iter().enumerate() { + if let Some(inner_ty) = ty.builtin_deref(true) { + if ty.is_fn_ptr() { + // FIXME(ZuseZ4): add a nicer error, or just figure out how to support them, + // since Enzyme itself can handle them. + tcx.dcx().err("function pointers are currently not supported in autodiff"); + } + if inner_ty.is_slice() { + // We know that the length will be passed as extra arg. + if !da.is_empty() { + // We are looking at a slice. The length of that slice will become an + // extra integer on llvm level. Integers are always const. + // However, if the slice get's duplicated, we want to know to later check the + // size. So we mark the new size argument as FakeActivitySize. + let activity = match da[i] { + DiffActivity::DualOnly + | DiffActivity::Dual + | DiffActivity::DuplicatedOnly + | DiffActivity::Duplicated => DiffActivity::FakeActivitySize, + DiffActivity::Const => DiffActivity::Const, + _ => bug!("unexpected activity for ptr/ref"), + }; + new_activities.push(activity); + new_positions.push(i + 1); + } + continue; + } + } + } + // now add the extra activities coming from slices + // Reverse order to not invalidate the indices + for _ in 0..new_activities.len() { + let pos = new_positions.pop().unwrap(); + let activity = new_activities.pop().unwrap(); + da.insert(pos, activity); + } +} + +pub(crate) fn find_autodiff_source_functions<'tcx>( + tcx: TyCtxt<'tcx>, + usage_map: &UsageMap<'tcx>, + autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>, +) -> Vec<AutoDiffItem> { + let mut autodiff_items: Vec<AutoDiffItem> = vec![]; + for (item, instance) in autodiff_mono_items { + let target_id = instance.def_id(); + let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone(); + let Some(target_attrs) = cg_fn_attr else { + continue; + }; + let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone(); + if target_attrs.is_source() { + trace!("source found: {:?}", target_id); + } + if !target_attrs.apply_autodiff() { + continue; + } + + let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + + let source = + usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item { + MonoItem::Fn(ref instance_s) => { + let source_id = instance_s.def_id(); + if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item + && ad.is_active() + { + return Some(instance_s); + } + None + } + _ => None, + }); + let inst = match source { + Some(source) => source, + None => continue, + }; + + debug!("source_id: {:?}", inst.def_id()); + let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized()); + assert!(fn_ty.is_fn()); + adjust_activity_to_abi(tcx, fn_ty, &mut input_activities); + let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); + + let mut new_target_attrs = target_attrs.clone(); + new_target_attrs.input_activity = input_activities; + let itm = new_target_attrs.into_item(symb, target_symbol); + autodiff_items.push(itm); + } + + if !autodiff_items.is_empty() { + trace!("AUTODIFF ITEMS EXIST"); + for item in &mut *autodiff_items { + trace!("{}", &item); + } + } + + autodiff_items +} |
