From 56a0c7dfea60a721948e29458fc714b6303e8e4a Mon Sep 17 00:00:00 2001 From: HaeNoe Date: Sat, 19 Apr 2025 19:17:22 +0200 Subject: feat: propagate generics to generated function --- compiler/rustc_builtin_macros/src/autodiff.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) (limited to 'compiler/rustc_builtin_macros/src') diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 8c5c20c7af4..36f84c5d74b 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -73,10 +73,10 @@ mod llvm_enzyme { } // Get information about the function the macro is applied to - fn extract_item_info(iitem: &P) -> Option<(Visibility, FnSig, Ident)> { + fn extract_item_info(iitem: &P) -> Option<(Visibility, FnSig, Ident, Generics)> { match &iitem.kind { - ItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((iitem.vis.clone(), sig.clone(), ident.clone())) + ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { + Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, } @@ -210,16 +210,18 @@ mod llvm_enzyme { } let dcx = ecx.sess.dcx(); - // first get information about the annotable item: - let Some((vis, sig, primal)) = (match &item { + // first get information about the annotable item: visibility, signature, name and generic + // parameters. + // these will be used to generate the differentiated version of the function + let Some((vis, sig, primal, generics)) = (match &item { Annotatable::Item(iitem) => extract_item_info(iitem), Annotatable::Stmt(stmt) => match &stmt.kind { ast::StmtKind::Item(iitem) => extract_item_info(iitem), _ => None, }, Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone())) + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { + Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, }, @@ -310,7 +312,7 @@ mod llvm_enzyme { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics: Generics::default(), + generics, contract: None, body: Some(d_body), define_opaque: None, -- cgit 1.4.1-3-g733a5