diff options
| author | Manuel Drehwald <git@manuel.drehwald.info> | 2025-04-03 17:21:21 -0400 |
|---|---|---|
| committer | Manuel Drehwald <git@manuel.drehwald.info> | 2025-04-03 17:21:21 -0400 |
| commit | e0c8ead8802fcc314ec852a6641e08fd09307708 (patch) | |
| tree | e1b7f37f5ef404b061adf4608f63a038123176ee /compiler/rustc_codegen_ssa/src | |
| parent | 087ffd73bf43671db22087570c97e0afdde20f31 (diff) | |
| download | rust-e0c8ead8802fcc314ec852a6641e08fd09307708.tar.gz rust-e0c8ead8802fcc314ec852a6641e08fd09307708.zip | |
add autodiff batching middle-end
Diffstat (limited to 'compiler/rustc_codegen_ssa/src')
| -rw-r--r-- | compiler/rustc_codegen_ssa/src/codegen_attrs.rs | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 4b2fd61442b..02f631c7b39 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -2,7 +2,7 @@ use std::str::FromStr; use rustc_abi::ExternAbi; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; -use rustc_ast::{MetaItem, MetaItemInner, attr}; +use rustc_ast::{LitKind, MetaItem, MetaItemInner, attr}; use rustc_attr_parsing::ReprAttr::ReprAlign; use rustc_attr_parsing::{AttributeKind, InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_data_structures::fx::FxHashMap; @@ -805,8 +805,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> { return Some(AutoDiffAttrs::source()); } - let [mode, input_activities @ .., ret_activity] = &list[..] else { - span_bug!(attr.span(), "rustc_autodiff attribute must contain mode and activities"); + let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else { + span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities"); }; let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode { p1.segments.first().unwrap().ident @@ -823,6 +823,30 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> { } }; + let width: u32 = match width_meta { + MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => { + let w = p1.segments.first().unwrap().ident; + match w.as_str().parse() { + Ok(val) => val, + Err(_) => { + span_bug!(w.span, "rustc_autodiff width should fit u32"); + } + } + } + MetaItemInner::Lit(lit) => { + if let LitKind::Int(val, _) = lit.kind { + match val.get().try_into() { + Ok(val) => val, + Err(_) => { + span_bug!(lit.span, "rustc_autodiff width should fit u32"); + } + } + } else { + span_bug!(lit.span, "rustc_autodiff width should be an integer"); + } + } + }; + // First read the ret symbol from the attribute let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity { p1.segments.first().unwrap().ident @@ -860,7 +884,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> { } } - Some(AutoDiffAttrs { mode, width: 1, ret_activity, input_activity: arg_activities }) + Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities }) } pub(crate) fn provide(providers: &mut Providers) { |
