about summary refs log tree commit diff
path: root/compiler/rustc_monomorphize/src
diff options
context:
space:
mode:
authorMarcelo Domínguez <dmmarcelo27@gmail.com>2025-08-14 15:27:57 +0000
committerMarcelo Domínguez <dmmarcelo27@gmail.com>2025-08-14 16:30:15 +0000
commit250d77e5d72fde69a6406050a3b037635f685378 (patch)
tree67749136fca27852b5fb784c864f7d3564a42a09 /compiler/rustc_monomorphize/src
parent5c631041aa0b0ad9e161b966b78e6dfdb8011023 (diff)
downloadrust-250d77e5d72fde69a6406050a3b037635f685378.tar.gz
rust-250d77e5d72fde69a6406050a3b037635f685378.zip
Complete functionality and general cleanup
Diffstat (limited to 'compiler/rustc_monomorphize/src')
-rw-r--r--compiler/rustc_monomorphize/src/collector.rs5
-rw-r--r--compiler/rustc_monomorphize/src/collector/autodiff.rs48
-rw-r--r--compiler/rustc_monomorphize/src/partitioning.rs34
-rw-r--r--compiler/rustc_monomorphize/src/partitioning/autodiff.rs143
4 files changed, 56 insertions, 174 deletions
diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs
index 26ca8518434..af2c3177067 100644
--- a/compiler/rustc_monomorphize/src/collector.rs
+++ b/compiler/rustc_monomorphize/src/collector.rs
@@ -205,6 +205,8 @@
 //! this is not implemented however: a mono item will be produced
 //! regardless of whether it is actually needed or not.
 
+mod autodiff;
+
 use std::cell::OnceCell;
 
 use rustc_data_structures::fx::FxIndexMap;
@@ -235,6 +237,7 @@ use rustc_span::source_map::{Spanned, dummy_spanned, respan};
 use rustc_span::{DUMMY_SP, Span};
 use tracing::{debug, instrument, trace};
 
+use crate::collector::autodiff::collect_autodiff_fn;
 use crate::errors::{
     self, EncounteredErrorWhileInstantiating, EncounteredErrorWhileInstantiatingGlobalAsm,
     NoOptimizedMir, RecursionLimit,
@@ -911,6 +914,8 @@ fn visit_instance_use<'tcx>(
         return;
     }
     if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) {
+        collect_autodiff_fn(tcx, instance, intrinsic, output);
+
         if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) {
             // The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will
             // be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any
diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs
new file mode 100644
index 00000000000..13868cca944
--- /dev/null
+++ b/compiler/rustc_monomorphize/src/collector/autodiff.rs
@@ -0,0 +1,48 @@
+use rustc_middle::bug;
+use rustc_middle::ty::{self, GenericArg, IntrinsicDef, TyCtxt};
+
+use crate::collector::{MonoItems, create_fn_mono_item};
+
+// Here, we force both primal and diff function to be collected in
+// mono so this does not interfere in `autodiff` intrinsics
+// codegen process. If they are unused, LLVM will remove them when
+// compiling with O3.
+pub(crate) fn collect_autodiff_fn<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    instance: ty::Instance<'tcx>,
+    intrinsic: IntrinsicDef,
+    output: &mut MonoItems<'tcx>,
+) {
+    if intrinsic.name != rustc_span::sym::autodiff {
+        return;
+    };
+
+    collect_autodiff_fn_from_arg(instance.args[0], tcx, output);
+}
+
+fn collect_autodiff_fn_from_arg<'tcx>(
+    arg: GenericArg<'tcx>,
+    tcx: TyCtxt<'tcx>,
+    output: &mut MonoItems<'tcx>,
+) {
+    let (instance, span) = match arg.kind() {
+        ty::GenericArgKind::Type(ty) => match ty.kind() {
+            ty::FnDef(def_id, substs) => {
+                let span = tcx.def_span(def_id);
+                let instance = ty::Instance::expect_resolve(
+                    tcx,
+                    ty::TypingEnv::non_body_analysis(tcx, def_id),
+                    *def_id,
+                    substs,
+                    span,
+                );
+
+                (instance, span)
+            }
+            _ => bug!("expected autodiff function"),
+        },
+        _ => bug!("expected type when matching autodiff arg"),
+    };
+
+    output.push(create_fn_mono_item(tcx, instance, span));
+}
diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs
index 628ea2b63de..d784d3540c4 100644
--- a/compiler/rustc_monomorphize/src/partitioning.rs
+++ b/compiler/rustc_monomorphize/src/partitioning.rs
@@ -92,8 +92,6 @@
 //! source-level module, functions from the same module will be available for
 //! inlining, even when they are not marked `#[inline]`.
 
-mod autodiff;
-
 use std::cmp;
 use std::collections::hash_map::Entry;
 use std::fs::{self, File};
@@ -251,17 +249,7 @@ where
             always_export_generics,
         );
 
-        // We can't differentiate a function that got inlined.
-        let autodiff_active = cfg!(llvm_enzyme)
-            && matches!(mono_item, MonoItem::Fn(_))
-            && cx
-                .tcx
-                .codegen_fn_attrs(mono_item.def_id())
-                .autodiff_item
-                .as_ref()
-                .is_some_and(|ad| ad.is_active());
-
-        if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
+        if visibility == Visibility::Hidden && can_be_internalized {
             internalization_candidates.insert(mono_item);
         }
         let size_estimate = mono_item.size_estimate(cx.tcx);
@@ -1157,27 +1145,15 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
         }
     }
 
