diff options
| author | Marcelo DomÃnguez <dmmarcelo27@gmail.com> | 2025-08-14 15:22:45 +0000 |
|---|---|---|
| committer | Marcelo DomÃnguez <dmmarcelo27@gmail.com> | 2025-08-14 16:29:58 +0000 |
| commit | 5c631041aa0b0ad9e161b966b78e6dfdb8011023 (patch) | |
| tree | cde51a410c02793fd39aa65fc5d68306666b469b /compiler/rustc_builtin_macros | |
| parent | 30017c36d6b5e3382ee7cf018d330a6a4a937d39 (diff) | |
| download | rust-5c631041aa0b0ad9e161b966b78e6dfdb8011023.tar.gz rust-5c631041aa0b0ad9e161b966b78e6dfdb8011023.zip | |
Basic implementation of `autodiff` intrinsic
Diffstat (limited to 'compiler/rustc_builtin_macros')
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 139 |
1 files changed, 133 insertions, 6 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index a662840eda5..3f8585d35bc 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -329,17 +329,22 @@ mod llvm_enzyme { .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) .count() as u32; let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - let d_body = gen_enzyme_body( + + // TODO(Sa4dUs): Remove this and all the related logic + let _d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, &generics, ); + let d_body = + call_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig); + // The first element of it is the name of the function to be generated let asdf = Box::new(ast::Fn { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics, + generics: generics.clone(), contract: None, body: Some(d_body), define_opaque: None, @@ -428,12 +433,15 @@ mod llvm_enzyme { tokens: ts, }); + let vis_clone = vis.clone(); + + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = Box::new(ast::AssocItem { - attrs: thin_vec![d_attr, inline_never], + attrs: thin_vec![d_attr], id: ast::DUMMY_NODE_ID, span, vis, @@ -443,13 +451,13 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(Box::new(ast::Stmt { @@ -463,7 +471,9 @@ mod llvm_enzyme { } }; - return vec![orig_annotatable, d_annotatable]; + let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone); + + return vec![orig_annotatable, dummy_const_annotatable, d_annotatable]; } // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be @@ -484,6 +494,123 @@ mod llvm_enzyme { ty } + // Generate `autodiff` intrinsic call + // ``` + // std::intrinsics::autodiff(source, diff, (args)) + // ``` + fn call_autodiff( + ecx: &ExtCtxt<'_>, + primal: Ident, + diff: Ident, + span: Span, + d_sig: &FnSig, + ) -> P<ast::Block> { + let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff)); + + let tuple_expr = ecx.expr_tuple( + span, + d_sig + .decl + .inputs + .iter() + .map(|arg| match arg.pat.kind { + PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)), + _ => todo!(), + }) + .collect::<ThinVec<_>>() + .into(), + ); + + let enzyme_path = ecx.path( + span, + vec![ + Ident::from_str("std"), + Ident::from_str("intrinsics"), + Ident::from_str("autodiff"), + ], + ); + let call_expr = ecx.expr_call( + span, + ecx.expr_path(enzyme_path), + vec![primal_path_expr, diff_path_expr, tuple_expr].into(), + ); + + let block = ecx.block_expr(call_expr); + + block + } + + // Generate dummy const to prevent primal function + // from being optimized away before applying enzyme + // ``` + // const _: () = + // { + // #[used] + // pub static DUMMY_PTR: fn_type = primal_fn; + // }; + // ``` + fn gen_dummy_const( + ecx: &ExtCtxt<'_>, + span: Span, + primal: Ident, + sig: FnSig, + generics: Generics, + vis: Visibility, + ) -> Annotatable { + // #[used] + let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used))); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let used_attr = outer_normal_attr(&used_attr, new_id, span); + + // static DUMMY_PTR: <fn_type> = <primal_ident> + let static_ident = Ident::from_str_and_span("DUMMY_PTR", span); + let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy { + safety: sig.header.safety, + ext: sig.header.ext, + generic_params: generics.params, + decl: sig.decl, + decl_span: sig.span, + })); + let static_ty = ecx.ty(span, fn_ptr_ty); + + let static_expr = ecx.expr_path(ecx.path(span, vec![primal])); + let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem { + ident: static_ident, + ty: static_ty, + safety: ast::Safety::Default, + mutability: ast::Mutability::Not, + expr: Some(static_expr), + define_opaque: None, + })); + + let static_item = ast::Item { + attrs: thin_vec![used_attr], + id: ast::DUMMY_NODE_ID, + span, + vis, + kind: static_item_kind, + tokens: None, + }; + + let block_expr = ecx.expr_block(Box::new(ast::Block { + stmts: thin_vec![ecx.stmt_item(span, P(static_item))], + id: ast::DUMMY_NODE_ID, + rules: ast::BlockCheckMode::Default, + span, + tokens: None, + })); + + let const_item = ecx.item_const( + span, + Ident::from_str_and_span("_", span), + ecx.ty(span, ast::TyKind::Tup(thin_vec![])), + block_expr, + ); + + Annotatable::Item(const_item) + } + // Will generate a body of the type: // ``` // { |
