about summary refs log tree commit diff
path: root/compiler/rustc_monomorphize
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_monomorphize')
-rw-r--r--compiler/rustc_monomorphize/Cargo.toml2
-rw-r--r--compiler/rustc_monomorphize/src/collector.rs36
-rw-r--r--compiler/rustc_monomorphize/src/mono_checks/abi_check.rs4
-rw-r--r--compiler/rustc_monomorphize/src/mono_checks/move_check.rs9
-rw-r--r--compiler/rustc_monomorphize/src/partitioning.rs59
-rw-r--r--compiler/rustc_monomorphize/src/partitioning/autodiff.rs121
6 files changed, 194 insertions, 37 deletions
diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml
index 9bdaeb015cd..5462105e5e8 100644
--- a/compiler/rustc_monomorphize/Cargo.toml
+++ b/compiler/rustc_monomorphize/Cargo.toml
@@ -6,6 +6,7 @@ edition = "2021"
 [dependencies]
 # tidy-alphabetical-start
 rustc_abi = { path = "../rustc_abi" }
+rustc_ast = { path = "../rustc_ast" }
 rustc_attr_parsing = { path = "../rustc_attr_parsing" }
 rustc_data_structures = { path = "../rustc_data_structures" }
 rustc_errors = { path = "../rustc_errors" }
@@ -15,6 +16,7 @@ rustc_macros = { path = "../rustc_macros" }
 rustc_middle = { path = "../rustc_middle" }
 rustc_session = { path = "../rustc_session" }
 rustc_span = { path = "../rustc_span" }
+rustc_symbol_mangling = { path = "../rustc_symbol_mangling" }
 rustc_target = { path = "../rustc_target" }
 serde = "1"
 serde_json = "1"
diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs
index bb603df1129..ae31ed59391 100644
--- a/compiler/rustc_monomorphize/src/collector.rs
+++ b/compiler/rustc_monomorphize/src/collector.rs
@@ -257,7 +257,7 @@ struct SharedState<'tcx> {
 
 pub(crate) struct UsageMap<'tcx> {
     // Maps every mono item to the mono items used by it.
-    used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
+    pub used_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
 
     // Maps every mono item to the mono items that use it.
     user_map: UnordMap<MonoItem<'tcx>, Vec<MonoItem<'tcx>>>,
@@ -814,6 +814,9 @@ impl<'a, 'tcx> MirVisitor<'tcx> for MirUsedCollector<'a, 'tcx> {
                 mir::AssertKind::MisalignedPointerDereference { .. } => {
                     push_mono_lang_item(self, LangItem::PanicMisalignedPointerDereference);
                 }
+                mir::AssertKind::NullPointerDereference => {
+                    push_mono_lang_item(self, LangItem::PanicNullPointerDereference);
+                }
                 _ => {
                     push_mono_lang_item(self, msg.panic_function());
                 }
@@ -950,7 +953,7 @@ fn visit_instance_use<'tcx>(
 
 /// Returns `true` if we should codegen an instance in the local crate, or returns `false` if we
 /// can just link to the upstream crate and therefore don't need a mono item.
-fn should_codegen_locally<'tcx>(tcx: TyCtxtAt<'tcx>, instance: Instance<'tcx>) -> bool {
+fn should_codegen_locally<'tcx>(tcx: TyCtxt<'tcx>, instance: Instance<'tcx>) -> bool {
     let Some(def_id) = instance.def.def_id_if_not_guaranteed_local_codegen() else {
         return true;
     };
@@ -973,7 +976,7 @@ fn should_codegen_locally<'tcx>(tcx: TyCtxtAt<'tcx>, instance: Instance<'tcx>) -
         return true;
     }
 
-    if tcx.is_reachable_non_generic(def_id) || instance.upstream_monomorphization(*tcx).is_some() {
+    if tcx.is_reachable_non_generic(def_id) || instance.upstream_monomorphization(tcx).is_some() {
         // We can link to the item in question, no instance needed in this crate.
         return false;
     }
@@ -1138,11 +1141,12 @@ fn create_mono_items_for_vtable_methods<'tcx>(
         bug!("create_mono_items_for_vtable_methods: {trait_ty:?} not a trait type");
     };
     if let Some(principal) = trait_ty.principal() {
-        let poly_trait_ref = principal.with_self_ty(tcx, impl_ty);
-        assert!(!poly_trait_ref.has_escaping_bound_vars());
+        let trait_ref =
+            tcx.instantiate_bound_regions_with_erased(principal.with_self_ty(tcx, impl_ty));
+        assert!(!trait_ref.has_escaping_bound_vars());
 
         // Walk all methods of the trait, including those of its supertraits
-        let entries = tcx.vtable_entries(poly_trait_ref);
+        let entries = tcx.vtable_entries(trait_ref);
         debug!(?entries);
         let methods = entries
             .iter()
@@ -1197,7 +1201,12 @@ fn collect_alloc<'tcx>(tcx: TyCtxt<'tcx>, alloc_id: AllocId, output: &mut MonoIt
             }
         }
         GlobalAlloc::VTable(ty, dyn_ty) => {
-            let alloc_id = tcx.vtable_allocation((ty, dyn_ty.principal()));
+            let alloc_id = tcx.vtable_allocation((
+                ty,
+                dyn_ty
+                    .principal()
+                    .map(|principal| tcx.instantiate_bound_regions_with_erased(principal)),
+            ));
             collect_alloc(tcx, alloc_id, output)
         }
     }
