about summary refs log tree commit diff
path: root/compiler/rustc_monomorphize/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_monomorphize/src')
-rw-r--r--compiler/rustc_monomorphize/src/collector.rs2
-rw-r--r--compiler/rustc_monomorphize/src/partitioning.rs32
-rw-r--r--compiler/rustc_monomorphize/src/partitioning/autodiff.rs121
3 files changed, 152 insertions, 3 deletions
diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs
index 150594ab94d..9a4b9fda3eb 100644
--- a/compiler/rustc_monomorphize/src/collector.rs
+++ b/compiler/rustc_monomorphize/src/collector.rs
@@ -257,7 +257,7 @@ struct SharedState<'tcx> {
 
 pub(crate) struct UsageMap<'tcx> {
     // Maps every mono item to the mono items used by it.
-    used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
+    pub used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
 
     // Maps every mono item to the mono items that use it.
     user_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs
index e08c348a64d..c985ea04278 100644
--- a/compiler/rustc_monomorphize/src/partitioning.rs
+++ b/compiler/rustc_monomorphize/src/partitioning.rs
@@ -92,6 +92,8 @@
 //! source-level module, functions from the same module will be available for
 //! inlining, even when they are not marked `#[inline]`.
 
+mod autodiff;
+
 use std::cmp;
 use std::collections::hash_map::Entry;
 use std::fs::{self, File};
@@ -251,7 +253,17 @@ where
             can_export_generics,
             always_export_generics,
         );
-        if visibility == Visibility::Hidden && can_be_internalized {
+
+        // We can't differentiate something that got inlined.
+        let autodiff_active = cfg!(llvm_enzyme)
+            && cx
+                .tcx
+                .codegen_fn_attrs(mono_item.def_id())
+                .autodiff_item
+                .as_ref()
+                .is_some_and(|ad| ad.is_active());
+
+        if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
             internalization_candidates.insert(mono_item);
         }
         let size_estimate = mono_item.size_estimate(cx.tcx);
@@ -1176,6 +1188,18 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
         })
         .collect();
 
+    let autodiff_mono_items: Vec<_> = items
+        .iter()
+        .filter_map(|item| match *item {
+            MonoItem::Fn(ref instance) => Some((item, instance)),
+            _ => None,
+        })
+        .collect();
+
+    let autodiff_items =
+        autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items);
+    let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items);
+
     // Output monomorphization stats per def_id
     if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats {
         if let Err(err) =
@@ -1236,7 +1260,11 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
         }
     }
 
-    MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units }
+    MonoItemPartitions {
+        all_mono_items: tcx.arena.alloc(mono_items),
+        codegen_units,
+        autodiff_items,
+    }
 }
 
 /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s
diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs
new file mode 100644
index 00000000000..bce31bf0748
--- /dev/null
+++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs
@@ -0,0 +1,121 @@
+use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
+use rustc_hir::def_id::LOCAL_CRATE;
+use rustc_middle::bug;
+use rustc_middle::mir::mono::MonoItem;
+use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
+use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
+use tracing::{debug, trace};
+
+use crate::partitioning::UsageMap;
+
+fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
+    if !matches!(fn_ty.kind(), ty::FnDef(..)) {
+        bug!("expected fn def for autodiff, got {:?}", fn_ty);
+    }
+    let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);
+
+    // If rustc compiles the unmodified primal, we know that this copy of the function
+    // also has correct lifetimes. We know that Enzyme won't free the shadow too early
+    // (or actually at all), so let's strip lifetimes when computing the layout.
+    let x = tcx.instantiate_bound_regions_with_erased(fnc_binder);
+    let mut new_activities = vec![];
+    let mut new_positions = vec![];
+    for (i, ty) in x.inputs().iter().enumerate() {
+        if let Some(inner_ty) = ty.builtin_deref(true) {
+            if ty.is_fn_ptr() {
+                // FIXME(ZuseZ4): add a nicer error, or just figure out how to support them,
+                // since Enzyme itself can handle them.
+                tcx.dcx().err("function pointers are currently not supported in autodiff");
+            }
+            if inner_ty.is_slice() {
+                // We know that the length will be passed as extra arg.
+                if !da.is_empty() {
+                    // We are looking at a slice. The length of that slice will become an
+                    // extra integer on llvm level. Integers are always const.
+                    // However, if the slice get's duplicated, we want to know to later check the
+                    // size. So we mark the new size argument as FakeActivitySize.
+                    let activity = match da[i] {
+                        DiffActivity::DualOnly
+                        | DiffActivity::Dual
+                        | DiffActivity::DuplicatedOnly
+                        | DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
+                        DiffActivity::Const => DiffActivity::Const,
+                        _ => bug!("unexpected activity for ptr/ref"),
+                    };
+                    new_activities.push(activity);
+                    new_positions.push(i + 1);
+                }
+                continue;
+            }
+        }
+    }
+    // now add the extra activities coming from slices
+    // Reverse order to not invalidate the indices
+    for _ in 0..new_activities.len() {
+        let pos = new_positions.pop().unwrap();
+        let activity = new_activities.pop().unwrap();
+        da.insert(pos, activity);
+    }
+}
+
+pub(crate) fn find_autodiff_source_functions<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    usage_map: &UsageMap<'tcx>,
+    autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
+) -> Vec<AutoDiffItem> {
+    let mut autodiff_items: Vec<AutoDiffItem> = vec![];
+    for (item, instance) in autodiff_mono_items {
+        let target_id = instance.def_id();
+        let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone();
+        let Some(target_attrs) = cg_fn_attr else {
+            continue;
+        };
+        let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
+        if target_attrs.is_source() {
+            trace!("source found: {:?}", target_id);
+        }
+        if !target_attrs.apply_autodiff() {
+            continue;
+        }
+
+        let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
+
+        let source =
+            usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
+                MonoItem::Fn(ref instance_s) => {
+                    let source_id = instance_s.def_id();
+                    if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
+                        && ad.is_active()
+                    {
+                        return Some(instance_s);
+                    }
+                    None
+                }
+                _ => None,
+            });
+        let inst = match source {
+            Some(source) => source,
+            None => continue,
+        };
+
+        debug!("source_id: {:?}", inst.def_id());
+        let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
+        assert!(fn_ty.is_fn());
+        adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
+        let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
+
+        let mut new_target_attrs = target_attrs.clone();
+        new_target_attrs.input_activity = input_activities;
+        let itm = new_target_attrs.into_item(symb, target_symbol);
+        autodiff_items.push(itm);
+    }
+
+    if !autodiff_items.is_empty() {
+        trace!("AUTODIFF ITEMS EXIST");
+        for item in &mut *autodiff_items {
+            trace!("{}", &item);
+        }
+    }
+
+    autodiff_items
+}