about summary refs log tree commit diff
path: root/compiler/rustc_builtin_macros/src/autodiff.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_builtin_macros/src/autodiff.rs')
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs140
1 files changed, 109 insertions, 31 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 8c5c20c7af4..dc3bb8ab52a 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -73,10 +73,10 @@ mod llvm_enzyme {
     }
 
     // Get information about the function the macro is applied to
-    fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
+    fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
         match &iitem.kind {
-            ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
-                Some((iitem.vis.clone(), sig.clone(), ident.clone()))
+            ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
+                Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
             }
             _ => None,
         }
@@ -86,27 +86,23 @@ mod llvm_enzyme {
         ecx: &mut ExtCtxt<'_>,
         meta_item: &ThinVec<MetaItemInner>,
         has_ret: bool,
+        mode: DiffMode,
     ) -> AutoDiffAttrs {
         let dcx = ecx.sess.dcx();
-        let mode = name(&meta_item[1]);
-        let Ok(mode) = DiffMode::from_str(&mode) else {
-            dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
-            return AutoDiffAttrs::error();
-        };
 
         // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
         // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
-        let mut first_activity = 2;
+        let mut first_activity = 1;
 
-        let width = if let [_, _, x, ..] = &meta_item[..]
+        let width = if let [_, x, ..] = &meta_item[..]
             && let Some(x) = width(x)
         {
-            first_activity = 3;
+            first_activity = 2;
             match x.try_into() {
                 Ok(x) => x,
                 Err(_) => {
                     dcx.emit_err(errors::AutoDiffInvalidWidth {
-                        span: meta_item[2].span(),
+                        span: meta_item[1].span(),
                         width: x,
                     });
                     return AutoDiffAttrs::error();
@@ -165,6 +161,24 @@ mod llvm_enzyme {
         ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
     }
 
+    pub(crate) fn expand_forward(
+        ecx: &mut ExtCtxt<'_>,
+        expand_span: Span,
+        meta_item: &ast::MetaItem,
+        item: Annotatable,
+    ) -> Vec<Annotatable> {
+        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
+    }
+
+    pub(crate) fn expand_reverse(
+        ecx: &mut ExtCtxt<'_>,
+        expand_span: Span,
+        meta_item: &ast::MetaItem,
+        item: Annotatable,
+    ) -> Vec<Annotatable> {
+        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
+    }
+
     /// We expand the autodiff macro to generate a new placeholder function which passes
     /// type-checking and can be called by users. The function body of the placeholder function will
     /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +212,12 @@ mod llvm_enzyme {
     /// ```
     /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
     /// in CI.
-    pub(crate) fn expand(
+    pub(crate) fn expand_with_mode(
         ecx: &mut ExtCtxt<'_>,
         expand_span: Span,
         meta_item: &ast::MetaItem,
         mut item: Annotatable,
+        mode: DiffMode,
     ) -> Vec<Annotatable> {
         if cfg!(not(llvm_enzyme)) {
             ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
@@ -210,16 +225,18 @@ mod llvm_enzyme {
         }
         let dcx = ecx.sess.dcx();
 
-        // first get information about the annotable item:
-        let Some((vis, sig, primal)) = (match &item {
+        // first get information about the annotable item: visibility, signature, name and generic
+        // parameters.
+        // these will be used to generate the differentiated version of the function
+        let Some((vis, sig, primal, generics)) = (match &item {
             Annotatable::Item(iitem) => extract_item_info(iitem),
             Annotatable::Stmt(stmt) => match &stmt.kind {
                 ast::StmtKind::Item(iitem) => extract_item_info(iitem),
                 _ => None,
             },
             Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
-                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
-                    Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
+                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
+                    Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
                 }
                 _ => None,
             },
@@ -243,29 +260,41 @@ mod llvm_enzyme {
         // create TokenStream from vec elemtents:
         // meta_item doesn't have a .tokens field
         let mut ts: Vec<TokenTree> = vec![];
-        if meta_item_vec.len() < 2 {
-            // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
-            // input and output args.
+        if meta_item_vec.len() < 1 {
+            // At the bare minimum, we need a fnc name.
             dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
             return vec![item];
         }
 
-        meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
+        let mode_symbol = match mode {
+            DiffMode::Forward => sym::Forward,
+            DiffMode::Reverse => sym::Reverse,
+            _ => unreachable!("Unsupported mode: {:?}", mode),
+        };
+
+        // Insert mode token
+        let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
+        ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
+        ts.insert(
+            1,
+            TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
+        );
 
         // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
         // If it is not given, we default to 1 (scalar mode).
         let start_position;
         let kind: LitKind = LitKind::Integer;
         let symbol;
-        if meta_item_vec.len() >= 3
-            && let Some(width) = width(&meta_item_vec[2])
+        if meta_item_vec.len() >= 2
+            && let Some(width) = width(&meta_item_vec[1])
         {
-            start_position = 3;
+            start_position = 2;
             symbol = Symbol::intern(&width.to_string());
         } else {
-            start_position = 2;
+            start_position = 1;
             symbol = sym::integer(1);
         }
+
         let l: Lit = Lit { kind, symbol, suffix: None };
         let t = Token::new(TokenKind::Literal(l), Span::default());
         let comma = Token::new(TokenKind::Comma, Span::default());
@@ -287,7 +316,7 @@ mod llvm_enzyme {
         ts.pop();
         let ts: TokenStream = TokenStream::from_iter(ts);
 
-        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
+        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
         if !x.is_active() {
             // We encountered an error, so we return the original item.
             // This allows us to potentially parse other attributes.
@@ -303,6 +332,7 @@ mod llvm_enzyme {
         let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
         let d_body = gen_enzyme_body(
             ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
+            &generics,
         );
 
         // The first element of it is the name of the function to be generated
@@ -310,7 +340,7 @@ mod llvm_enzyme {
             defaultness: ast::Defaultness::Final,
             sig: d_sig,
             ident: first_ident(&meta_item_vec[0]),
-            generics: Generics::default(),
+            generics,
             contract: None,
             body: Some(d_body),
             define_opaque: None,
@@ -475,6 +505,7 @@ mod llvm_enzyme {
         new_decl_span: Span,
         idents: &[Ident],
         errored: bool,
+        generics: &Generics,
     ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
         let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
         let noop = ast::InlineAsm {
@@ -497,7 +528,7 @@ mod llvm_enzyme {
         };
         let unsf_expr = ecx.expr_block(P(unsf_block));
         let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
-        let primal_call = gen_primal_call(ecx, span, primal, idents);
+        let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
         let black_box_primal_call = ecx.expr_call(
             new_decl_span,
             blackbox_call_expr.clone(),
@@ -546,6 +577,7 @@ mod llvm_enzyme {
         sig_span: Span,
         idents: Vec<Ident>,
         errored: bool,
+        generics: &Generics,
     ) -> P<ast::Block> {
         let new_decl_span = d_sig.span;
 
@@ -566,6 +598,7 @@ mod llvm_enzyme {
             new_decl_span,
             &idents,
             errored,
+            generics,
         );
 
         if !has_ret(&d_sig.decl.output) {
@@ -608,7 +641,6 @@ mod llvm_enzyme {
                 panic!("Did not expect Default ret ty: {:?}", span);
             }
         };
-
         if x.mode.is_fwd() {
             // Fwd mode is easy. If the return activity is Const, we support arbitrary types.
             // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
@@ -668,8 +700,10 @@ mod llvm_enzyme {
         span: Span,
         primal: Ident,
         idents: &[Ident],
+        generics: &Generics,
     ) -> P<ast::Expr> {
         let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
+
         if has_self {
             let args: ThinVec<_> =
                 idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
@@ -678,7 +712,51 @@ mod llvm_enzyme {
         } else {
             let args: ThinVec<_> =
                 idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
-            let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
+            let mut primal_path = ecx.path_ident(span, primal);
+
+            let is_generic = !generics.params.is_empty();
+
+            match (is_generic, primal_path.segments.last_mut()) {
+                (true, Some(function_path)) => {
+                    let primal_generic_types = generics
+                        .params
+                        .iter()
+                        .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
+
+                    let generated_generic_types = primal_generic_types
+                        .map(|type_param| {
+                            let generic_param = TyKind::Path(
+                                None,
+                                ast::Path {
+                                    span,
+                                    segments: thin_vec![ast::PathSegment {
+                                        ident: type_param.ident,
+                                        args: None,
+                                        id: ast::DUMMY_NODE_ID,
+                                    }],
+                                    tokens: None,
+                                },
+                            );
+
+                            ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
+                                id: type_param.id,
+                                span,
+                                kind: generic_param,
+                                tokens: None,
+                            })))
+                        })
+                        .collect();
+
+                    function_path.args =
+                        Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
+                            span,
+                            args: generated_generic_types,
+                        })));
+                }
+                _ => {}
+            }
+
+            let primal_call_expr = ecx.expr_path(primal_path);
             ecx.expr_call(span, primal_call_expr, args)
         }
     }
@@ -966,4 +1044,4 @@ mod llvm_enzyme {
     }
 }
 
-pub(crate) use llvm_enzyme::expand;
+pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};