@@ -1213,7 +1222,7 @@ fn collect_items_of_instance<'tcx>(
     mode: CollectionMode,
 ) -> (MonoItems<'tcx>, MonoItems<'tcx>) {
     // This item is getting monomorphized, do mono-time checks.
-    tcx.ensure().check_mono_item(instance);
+    tcx.ensure_ok().check_mono_item(instance);
 
     let body = tcx.instance_mir(instance.def);
     // Naively, in "used" collection mode, all functions get added to *both* `used_items` and
@@ -1454,11 +1463,14 @@ impl<'v> RootCollector<'_, 'v> {
                 self.output.push(dummy_spanned(MonoItem::Static(def_id)));
             }
             DefKind::Const => {
-                // const items only generate mono items if they are
-                // actually used somewhere. Just declaring them is insufficient.
+                // Const items only generate mono items if they are actually used somewhere.
+                // Just declaring them is insufficient.
 
-                // but even just declaring them must collect the items they refer to
-                if let Ok(val) = self.tcx.const_eval_poly(id.owner_id.to_def_id()) {
+                // But even just declaring them must collect the items they refer to
+                // unless their generics require monomorphization.
+                if !self.tcx.generics_of(id.owner_id).requires_monomorphization(self.tcx)
+                    && let Ok(val) = self.tcx.const_eval_poly(id.owner_id.to_def_id())
+                {
                     collect_const_value(self.tcx, val, self.output);
                 }
             }
diff --git a/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs b/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs
index 30e634d8252..a0be7f11d70 100644
--- a/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs
+++ b/compiler/rustc_monomorphize/src/mono_checks/abi_check.rs
@@ -1,5 +1,6 @@
 //! This module ensures that if a function's ABI requires a particular target feature,
 //! that target feature is enabled both on the callee and all callers.
+use rustc_abi::{BackendRepr, RegKind};
 use rustc_hir::CRATE_HIR_ID;
 use rustc_middle::mir::{self, traversal};
 use rustc_middle::ty::inherent::*;
@@ -7,8 +8,7 @@ use rustc_middle::ty::{self, Instance, InstanceKind, Ty, TyCtxt};
 use rustc_session::lint::builtin::ABI_UNSUPPORTED_VECTOR_TYPES;
 use rustc_span::def_id::DefId;
 use rustc_span::{DUMMY_SP, Span, Symbol};
-use rustc_target::abi::call::{FnAbi, PassMode};
-use rustc_target::abi::{BackendRepr, RegKind};
+use rustc_target::callconv::{FnAbi, PassMode};
 
 use crate::errors::{
     AbiErrorDisabledVectorTypeCall, AbiErrorDisabledVectorTypeDef,
diff --git a/compiler/rustc_monomorphize/src/mono_checks/move_check.rs b/compiler/rustc_monomorphize/src/mono_checks/move_check.rs
index 02b9397f9a9..838bfdab1ea 100644
--- a/compiler/rustc_monomorphize/src/mono_checks/move_check.rs
+++ b/compiler/rustc_monomorphize/src/mono_checks/move_check.rs
@@ -162,11 +162,12 @@ impl<'tcx> MoveCheckVisitor<'tcx> {
             // but correct span? This would make the lint at least accept crate-level lint attributes.
             return;
         };
-        self.tcx.emit_node_span_lint(LARGE_ASSIGNMENTS, lint_root, span, LargeAssignmentsLint {
+        self.tcx.emit_node_span_lint(
+            LARGE_ASSIGNMENTS,
+            lint_root,
             span,
-            size: too_large_size.bytes(),
-            limit: limit as u64,
-        });
+            LargeAssignmentsLint { span, size: too_large_size.bytes(), limit: limit as u64 },
+        );
         self.move_size_spans.push(span);
     }
 }
diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs
index 7b179663430..d826d03918e 100644
--- a/compiler/rustc_monomorphize/src/partitioning.rs
+++ b/compiler/rustc_monomorphize/src/partitioning.rs
@@ -92,6 +92,8 @@
 //! source-level module, functions from the same module will be available for
 //! inlining, even when they are not marked `#[inline]`.
 
+mod autodiff;
+
 use std::cmp;
 use std::collections::hash_map::Entry;
 use std::fs::{self, File};
@@ -110,7 +112,7 @@ use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags;
 use rustc_middle::middle::exported_symbols::{SymbolExportInfo, SymbolExportLevel};
 use rustc_middle::mir::mono::{
     CodegenUnit, CodegenUnitNameBuilder, InstantiationMode, Linkage, MonoItem, MonoItemData,
-    Visibility,
+    MonoItemPartitions, Visibility,
 };
 use rustc_middle::ty::print::{characteristic_def_id_of_type, with_no_trimmed_paths};
 use rustc_middle::ty::{self, InstanceKind, TyCtxt};
@@ -251,17 +253,23 @@ where
             can_export_generics,
             always_export_generics,
         );
-        if visibility == Visibility::Hidden && can_be_internalized {
+
+        // We can't differentiate something that got inlined.
+        let autodiff_active = cfg!(llvm_enzyme)
+            && cx
+                .tcx
+                .codegen_fn_attrs(mono_item.def_id())
+                .autodiff_item
+                .as_ref()
+                .is_some_and(|ad| ad.is_active());
+
+        if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized {
             internalization_candidates.insert(mono_item);
         }
         let size_estimate = mono_item.size_estimate(cx.tcx);
 
-        cgu.items_mut().insert(mono_item, MonoItemData {
-            inlined: false,
-            linkage,
-            visibility,
-            size_estimate,
-        });
+        cgu.items_mut()
+            .insert(mono_item, MonoItemData { inlined: false, linkage, visibility, size_estimate });
 
         // Get all inlined items that are reachable from `mono_item` without
         // going via another root item. This includes drop-glue, functions from
@@ -1114,7 +1122,7 @@ where
     }
 }
 
-fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[CodegenUnit<'_>]) {
+fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitions<'_> {
     let collection_strategy = match tcx.sess.opts.unstable_opts.print_mono_items {
         Some(ref s) => {
             let mode = s.to_lowercase();
@@ -1167,15 +1175,27 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co
         }
     }
 
+    #[cfg(not(llvm_enzyme))]
+    let autodiff_mono_items: Vec<_> = vec![];
+    #[cfg(llvm_enzyme)]
+    let mut autodiff_mono_items: Vec<_> = vec![];
     let mono_items: DefIdSet = items
         .iter()
         .filter_map(|mono_item| match *mono_item {
-            MonoItem::Fn(ref instance) => Some(instance.def_id()),
+            MonoItem::Fn(ref instance) => {
+                #[cfg(llvm_enzyme)]
+                autodiff_mono_items.push((mono_item, instance));
+                Some(instance.def_id())
+            }
             MonoItem::Static(def_id) => Some(def_id),
             _ => None,
         })
         .collect();
 
+    let autodiff_items =
+        autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items);
+    let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items);
+
     // Output monomorphization stats per def_id
     if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats {
         if let Err(err) =
@@ -1214,9 +1234,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co
                         Linkage::LinkOnceODR => "OnceODR",
                         Linkage::WeakAny => "WeakAny",
                         Linkage::WeakODR => "WeakODR",
-                        Linkage::Appending => "Appending",
                         Linkage::Internal => "Internal",
-                        Linkage::Private => "Private",
                         Linkage::ExternalWeak => "ExternalWeak",
                         Linkage::Common => "Common",
                     };
@@ -1236,7 +1254,11 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co
         }
     }
 
-    (tcx.arena.alloc(mono_items), codegen_units)
+    MonoItemPartitions {
+        all_mono_items: tcx.arena.alloc(mono_items),
+        codegen_units,
+        autodiff_items,
+    }
 }
 
 /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s