-    #[cfg(not(llvm_enzyme))]
-    let autodiff_mono_items: Vec<_> = vec![];
-    #[cfg(llvm_enzyme)]
-    let mut autodiff_mono_items: Vec<_> = vec![];
     let mono_items: DefIdSet = items
         .iter()
         .filter_map(|mono_item| match *mono_item {
-            MonoItem::Fn(ref instance) => {
-                #[cfg(llvm_enzyme)]
-                autodiff_mono_items.push((mono_item, instance));
-                Some(instance.def_id())
-            }
+            MonoItem::Fn(ref instance) => Some(instance.def_id()),
             MonoItem::Static(def_id) => Some(def_id),
             _ => None,
         })
         .collect();
 
-    let autodiff_items =
-        autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items);
-    let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items);
-
     // Output monomorphization stats per def_id
     if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats
         && let Err(err) =
@@ -1235,11 +1211,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio
         }
     }
 
-    MonoItemPartitions {
-        all_mono_items: tcx.arena.alloc(mono_items),
-        codegen_units,
-        autodiff_items,
-    }
+    MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units }
 }
 
 /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s
diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs
deleted file mode 100644
index 22d593b80b8..00000000000
--- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs
+++ /dev/null
@@ -1,143 +0,0 @@
-use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
-use rustc_hir::def_id::LOCAL_CRATE;
-use rustc_middle::bug;
-use rustc_middle::mir::mono::MonoItem;
-use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
-use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
-use tracing::{debug, trace};
-
-use crate::partitioning::UsageMap;
-
-fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
-    if !matches!(fn_ty.kind(), ty::FnDef(..)) {
-        bug!("expected fn def for autodiff, got {:?}", fn_ty);
-    }
-
-    // We don't actually pass the types back into the type system.
-    // All we do is decide how to handle the arguments.
-    let sig = fn_ty.fn_sig(tcx).skip_binder();
-
-    let mut new_activities = vec![];
-    let mut new_positions = vec![];
-    for (i, ty) in sig.inputs().iter().enumerate() {
-        if let Some(inner_ty) = ty.builtin_deref(true) {
-            if inner_ty.is_slice() {
-                // Now we need to figure out the size of each slice element in memory to allow
-                // safety checks and usability improvements in the backend.
-                let sty = match inner_ty.builtin_index() {
-                    Some(sty) => sty,
-                    None => {
-                        panic!("slice element type unknown");
-                    }
-                };
-                let pci = PseudoCanonicalInput {
-                    typing_env: TypingEnv::fully_monomorphized(),
-                    value: sty,
-                };
-
-                let layout = tcx.layout_of(pci);
-                let elem_size = match layout {
-                    Ok(layout) => layout.size,
-                    Err(_) => {
-                        bug!("autodiff failed to compute slice element size");
-                    }
-                };
-                let elem_size: u32 = elem_size.bytes() as u32;
-
-                // We know that the length will be passed as extra arg.
-                if !da.is_empty() {
-                    // We are looking at a slice. The length of that slice will become an
-                    // extra integer on llvm level. Integers are always const.
-                    // However, if the slice get's duplicated, we want to know to later check the
-                    // size. So we mark the new size argument as FakeActivitySize.
-                    // There is one FakeActivitySize per slice, so for convenience we store the
-                    // slice element size in bytes in it. We will use the size in the backend.
-                    let activity = match da[i] {
-                        DiffActivity::DualOnly
-                        | DiffActivity::Dual
-                        | DiffActivity::Dualv
-                        | DiffActivity::DuplicatedOnly
-                        | DiffActivity::Duplicated => {
-                            DiffActivity::FakeActivitySize(Some(elem_size))
-                        }
-                        DiffActivity::Const => DiffActivity::Const,
-                        _ => bug!("unexpected activity for ptr/ref"),
-                    };
-                    new_activities.push(activity);
-                    new_positions.push(i + 1);
-                }
-
-                continue;
-            }
-        }
-    }
-    // now add the extra activities coming from slices
-    // Reverse order to not invalidate the indices
-    for _ in 0..new_activities.len() {
-        let pos = new_positions.pop().unwrap();
-        let activity = new_activities.pop().unwrap();
-        da.insert(pos, activity);
-    }
-}
-
-pub(crate) fn find_autodiff_source_functions<'tcx>(
-    tcx: TyCtxt<'tcx>,
-    usage_map: &UsageMap<'tcx>,
-    autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
-) -> Vec<AutoDiffItem> {
-    let mut autodiff_items: Vec<AutoDiffItem> = vec![];
-    for (item, instance) in autodiff_mono_items {
-        let target_id = instance.def_id();
-        let cg_fn_attr = &tcx.codegen_fn_attrs(target_id).autodiff_item;
-        let Some(target_attrs) = cg_fn_attr else {
-            continue;
-        };
-        let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
-        if target_attrs.is_source() {
-            trace!("source found: {:?}", target_id);
-        }
-        if !target_attrs.apply_autodiff() {
-            continue;
-        }
-
-        let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
-
-        let source =
-            usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
-                MonoItem::Fn(ref instance_s) => {
-                    let source_id = instance_s.def_id();
-                    if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
-                        && ad.is_active()
-                    {
-                        return Some(instance_s);
-                    }
-                    None
-                }
-                _ => None,
-            });
-        let inst = match source {
-            Some(source) => source,
-            None => continue,
-        };
-
-        debug!("source_id: {:?}", inst.def_id());
-        let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
-        assert!(fn_ty.is_fn());
-        adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
-        let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
-
-        let mut new_target_attrs = target_attrs.clone();
-        new_target_attrs.input_activity = input_activities;
-        let itm = new_target_attrs.into_item(symb, target_symbol);
-        autodiff_items.push(itm);
-    }
-
-    if !autodiff_items.is_empty() {
-        trace!("AUTODIFF ITEMS EXIST");
-        for item in &mut *autodiff_items {
-            trace!("{}", &item);
-        }
-    }
-
-    autodiff_items
-}