about summary refs log tree commit diff
path: root/compiler/rustc_builtin_macros
diff options
context:
space:
mode:
authorMarcelo Domínguez <dmmarcelo27@gmail.com>2025-08-14 15:22:45 +0000
committerMarcelo Domínguez <dmmarcelo27@gmail.com>2025-08-14 16:29:58 +0000
commit5c631041aa0b0ad9e161b966b78e6dfdb8011023 (patch)
treecde51a410c02793fd39aa65fc5d68306666b469b /compiler/rustc_builtin_macros
parent30017c36d6b5e3382ee7cf018d330a6a4a937d39 (diff)
downloadrust-5c631041aa0b0ad9e161b966b78e6dfdb8011023.tar.gz
rust-5c631041aa0b0ad9e161b966b78e6dfdb8011023.zip
Basic implementation of `autodiff` intrinsic
Diffstat (limited to 'compiler/rustc_builtin_macros')
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs139
1 files changed, 133 insertions, 6 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index a662840eda5..3f8585d35bc 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -329,17 +329,22 @@ mod llvm_enzyme {
             .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
             .count() as u32;
         let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
-        let d_body = gen_enzyme_body(
+
+        // TODO(Sa4dUs): Remove this and all the related logic
+        let _d_body = gen_enzyme_body(
             ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
             &generics,
         );
 
+        let d_body =
+            call_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig);
+
         // The first element of it is the name of the function to be generated
         let asdf = Box::new(ast::Fn {
             defaultness: ast::Defaultness::Final,
             sig: d_sig,
             ident: first_ident(&meta_item_vec[0]),
-            generics,
+            generics: generics.clone(),
             contract: None,
             body: Some(d_body),
             define_opaque: None,
@@ -428,12 +433,15 @@ mod llvm_enzyme {
             tokens: ts,
         });
 
+        let vis_clone = vis.clone();
+
+        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
         let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
         let d_annotatable = match &item {
             Annotatable::AssocItem(_, _) => {
                 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
                 let d_fn = Box::new(ast::AssocItem {
-                    attrs: thin_vec![d_attr, inline_never],
+                    attrs: thin_vec![d_attr],
                     id: ast::DUMMY_NODE_ID,
                     span,
                     vis,
@@ -443,13 +451,13 @@ mod llvm_enzyme {
                 Annotatable::AssocItem(d_fn, Impl { of_trait: false })
             }
             Annotatable::Item(_) => {
-                let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
+                let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
                 d_fn.vis = vis;
 
                 Annotatable::Item(d_fn)
             }
             Annotatable::Stmt(_) => {
-                let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
+                let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
                 d_fn.vis = vis;
 
                 Annotatable::Stmt(Box::new(ast::Stmt {
@@ -463,7 +471,9 @@ mod llvm_enzyme {
             }
         };
 
-        return vec![orig_annotatable, d_annotatable];
+        let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone);
+
+        return vec![orig_annotatable, dummy_const_annotatable, d_annotatable];
     }
 
     // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
@@ -484,6 +494,123 @@ mod llvm_enzyme {
         ty
     }
 
+    // Generate `autodiff` intrinsic call
+    // ```
+    // std::intrinsics::autodiff(source, diff, (args))
+    // ```
+    fn call_autodiff(
+        ecx: &ExtCtxt<'_>,
+        primal: Ident,
+        diff: Ident,
+        span: Span,
+        d_sig: &FnSig,
+    ) -> P<ast::Block> {
+        let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal));
+        let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff));
+
+        let tuple_expr = ecx.expr_tuple(
+            span,
+            d_sig
+                .decl
+                .inputs
+                .iter()
+                .map(|arg| match arg.pat.kind {
+                    PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
+                    _ => todo!(),
+                })
+                .collect::<ThinVec<_>>()
+                .into(),
+        );
+
+        let enzyme_path = ecx.path(
+            span,
+            vec![
+                Ident::from_str("std"),
+                Ident::from_str("intrinsics"),
+                Ident::from_str("autodiff"),
+            ],
+        );
+        let call_expr = ecx.expr_call(
+            span,
+            ecx.expr_path(enzyme_path),
+            vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
+        );
+
+        let block = ecx.block_expr(call_expr);
+
+        block
+    }
+
+    // Generate dummy const to prevent primal function
+    // from being optimized away before applying enzyme
+    // ```
+    // const _: () =
+    // {
+    //     #[used]
+    //     pub static DUMMY_PTR: fn_type = primal_fn;
+    // };
+    // ```
+    fn gen_dummy_const(
+        ecx: &ExtCtxt<'_>,
+        span: Span,
+        primal: Ident,
+        sig: FnSig,
+        generics: Generics,
+        vis: Visibility,
+    ) -> Annotatable {
+        // #[used]
+        let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used)));
+        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
+        let used_attr = outer_normal_attr(&used_attr, new_id, span);
+
+        // static DUMMY_PTR: <fn_type> = <primal_ident>
+        let static_ident = Ident::from_str_and_span("DUMMY_PTR", span);
+        let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy {
+            safety: sig.header.safety,
+            ext: sig.header.ext,
+            generic_params: generics.params,
+            decl: sig.decl,
+            decl_span: sig.span,
+        }));
+        let static_ty = ecx.ty(span, fn_ptr_ty);
+
+        let static_expr = ecx.expr_path(ecx.path(span, vec![primal]));
+        let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem {
+            ident: static_ident,
+            ty: static_ty,
+            safety: ast::Safety::Default,
+            mutability: ast::Mutability::Not,
+            expr: Some(static_expr),
+            define_opaque: None,
+        }));
+
+        let static_item = ast::Item {
+            attrs: thin_vec![used_attr],
+            id: ast::DUMMY_NODE_ID,
+            span,
+            vis,
+            kind: static_item_kind,
+            tokens: None,
+        };
+
+        let block_expr = ecx.expr_block(Box::new(ast::Block {
+            stmts: thin_vec![ecx.stmt_item(span, P(static_item))],
+            id: ast::DUMMY_NODE_ID,
+            rules: ast::BlockCheckMode::Default,
+            span,
+            tokens: None,
+        }));
+
+        let const_item = ecx.item_const(
+            span,
+            Ident::from_str_and_span("_", span),
+            ecx.ty(span, ast::TyKind::Tup(thin_vec![])),
+            block_expr,
+        );
+
+        Annotatable::Item(const_item)
+    }
+
     // Will generate a body of the type:
     // ```
     // {