@@ -1319,14 +1341,13 @@ fn dump_mono_items_stats<'tcx>(
 pub(crate) fn provide(providers: &mut Providers) {
     providers.collect_and_partition_mono_items = collect_and_partition_mono_items;
 
-    providers.is_codegened_item = |tcx, def_id| {
-        let (all_mono_items, _) = tcx.collect_and_partition_mono_items(());
-        all_mono_items.contains(&def_id)
-    };
+    providers.is_codegened_item =
+        |tcx, def_id| tcx.collect_and_partition_mono_items(()).all_mono_items.contains(&def_id);
 
     providers.codegen_unit = |tcx, name| {
-        let (_, all) = tcx.collect_and_partition_mono_items(());
-        all.iter()
+        tcx.collect_and_partition_mono_items(())
+            .codegen_units
+            .iter()
             .find(|cgu| cgu.name() == name)
             .unwrap_or_else(|| panic!("failed to find cgu with name {name:?}"))
     };
diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs
new file mode 100644
index 00000000000..bce31bf0748
--- /dev/null
+++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs
@@ -0,0 +1,121 @@
+use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
+use rustc_hir::def_id::LOCAL_CRATE;
+use rustc_middle::bug;
+use rustc_middle::mir::mono::MonoItem;
+use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
+use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
+use tracing::{debug, trace};
+
+use crate::partitioning::UsageMap;
+
+fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
+    if !matches!(fn_ty.kind(), ty::FnDef(..)) {
+        bug!("expected fn def for autodiff, got {:?}", fn_ty);
+    }
+    let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);
+
+    // If rustc compiles the unmodified primal, we know that this copy of the function
+    // also has correct lifetimes. We know that Enzyme won't free the shadow too early
+    // (or actually at all), so let's strip lifetimes when computing the layout.
+    let x = tcx.instantiate_bound_regions_with_erased(fnc_binder);
+    let mut new_activities = vec![];
+    let mut new_positions = vec![];
+    for (i, ty) in x.inputs().iter().enumerate() {
+        if let Some(inner_ty) = ty.builtin_deref(true) {
+            if ty.is_fn_ptr() {
+                // FIXME(ZuseZ4): add a nicer error, or just figure out how to support them,
+                // since Enzyme itself can handle them.
+                tcx.dcx().err("function pointers are currently not supported in autodiff");
+            }
+            if inner_ty.is_slice() {
+                // We know that the length will be passed as extra arg.
+                if !da.is_empty() {
+                    // We are looking at a slice. The length of that slice will become an
+                    // extra integer on llvm level. Integers are always const.
+                    // However, if the slice get's duplicated, we want to know to later check the
+                    // size. So we mark the new size argument as FakeActivitySize.
+                    let activity = match da[i] {
+                        DiffActivity::DualOnly
+                        | DiffActivity::Dual
+                        | DiffActivity::DuplicatedOnly
+                        | DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
+                        DiffActivity::Const => DiffActivity::Const,
+                        _ => bug!("unexpected activity for ptr/ref"),
+                    };
+                    new_activities.push(activity);
+                    new_positions.push(i + 1);
+                }
+                continue;
+            }
+        }
+    }
+    // now add the extra activities coming from slices
+    // Reverse order to not invalidate the indices
+    for _ in 0..new_activities.len() {
+        let pos = new_positions.pop().unwrap();
+        let activity = new_activities.pop().unwrap();
+        da.insert(pos, activity);
+    }
+}
+
+pub(crate) fn find_autodiff_source_functions<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    usage_map: &UsageMap<'tcx>,
+    autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
+) -> Vec<AutoDiffItem> {
+    let mut autodiff_items: Vec<AutoDiffItem> = vec![];
+    for (item, instance) in autodiff_mono_items {
+        let target_id = instance.def_id();
+        let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone();
+        let Some(target_attrs) = cg_fn_attr else {
+            continue;
+        };
+        let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
+        if target_attrs.is_source() {
+            trace!("source found: {:?}", target_id);
+        }
+        if !target_attrs.apply_autodiff() {
+            continue;
+        }
+
+        let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
+
+        let source =
+            usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
+                MonoItem::Fn(ref instance_s) => {
+                    let source_id = instance_s.def_id();
+                    if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
+                        && ad.is_active()
+                    {
+                        return Some(instance_s);
+                    }
+                    None
+                }
+                _ => None,
+            });
+        let inst = match source {
+            Some(source) => source,
+            None => continue,
+        };
+
+        debug!("source_id: {:?}", inst.def_id());
+        let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
+        assert!(fn_ty.is_fn());
+        adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
+        let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
+
+        let mut new_target_attrs = target_attrs.clone();
+        new_target_attrs.input_activity = input_activities;
+        let itm = new_target_attrs.into_item(symb, target_symbol);
+        autodiff_items.push(itm);
+    }
+
+    if !autodiff_items.is_empty() {
+        trace!("AUTODIFF ITEMS EXIST");
+        for item in &mut *autodiff_items {
+            trace!("{}", &item);
+        }
+    }
+
+    autodiff_items
+}