diff options
| author | HaeNoe <git@haenoe.party> | 2025-04-19 19:17:22 +0200 |
|---|---|---|
| committer | HaeNoe <git@haenoe.party> | 2025-05-11 17:54:57 +0200 |
| commit | 56a0c7dfea60a721948e29458fc714b6303e8e4a (patch) | |
| tree | 6ab35e9b8bf482f914d0211da5c5bffd85497953 | |
| parent | 16c1c54a2921d5ace22e4a71c0ba7d4ef4b8aec7 (diff) | |
| download | rust-56a0c7dfea60a721948e29458fc714b6303e8e4a.tar.gz rust-56a0c7dfea60a721948e29458fc714b6303e8e4a.zip | |
feat: propagate generics to generated function
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 8c5c20c7af4..36f84c5d74b 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -73,10 +73,10 @@ mod llvm_enzyme { } // Get information about the function the macro is applied to - fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> { + fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> { match &iitem.kind { - ItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((iitem.vis.clone(), sig.clone(), ident.clone())) + ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { + Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, } @@ -210,16 +210,18 @@ mod llvm_enzyme { } let dcx = ecx.sess.dcx(); - // first get information about the annotable item: - let Some((vis, sig, primal)) = (match &item { + // first get information about the annotable item: visibility, signature, name and generic + // parameters. + // these will be used to generate the differentiated version of the function + let Some((vis, sig, primal, generics)) = (match &item { Annotatable::Item(iitem) => extract_item_info(iitem), Annotatable::Stmt(stmt) => match &stmt.kind { ast::StmtKind::Item(iitem) => extract_item_info(iitem), _ => None, }, Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone())) + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { + Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, }, @@ -310,7 +312,7 @@ mod llvm_enzyme { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics: Generics::default(), + generics, contract: None, body: Some(d_body), define_opaque: None, |
