about summary refs log tree commit diff
diff options
context:
space:
mode:
authorHaeNoe <git@haenoe.party>2025-04-19 19:17:22 +0200
committerHaeNoe <git@haenoe.party>2025-05-11 17:54:57 +0200
commit56a0c7dfea60a721948e29458fc714b6303e8e4a (patch)
tree6ab35e9b8bf482f914d0211da5c5bffd85497953
parent16c1c54a2921d5ace22e4a71c0ba7d4ef4b8aec7 (diff)
downloadrust-56a0c7dfea60a721948e29458fc714b6303e8e4a.tar.gz
rust-56a0c7dfea60a721948e29458fc714b6303e8e4a.zip
feat: propagate generics to generated function
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs18
1 files changed, 10 insertions, 8 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 8c5c20c7af4..36f84c5d74b 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -73,10 +73,10 @@ mod llvm_enzyme {
     }
 
     // Get information about the function the macro is applied to
-    fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
+    fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
         match &iitem.kind {
-            ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
-                Some((iitem.vis.clone(), sig.clone(), ident.clone()))
+            ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
+                Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
             }
             _ => None,
         }
@@ -210,16 +210,18 @@ mod llvm_enzyme {
         }
         let dcx = ecx.sess.dcx();
 
-        // first get information about the annotable item:
-        let Some((vis, sig, primal)) = (match &item {
+        // first get information about the annotable item: visibility, signature, name and generic
+        // parameters.
+        // these will be used to generate the differentiated version of the function
+        let Some((vis, sig, primal, generics)) = (match &item {
             Annotatable::Item(iitem) => extract_item_info(iitem),
             Annotatable::Stmt(stmt) => match &stmt.kind {
                 ast::StmtKind::Item(iitem) => extract_item_info(iitem),
                 _ => None,
             },
             Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
-                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
-                    Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
+                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
+                    Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
                 }
                 _ => None,
             },
@@ -310,7 +312,7 @@ mod llvm_enzyme {
             defaultness: ast::Defaultness::Final,
             sig: d_sig,
             ident: first_ident(&meta_item_vec[0]),
-            generics: Generics::default(),
+            generics,
             contract: None,
             body: Some(d_body),
             define_opaque: None,