about summary refs log tree commit diff
path: root/compiler/rustc_builtin_macros/src/autodiff.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_builtin_macros/src/autodiff.rs')
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs73
1 files changed, 62 insertions, 11 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 8c5c20c7af4..1ff4fc6aaab 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -73,10 +73,10 @@ mod llvm_enzyme {
     }
 
     // Get information about the function the macro is applied to
-    fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
+    fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
         match &iitem.kind {
-            ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
-                Some((iitem.vis.clone(), sig.clone(), ident.clone()))
+            ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
+                Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
             }
             _ => None,
         }
@@ -210,16 +210,18 @@ mod llvm_enzyme {
         }
         let dcx = ecx.sess.dcx();
 
-        // first get information about the annotable item:
-        let Some((vis, sig, primal)) = (match &item {
+        // first get information about the annotable item: visibility, signature, name and generic
+        // parameters.
+        // these will be used to generate the differentiated version of the function
+        let Some((vis, sig, primal, generics)) = (match &item {
             Annotatable::Item(iitem) => extract_item_info(iitem),
             Annotatable::Stmt(stmt) => match &stmt.kind {
                 ast::StmtKind::Item(iitem) => extract_item_info(iitem),
                 _ => None,
             },
             Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
-                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
-                    Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
+                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
+                    Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
                 }
                 _ => None,
             },
@@ -303,6 +305,7 @@ mod llvm_enzyme {
         let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
         let d_body = gen_enzyme_body(
             ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
+            &generics,
         );
 
         // The first element of it is the name of the function to be generated
@@ -310,7 +313,7 @@ mod llvm_enzyme {
             defaultness: ast::Defaultness::Final,
             sig: d_sig,
             ident: first_ident(&meta_item_vec[0]),
-            generics: Generics::default(),
+            generics,
             contract: None,
             body: Some(d_body),
             define_opaque: None,
@@ -475,6 +478,7 @@ mod llvm_enzyme {
         new_decl_span: Span,
         idents: &[Ident],
         errored: bool,
+        generics: &Generics,
     ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
         let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
         let noop = ast::InlineAsm {
@@ -497,7 +501,7 @@ mod llvm_enzyme {
         };
         let unsf_expr = ecx.expr_block(P(unsf_block));
         let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
-        let primal_call = gen_primal_call(ecx, span, primal, idents);
+        let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
         let black_box_primal_call = ecx.expr_call(
             new_decl_span,
             blackbox_call_expr.clone(),
@@ -546,6 +550,7 @@ mod llvm_enzyme {
         sig_span: Span,
         idents: Vec<Ident>,
         errored: bool,
+        generics: &Generics,
     ) -> P<ast::Block> {
         let new_decl_span = d_sig.span;
 
@@ -566,6 +571,7 @@ mod llvm_enzyme {
             new_decl_span,
             &idents,
             errored,
+            generics,
         );
 
         if !has_ret(&d_sig.decl.output) {
@@ -608,7 +614,6 @@ mod llvm_enzyme {
                 panic!("Did not expect Default ret ty: {:?}", span);
             }
         };
-
         if x.mode.is_fwd() {
             // Fwd mode is easy. If the return activity is Const, we support arbitrary types.
             // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
@@ -668,8 +673,10 @@ mod llvm_enzyme {
         span: Span,
         primal: Ident,
         idents: &[Ident],
+        generics: &Generics,
     ) -> P<ast::Expr> {
         let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
+
         if has_self {
             let args: ThinVec<_> =
                 idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
@@ -678,7 +685,51 @@ mod llvm_enzyme {
         } else {
             let args: ThinVec<_> =
                 idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
-            let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
+            let mut primal_path = ecx.path_ident(span, primal);
+
+            let is_generic = !generics.params.is_empty();
+
+            match (is_generic, primal_path.segments.last_mut()) {
+                (true, Some(function_path)) => {
+                    let primal_generic_types = generics
+                        .params
+                        .iter()
+                        .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
+
+                    let generated_generic_types = primal_generic_types
+                        .map(|type_param| {
+                            let generic_param = TyKind::Path(
+                                None,
+                                ast::Path {
+                                    span,
+                                    segments: thin_vec![ast::PathSegment {
+                                        ident: type_param.ident,
+                                        args: None,
+                                        id: ast::DUMMY_NODE_ID,
+                                    }],
+                                    tokens: None,
+                                },
+                            );
+
+                            ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
+                                id: type_param.id,
+                                span,
+                                kind: generic_param,
+                                tokens: None,
+                            })))
+                        })
+                        .collect();
+
+                    function_path.args =
+                        Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
+                            span,
+                            args: generated_generic_types,
+                        })));
+                }
+                _ => {}
+            }
+
+            let primal_call_expr = ecx.expr_path(primal_path);
             ecx.expr_call(span, primal_call_expr, args)
         }
     }