diff options
Diffstat (limited to 'compiler/rustc_builtin_macros/src/autodiff.rs')
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 140 |
1 files changed, 109 insertions, 31 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 8c5c20c7af4..dc3bb8ab52a 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -73,10 +73,10 @@ mod llvm_enzyme { } // Get information about the function the macro is applied to - fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> { + fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> { match &iitem.kind { - ItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((iitem.vis.clone(), sig.clone(), ident.clone())) + ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { + Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, } @@ -86,27 +86,23 @@ mod llvm_enzyme { ecx: &mut ExtCtxt<'_>, meta_item: &ThinVec<MetaItemInner>, has_ret: bool, + mode: DiffMode, ) -> AutoDiffAttrs { let dcx = ecx.sess.dcx(); - let mode = name(&meta_item[1]); - let Ok(mode) = DiffMode::from_str(&mode) else { - dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode }); - return AutoDiffAttrs::error(); - }; // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode. // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1. - let mut first_activity = 2; + let mut first_activity = 1; - let width = if let [_, _, x, ..] = &meta_item[..] + let width = if let [_, x, ..] = &meta_item[..] && let Some(x) = width(x) { - first_activity = 3; + first_activity = 2; match x.try_into() { Ok(x) => x, Err(_) => { dcx.emit_err(errors::AutoDiffInvalidWidth { - span: meta_item[2].span(), + span: meta_item[1].span(), width: x, }); return AutoDiffAttrs::error(); @@ -165,6 +161,24 @@ mod llvm_enzyme { ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); } + pub(crate) fn expand_forward( + ecx: &mut ExtCtxt<'_>, + expand_span: Span, + meta_item: &ast::MetaItem, + item: Annotatable, + ) -> Vec<Annotatable> { + expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward) + } + + pub(crate) fn expand_reverse( + ecx: &mut ExtCtxt<'_>, + expand_span: Span, + meta_item: &ast::MetaItem, + item: Annotatable, + ) -> Vec<Annotatable> { + expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse) + } + /// We expand the autodiff macro to generate a new placeholder function which passes /// type-checking and can be called by users. The function body of the placeholder function will /// later be replaced on LLVM-IR level, so the design of the body is less important and for now @@ -198,11 +212,12 @@ mod llvm_enzyme { /// ``` /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked /// in CI. - pub(crate) fn expand( + pub(crate) fn expand_with_mode( ecx: &mut ExtCtxt<'_>, expand_span: Span, meta_item: &ast::MetaItem, mut item: Annotatable, + mode: DiffMode, ) -> Vec<Annotatable> { if cfg!(not(llvm_enzyme)) { ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span }); @@ -210,16 +225,18 @@ mod llvm_enzyme { } let dcx = ecx.sess.dcx(); - // first get information about the annotable item: - let Some((vis, sig, primal)) = (match &item { + // first get information about the annotable item: visibility, signature, name and generic + // parameters. + // these will be used to generate the differentiated version of the function + let Some((vis, sig, primal, generics)) = (match &item { Annotatable::Item(iitem) => extract_item_info(iitem), Annotatable::Stmt(stmt) => match &stmt.kind { ast::StmtKind::Item(iitem) => extract_item_info(iitem), _ => None, }, Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone())) + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { + Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, }, @@ -243,29 +260,41 @@ mod llvm_enzyme { // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field let mut ts: Vec<TokenTree> = vec![]; - if meta_item_vec.len() < 2 { - // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no - // input and output args. + if meta_item_vec.len() < 1 { + // At the bare minimum, we need a fnc name. dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() }); return vec![item]; } - meta_item_inner_to_ts(&meta_item_vec[1], &mut ts); + let mode_symbol = match mode { + DiffMode::Forward => sym::Forward, + DiffMode::Reverse => sym::Reverse, + _ => unreachable!("Unsupported mode: {:?}", mode), + }; + + // Insert mode token + let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default()); + ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint)); + ts.insert( + 1, + TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone), + ); // Now, if the user gave a width (vector aka batch-mode ad), then we copy it. // If it is not given, we default to 1 (scalar mode). let start_position; let kind: LitKind = LitKind::Integer; let symbol; - if meta_item_vec.len() >= 3 - && let Some(width) = width(&meta_item_vec[2]) + if meta_item_vec.len() >= 2 + && let Some(width) = width(&meta_item_vec[1]) { - start_position = 3; + start_position = 2; symbol = Symbol::intern(&width.to_string()); } else { - start_position = 2; + start_position = 1; symbol = sym::integer(1); } + let l: Lit = Lit { kind, symbol, suffix: None }; let t = Token::new(TokenKind::Literal(l), Span::default()); let comma = Token::new(TokenKind::Comma, Span::default()); @@ -287,7 +316,7 @@ mod llvm_enzyme { ts.pop(); let ts: TokenStream = TokenStream::from_iter(ts); - let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret); + let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode); if !x.is_active() { // We encountered an error, so we return the original item. // This allows us to potentially parse other attributes. @@ -303,6 +332,7 @@ mod llvm_enzyme { let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); let d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, + &generics, ); // The first element of it is the name of the function to be generated @@ -310,7 +340,7 @@ mod llvm_enzyme { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics: Generics::default(), + generics, contract: None, body: Some(d_body), define_opaque: None, @@ -475,6 +505,7 @@ mod llvm_enzyme { new_decl_span: Span, idents: &[Ident], errored: bool, + generics: &Generics, ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) { let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); let noop = ast::InlineAsm { @@ -497,7 +528,7 @@ mod llvm_enzyme { }; let unsf_expr = ecx.expr_block(P(unsf_block)); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); - let primal_call = gen_primal_call(ecx, span, primal, idents); + let primal_call = gen_primal_call(ecx, span, primal, idents, generics); let black_box_primal_call = ecx.expr_call( new_decl_span, blackbox_call_expr.clone(), @@ -546,6 +577,7 @@ mod llvm_enzyme { sig_span: Span, idents: Vec<Ident>, errored: bool, + generics: &Generics, ) -> P<ast::Block> { let new_decl_span = d_sig.span; @@ -566,6 +598,7 @@ mod llvm_enzyme { new_decl_span, &idents, errored, + generics, ); if !has_ret(&d_sig.decl.output) { @@ -608,7 +641,6 @@ mod llvm_enzyme { panic!("Did not expect Default ret ty: {:?}", span); } }; - if x.mode.is_fwd() { // Fwd mode is easy. If the return activity is Const, we support arbitrary types. // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. @@ -668,8 +700,10 @@ mod llvm_enzyme { span: Span, primal: Ident, idents: &[Ident], + generics: &Generics, ) -> P<ast::Expr> { let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; + if has_self { let args: ThinVec<_> = idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); @@ -678,7 +712,51 @@ mod llvm_enzyme { } else { let args: ThinVec<_> = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let mut primal_path = ecx.path_ident(span, primal); + + let is_generic = !generics.params.is_empty(); + + match (is_generic, primal_path.segments.last_mut()) { + (true, Some(function_path)) => { + let primal_generic_types = generics + .params + .iter() + .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. })); + + let generated_generic_types = primal_generic_types + .map(|type_param| { + let generic_param = TyKind::Path( + None, + ast::Path { + span, + segments: thin_vec![ast::PathSegment { + ident: type_param.ident, + args: None, + id: ast::DUMMY_NODE_ID, + }], + tokens: None, + }, + ); + + ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty { + id: type_param.id, + span, + kind: generic_param, + tokens: None, + }))) + }) + .collect(); + + function_path.args = + Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs { + span, + args: generated_generic_types, + }))); + } + _ => {} + } + + let primal_call_expr = ecx.expr_path(primal_path); ecx.expr_call(span, primal_call_expr, args) } } @@ -966,4 +1044,4 @@ mod llvm_enzyme { } } -pub(crate) use llvm_enzyme::expand; +pub(crate) use llvm_enzyme::{expand_forward, expand_reverse}; |
