diff options
Diffstat (limited to 'compiler/rustc_builtin_macros/src/autodiff.rs')
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 73 |
1 files changed, 62 insertions, 11 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 8c5c20c7af4..1ff4fc6aaab 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -73,10 +73,10 @@ mod llvm_enzyme { } // Get information about the function the macro is applied to - fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> { + fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> { match &iitem.kind { - ItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((iitem.vis.clone(), sig.clone(), ident.clone())) + ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { + Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, } @@ -210,16 +210,18 @@ mod llvm_enzyme { } let dcx = ecx.sess.dcx(); - // first get information about the annotable item: - let Some((vis, sig, primal)) = (match &item { + // first get information about the annotable item: visibility, signature, name and generic + // parameters. + // these will be used to generate the differentiated version of the function + let Some((vis, sig, primal, generics)) = (match &item { Annotatable::Item(iitem) => extract_item_info(iitem), Annotatable::Stmt(stmt) => match &stmt.kind { ast::StmtKind::Item(iitem) => extract_item_info(iitem), _ => None, }, Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone())) + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { + Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, }, @@ -303,6 +305,7 @@ mod llvm_enzyme { let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); let d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, + &generics, ); // The first element of it is the name of the function to be generated @@ -310,7 +313,7 @@ mod llvm_enzyme { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics: Generics::default(), + generics, contract: None, body: Some(d_body), define_opaque: None, @@ -475,6 +478,7 @@ mod llvm_enzyme { new_decl_span: Span, idents: &[Ident], errored: bool, + generics: &Generics, ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) { let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); let noop = ast::InlineAsm { @@ -497,7 +501,7 @@ mod llvm_enzyme { }; let unsf_expr = ecx.expr_block(P(unsf_block)); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); - let primal_call = gen_primal_call(ecx, span, primal, idents); + let primal_call = gen_primal_call(ecx, span, primal, idents, generics); let black_box_primal_call = ecx.expr_call( new_decl_span, blackbox_call_expr.clone(), @@ -546,6 +550,7 @@ mod llvm_enzyme { sig_span: Span, idents: Vec<Ident>, errored: bool, + generics: &Generics, ) -> P<ast::Block> { let new_decl_span = d_sig.span; @@ -566,6 +571,7 @@ mod llvm_enzyme { new_decl_span, &idents, errored, + generics, ); if !has_ret(&d_sig.decl.output) { @@ -608,7 +614,6 @@ mod llvm_enzyme { panic!("Did not expect Default ret ty: {:?}", span); } }; - if x.mode.is_fwd() { // Fwd mode is easy. If the return activity is Const, we support arbitrary types. // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. @@ -668,8 +673,10 @@ mod llvm_enzyme { span: Span, primal: Ident, idents: &[Ident], + generics: &Generics, ) -> P<ast::Expr> { let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; + if has_self { let args: ThinVec<_> = idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); @@ -678,7 +685,51 @@ mod llvm_enzyme { } else { let args: ThinVec<_> = idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let mut primal_path = ecx.path_ident(span, primal); + + let is_generic = !generics.params.is_empty(); + + match (is_generic, primal_path.segments.last_mut()) { + (true, Some(function_path)) => { + let primal_generic_types = generics + .params + .iter() + .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. })); + + let generated_generic_types = primal_generic_types + .map(|type_param| { + let generic_param = TyKind::Path( + None, + ast::Path { + span, + segments: thin_vec![ast::PathSegment { + ident: type_param.ident, + args: None, + id: ast::DUMMY_NODE_ID, + }], + tokens: None, + }, + ); + + ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty { + id: type_param.id, + span, + kind: generic_param, + tokens: None, + }))) + }) + .collect(); + + function_path.args = + Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs { + span, + args: generated_generic_types, + }))); + } + _ => {} + } + + let primal_call_expr = ecx.expr_path(primal_path); ecx.expr_call(span, primal_call_expr, args) } } |
