about summary refs log tree commit diff
path: root/compiler/rustc_codegen_ssa/src/codegen_attrs.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_ssa/src/codegen_attrs.rs')
-rw-r--r--compiler/rustc_codegen_ssa/src/codegen_attrs.rs118
1 files changed, 117 insertions, 1 deletions
diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs
index a0bc2d4ea48..4166387dad0 100644
--- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs
+++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs
@@ -1,5 +1,10 @@
+use std::str::FromStr;
+
 use rustc_ast::attr::list_contains_name;
-use rustc_ast::{MetaItemInner, attr};
+use rustc_ast::expand::autodiff_attrs::{
+    AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
+};
+use rustc_ast::{MetaItem, MetaItemInner, attr};
 use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
 use rustc_data_structures::fx::FxHashMap;
 use rustc_errors::codes::*;
@@ -13,6 +18,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
 };
 use rustc_middle::mir::mono::Linkage;
 use rustc_middle::query::Providers;
+use rustc_middle::span_bug;
 use rustc_middle::ty::{self as ty, TyCtxt};
 use rustc_session::parse::feature_err;
 use rustc_session::{Session, lint};
@@ -65,6 +71,13 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs {
         codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER;
     }
 
+    // If our rustc version supports autodiff/enzyme, then we call our handler
+    // to check for any `#[rustc_autodiff(...)]` attributes.
+    if cfg!(llvm_enzyme) {
+        let ad = autodiff_attrs(tcx, did.into());
+        codegen_fn_attrs.autodiff_item = ad;
+    }
+
     // When `no_builtins` is applied at the crate level, we should add the
     // `no-builtins` attribute to each function to ensure it takes effect in LTO.
     let crate_attrs = tcx.hir().attrs(rustc_hir::CRATE_HIR_ID);
@@ -856,6 +869,109 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
     }
 }
 
+/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)]
+/// macros. There are two forms. The pure one without args to mark primal functions (the functions
+/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
+/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
+/// panic, unless we introduced a bug when parsing the autodiff macro.
+fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
+    let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
+
+    let attrs =
+        attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>();
+
+    // check for exactly one autodiff attribute on placeholder functions.
+    // There should only be one, since we generate a new placeholder per ad macro.
+    // FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but
+    // looks strange e.g. under cargo-expand.
+    let attr = match &attrs[..] {
+        [] => return None,
+        [attr] => attr,
+        // These two attributes are the same and unfortunately duplicated due to a previous bug.
+        [attr, _attr2] => attr,
+        _ => {
+            //FIXME(ZuseZ4): Once we fixed our parser, we should also prohibit the two-attribute
+            //branch above.
+            span_bug!(attrs[1].span, "cg_ssa: rustc_autodiff should only exist once per source");
+        }
+    };
+
+    let list = attr.meta_item_list().unwrap_or_default();
+
+    // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
+    if list.is_empty() {
+        return Some(AutoDiffAttrs::source());
+    }
+
+    let [mode, input_activities @ .., ret_activity] = &list[..] else {
+        span_bug!(attr.span, "rustc_autodiff attribute must contain mode and activities");
+    };
+    let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
+        p1.segments.first().unwrap().ident
+    } else {
+        span_bug!(attr.span, "rustc_autodiff attribute must contain mode");
+    };
+
+    // parse mode
+    let mode = match mode.as_str() {
+        "Forward" => DiffMode::Forward,
+        "Reverse" => DiffMode::Reverse,
+        "ForwardFirst" => DiffMode::ForwardFirst,
+        "ReverseFirst" => DiffMode::ReverseFirst,
+        _ => {
+            span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
+        }
+    };
+
+    // First read the ret symbol from the attribute
+    let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
+        p1.segments.first().unwrap().ident
+    } else {
+        span_bug!(attr.span, "rustc_autodiff attribute must contain the return activity");
+    };
+
+    // Then parse it into an actual DiffActivity
+    let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
+        span_bug!(ret_symbol.span, "invalid return activity");
+    };
+
+    // Now parse all the intermediate (input) activities
+    let mut arg_activities: Vec<DiffActivity> = vec![];
+    for arg in input_activities {
+        let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
+            match p2.segments.first() {
+                Some(x) => x.ident,
+                None => {
+                    span_bug!(
+                        arg.span(),
+                        "rustc_autodiff attribute must contain the input activity"
+                    );
+                }
+            }
+        } else {
+            span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity");
+        };
+
+        match DiffActivity::from_str(arg_symbol.as_str()) {
+            Ok(arg_activity) => arg_activities.push(arg_activity),
+            Err(_) => {
+                span_bug!(arg_symbol.span, "invalid input activity");
+            }
+        }
+    }
+
+    for &input in &arg_activities {
+        if !valid_input_activity(mode, input) {
+            span_bug!(attr.span, "Invalid input activity {} for {} mode", input, mode);
+        }
+    }
+    if !valid_ret_activity(mode, ret_activity) {
+        span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode);
+    }
+
+    Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
+}
+
 pub(crate) fn provide(providers: &mut Providers) {
     *providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
 }