about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs55
-rw-r--r--tests/pretty/autodiff/autodiff_forward.pp4
2 files changed, 54 insertions, 5 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 36f84c5d74b..1ff4fc6aaab 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -305,6 +305,7 @@ mod llvm_enzyme {
         let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
         let d_body = gen_enzyme_body(
             ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
+            &generics,
         );
 
         // The first element of it is the name of the function to be generated
@@ -477,6 +478,7 @@ mod llvm_enzyme {
         new_decl_span: Span,
         idents: &[Ident],
         errored: bool,
+        generics: &Generics,
     ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
         let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
         let noop = ast::InlineAsm {
@@ -499,7 +501,7 @@ mod llvm_enzyme {
         };
         let unsf_expr = ecx.expr_block(P(unsf_block));
         let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
-        let primal_call = gen_primal_call(ecx, span, primal, idents);
+        let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
         let black_box_primal_call = ecx.expr_call(
             new_decl_span,
             blackbox_call_expr.clone(),
@@ -548,6 +550,7 @@ mod llvm_enzyme {
         sig_span: Span,
         idents: Vec<Ident>,
         errored: bool,
+        generics: &Generics,
     ) -> P<ast::Block> {
         let new_decl_span = d_sig.span;
 
@@ -568,6 +571,7 @@ mod llvm_enzyme {
             new_decl_span,
             &idents,
             errored,
+            generics,
         );
 
         if !has_ret(&d_sig.decl.output) {
@@ -610,7 +614,6 @@ mod llvm_enzyme {
                 panic!("Did not expect Default ret ty: {:?}", span);
             }
         };
-
         if x.mode.is_fwd() {
             // Fwd mode is easy. If the return activity is Const, we support arbitrary types.
             // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
@@ -670,8 +673,10 @@ mod llvm_enzyme {
         span: Span,
         primal: Ident,
         idents: &[Ident],
+        generics: &Generics,
     ) -> P<ast::Expr> {
         let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
+
         if has_self {
             let args: ThinVec<_> =
                 idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
@@ -680,7 +685,51 @@ mod llvm_enzyme {
         } else {
             let args: ThinVec<_> =
                 idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
-            let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
+            let mut primal_path = ecx.path_ident(span, primal);
+
+            let is_generic = !generics.params.is_empty();
+
+            match (is_generic, primal_path.segments.last_mut()) {
+                (true, Some(function_path)) => {
+                    let primal_generic_types = generics
+                        .params
+                        .iter()
+                        .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
+
+                    let generated_generic_types = primal_generic_types
+                        .map(|type_param| {
+                            let generic_param = TyKind::Path(
+                                None,
+                                ast::Path {
+                                    span,
+                                    segments: thin_vec![ast::PathSegment {
+                                        ident: type_param.ident,
+                                        args: None,
+                                        id: ast::DUMMY_NODE_ID,
+                                    }],
+                                    tokens: None,
+                                },
+                            );
+
+                            ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
+                                id: type_param.id,
+                                span,
+                                kind: generic_param,
+                                tokens: None,
+                            })))
+                        })
+                        .collect();
+
+                    function_path.args =
+                        Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
+                            span,
+                            args: generated_generic_types,
+                        })));
+                }
+                _ => {}
+            }
+
+            let primal_call_expr = ecx.expr_path(primal_path);
             ecx.expr_call(span, primal_call_expr, args)
         }
     }
diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp
index 2fa122a618d..8253603e807 100644
--- a/tests/pretty/autodiff/autodiff_forward.pp
+++ b/tests/pretty/autodiff/autodiff_forward.pp
@@ -191,8 +191,8 @@ pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
 pub fn d_square<T: std::ops::Mul<Output = T> +
     Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
     unsafe { asm!("NOP", options(pure, nomem)); };
-    ::core::hint::black_box(f10(x));
+    ::core::hint::black_box(f10::<T>(x));
     ::core::hint::black_box((dx_0, dret));
-    ::core::hint::black_box(f10(x))
+    ::core::hint::black_box(f10::<T>(x))
 }
 fn main() {}