diff options
Diffstat (limited to 'compiler/rustc_codegen_ssa/src/codegen_attrs.rs')
| -rw-r--r-- | compiler/rustc_codegen_ssa/src/codegen_attrs.rs | 118 |
1 files changed, 117 insertions, 1 deletions
diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index a0bc2d4ea48..4166387dad0 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,5 +1,10 @@ +use std::str::FromStr; + use rustc_ast::attr::list_contains_name; -use rustc_ast::{MetaItemInner, attr}; +use rustc_ast::expand::autodiff_attrs::{ + AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, +}; +use rustc_ast::{MetaItem, MetaItemInner, attr}; use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_data_structures::fx::FxHashMap; use rustc_errors::codes::*; @@ -13,6 +18,7 @@ use rustc_middle::middle::codegen_fn_attrs::{ }; use rustc_middle::mir::mono::Linkage; use rustc_middle::query::Providers; +use rustc_middle::span_bug; use rustc_middle::ty::{self as ty, TyCtxt}; use rustc_session::parse::feature_err; use rustc_session::{Session, lint}; @@ -65,6 +71,13 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs { codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER; } + // If our rustc version supports autodiff/enzyme, then we call our handler + // to check for any `#[rustc_autodiff(...)]` attributes. + if cfg!(llvm_enzyme) { + let ad = autodiff_attrs(tcx, did.into()); + codegen_fn_attrs.autodiff_item = ad; + } + // When `no_builtins` is applied at the crate level, we should add the // `no-builtins` attribute to each function to ensure it takes effect in LTO. let crate_attrs = tcx.hir().attrs(rustc_hir::CRATE_HIR_ID); @@ -856,6 +869,109 @@ impl<'a> MixedExportNameAndNoMangleState<'a> { } } +/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)] +/// macros. There are two forms. The pure one without args to mark primal functions (the functions +/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the +/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never +/// panic, unless we introduced a bug when parsing the autodiff macro. +fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> { + let attrs = tcx.get_attrs(id, sym::rustc_autodiff); + + let attrs = + attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>(); + + // check for exactly one autodiff attribute on placeholder functions. + // There should only be one, since we generate a new placeholder per ad macro. + // FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but + // looks strange e.g. under cargo-expand. + let attr = match &attrs[..] { + [] => return None, + [attr] => attr, + // These two attributes are the same and unfortunately duplicated due to a previous bug. + [attr, _attr2] => attr, + _ => { + //FIXME(ZuseZ4): Once we fixed our parser, we should also prohibit the two-attribute + //branch above. + span_bug!(attrs[1].span, "cg_ssa: rustc_autodiff should only exist once per source"); + } + }; + + let list = attr.meta_item_list().unwrap_or_default(); + + // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions + if list.is_empty() { + return Some(AutoDiffAttrs::source()); + } + + let [mode, input_activities @ .., ret_activity] = &list[..] else { + span_bug!(attr.span, "rustc_autodiff attribute must contain mode and activities"); + }; + let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode { + p1.segments.first().unwrap().ident + } else { + span_bug!(attr.span, "rustc_autodiff attribute must contain mode"); + }; + + // parse mode + let mode = match mode.as_str() { + "Forward" => DiffMode::Forward, + "Reverse" => DiffMode::Reverse, + "ForwardFirst" => DiffMode::ForwardFirst, + "ReverseFirst" => DiffMode::ReverseFirst, + _ => { + span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode"); + } + }; + + // First read the ret symbol from the attribute + let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity { + p1.segments.first().unwrap().ident + } else { + span_bug!(attr.span, "rustc_autodiff attribute must contain the return activity"); + }; + + // Then parse it into an actual DiffActivity + let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else { + span_bug!(ret_symbol.span, "invalid return activity"); + }; + + // Now parse all the intermediate (input) activities + let mut arg_activities: Vec<DiffActivity> = vec![]; + for arg in input_activities { + let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg { + match p2.segments.first() { + Some(x) => x.ident, + None => { + span_bug!( + arg.span(), + "rustc_autodiff attribute must contain the input activity" + ); + } + } + } else { + span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity"); + }; + + match DiffActivity::from_str(arg_symbol.as_str()) { + Ok(arg_activity) => arg_activities.push(arg_activity), + Err(_) => { + span_bug!(arg_symbol.span, "invalid input activity"); + } + } + } + + for &input in &arg_activities { + if !valid_input_activity(mode, input) { + span_bug!(attr.span, "Invalid input activity {} for {} mode", input, mode); + } + } + if !valid_ret_activity(mode, ret_activity) { + span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode); + } + + Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities }) +} + pub(crate) fn provide(providers: &mut Providers) { *providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers }; } |
