diff options
| author | Marcelo DomÃnguez <dmmarcelo27@gmail.com> | 2025-08-14 15:27:57 +0000 |
|---|---|---|
| committer | Marcelo DomÃnguez <dmmarcelo27@gmail.com> | 2025-08-14 16:30:15 +0000 |
| commit | 250d77e5d72fde69a6406050a3b037635f685378 (patch) | |
| tree | 67749136fca27852b5fb784c864f7d3564a42a09 /compiler/rustc_builtin_macros/src | |
| parent | 5c631041aa0b0ad9e161b966b78e6dfdb8011023 (diff) | |
| download | rust-250d77e5d72fde69a6406050a3b037635f685378.tar.gz rust-250d77e5d72fde69a6406050a3b037635f685378.zip | |
Complete functionality and general cleanup
Diffstat (limited to 'compiler/rustc_builtin_macros/src')
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 492 |
1 files changed, 91 insertions, 401 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 3f8585d35bc..c260dca87c0 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -15,11 +15,12 @@ mod llvm_enzyme { use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ - self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind, - MetaItemInner, PatKind, QSelf, TyKind, Visibility, + self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode, + FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind, + MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; - use rustc_span::{Ident, Span, Symbol, kw, sym}; + use rustc_span::{Ident, Span, Symbol, sym}; use thin_vec::{ThinVec, thin_vec}; use tracing::{debug, trace}; @@ -179,11 +180,8 @@ mod llvm_enzyme { } /// We expand the autodiff macro to generate a new placeholder function which passes - /// type-checking and can be called by users. The function body of the placeholder function will - /// later be replaced on LLVM-IR level, so the design of the body is less important and for now - /// should just prevent early inlining and optimizations which alter the function signature. - /// The exact signature of the generated function depends on the configuration provided by the - /// user, but here is an example: + /// type-checking and can be called by users. The exact signature of the generated function + /// depends on the configuration provided by the user, but here is an example: /// /// ``` /// #[autodiff(cos_box, Reverse, Duplicated, Active)] @@ -199,14 +197,8 @@ mod llvm_enzyme { /// f32::sin(**x) /// } /// #[rustc_autodiff(Reverse, Duplicated, Active)] - /// #[inline(never)] /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 { - /// unsafe { - /// asm!("NOP"); - /// }; - /// ::core::hint::black_box(sin(x)); - /// ::core::hint::black_box((dx, dret)); - /// ::core::hint::black_box(sin(x)) + /// std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret)) /// } /// ``` /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked @@ -227,16 +219,24 @@ mod llvm_enzyme { // first get information about the annotable item: visibility, signature, name and generic // parameters. // these will be used to generate the differentiated version of the function - let Some((vis, sig, primal, generics)) = (match &item { - Annotatable::Item(iitem) => extract_item_info(iitem), + let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item { + Annotatable::Item(iitem) => { + extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false)) + } Annotatable::Stmt(stmt) => match &stmt.kind { - ast::StmtKind::Item(iitem) => extract_item_info(iitem), + ast::StmtKind::Item(iitem) => { + extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false)) + } _ => None, }, - Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) - } + Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some(( + assoc_item.vis.clone(), + sig.clone(), + ident.clone(), + generics.clone(), + *of_trait, + )), _ => None, }, _ => None, @@ -254,7 +254,6 @@ mod llvm_enzyme { }; let has_ret = has_ret(&sig.decl.output); - let sig_span = ecx.with_call_site_ctxt(sig.span); // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field @@ -323,28 +322,27 @@ mod llvm_enzyme { } let span = ecx.with_def_site_ctxt(expand_span); - let n_active: u32 = x - .input_activity - .iter() - .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) - .count() as u32; - let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - - // TODO(Sa4dUs): Remove this and all the related logic - let _d_body = gen_enzyme_body( - ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, - &generics, - ); + let d_sig = gen_enzyme_decl(ecx, &sig, &x, span); - let d_body = - call_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig); + let d_body = ecx.block( + span, + thin_vec![call_autodiff( + ecx, + primal, + first_ident(&meta_item_vec[0]), + span, + &d_sig, + &generics, + impl_of_trait, + )], + ); // The first element of it is the name of the function to be generated - let asdf = Box::new(ast::Fn { + let d_fn = Box::new(ast::Fn { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics: generics.clone(), + generics, contract: None, body: Some(d_body), define_opaque: None, @@ -433,13 +431,11 @@ mod llvm_enzyme { tokens: ts, }); - let vis_clone = vis.clone(); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { - let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); + let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn); let d_fn = Box::new(ast::AssocItem { attrs: thin_vec![d_attr], id: ast::DUMMY_NODE_ID, @@ -451,13 +447,13 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn)); d_fn.vis = vis; Annotatable::Stmt(Box::new(ast::Stmt { @@ -471,9 +467,7 @@ mod llvm_enzyme { } }; - let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone); - - return vec![orig_annotatable, dummy_const_annotatable, d_annotatable]; + return vec![orig_annotatable, d_annotatable]; } // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be @@ -504,9 +498,11 @@ mod llvm_enzyme { diff: Ident, span: Span, d_sig: &FnSig, - ) -> P<ast::Block> { - let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal)); - let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff)); + generics: &Generics, + is_impl: bool, + ) -> rustc_ast::Stmt { + let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl); + let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl); let tuple_expr = ecx.expr_tuple( span, @@ -522,371 +518,65 @@ mod llvm_enzyme { .into(), ); - let enzyme_path = ecx.path( - span, - vec![ - Ident::from_str("std"), - Ident::from_str("intrinsics"), - Ident::from_str("autodiff"), - ], - ); + let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::autodiff]); + let enzyme_path = ecx.path(span, enzyme_path_idents); let call_expr = ecx.expr_call( span, ecx.expr_path(enzyme_path), vec![primal_path_expr, diff_path_expr, tuple_expr].into(), ); - let block = ecx.block_expr(call_expr); - - block - } - - // Generate dummy const to prevent primal function - // from being optimized away before applying enzyme - // ``` - // const _: () = - // { - // #[used] - // pub static DUMMY_PTR: fn_type = primal_fn; - // }; - // ``` - fn gen_dummy_const( - ecx: &ExtCtxt<'_>, - span: Span, - primal: Ident, - sig: FnSig, - generics: Generics, - vis: Visibility, - ) -> Annotatable { - // #[used] - let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used))); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); - let used_attr = outer_normal_attr(&used_attr, new_id, span); - - // static DUMMY_PTR: <fn_type> = <primal_ident> - let static_ident = Ident::from_str_and_span("DUMMY_PTR", span); - let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy { - safety: sig.header.safety, - ext: sig.header.ext, - generic_params: generics.params, - decl: sig.decl, - decl_span: sig.span, - })); - let static_ty = ecx.ty(span, fn_ptr_ty); - - let static_expr = ecx.expr_path(ecx.path(span, vec![primal])); - let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem { - ident: static_ident, - ty: static_ty, - safety: ast::Safety::Default, - mutability: ast::Mutability::Not, - expr: Some(static_expr), - define_opaque: None, - })); - - let static_item = ast::Item { - attrs: thin_vec![used_attr], - id: ast::DUMMY_NODE_ID, - span, - vis, - kind: static_item_kind, - tokens: None, - }; - - let block_expr = ecx.expr_block(Box::new(ast::Block { - stmts: thin_vec![ecx.stmt_item(span, P(static_item))], - id: ast::DUMMY_NODE_ID, - rules: ast::BlockCheckMode::Default, - span, - tokens: None, - })); - - let const_item = ecx.item_const( - span, - Ident::from_str_and_span("_", span), - ecx.ty(span, ast::TyKind::Tup(thin_vec![])), - block_expr, - ); - - Annotatable::Item(const_item) + ecx.stmt_expr(call_expr) } - // Will generate a body of the type: - // ``` - // { - // unsafe { - // asm!("NOP"); - // } - // ::core::hint::black_box(primal(args)); - // ::core::hint::black_box((args, ret)); - // <This part remains to be done by following function> - // } - // ``` - fn init_body_helper( + // Generate turbofish expression from fn name and generics + // Given `foo` and `<A, B, C>` params, gen `foo::<A, B, C>` + // We use this expression when passing primal and diff function to the autodiff intrinsic + fn gen_turbofish_expr( ecx: &ExtCtxt<'_>, - span: Span, - primal: Ident, - new_names: &[String], - sig_span: Span, - new_decl_span: Span, - idents: &[Ident], - errored: bool, + ident: Ident, generics: &Generics, - ) -> (Box<ast::Block>, Box<ast::Expr>, Box<ast::Expr>, Box<ast::Expr>) { - let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); - let noop = ast::InlineAsm { - asm_macro: ast::AsmMacro::Asm, - template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())], - template_strs: Box::new([]), - operands: vec![], - clobber_abis: vec![], - options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM, - line_spans: vec![], - }; - let noop_expr = ecx.expr_asm(span, Box::new(noop)); - let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated); - let unsf_block = ast::Block { - stmts: thin_vec![ecx.stmt_semi(noop_expr)], - id: ast::DUMMY_NODE_ID, - tokens: None, - rules: unsf, - span, - }; - let unsf_expr = ecx.expr_block(Box::new(unsf_block)); - let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); - let primal_call = gen_primal_call(ecx, span, primal, idents, generics); - let black_box_primal_call = ecx.expr_call( - new_decl_span, - blackbox_call_expr.clone(), - thin_vec![primal_call.clone()], - ); - let tup_args = new_names - .iter() - .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) - .collect(); - - let black_box_remaining_args = ecx.expr_call( - sig_span, - blackbox_call_expr.clone(), - thin_vec![ecx.expr_tuple(sig_span, tup_args)], - ); - - let mut body = ecx.block(span, ThinVec::new()); - body.stmts.push(ecx.stmt_semi(unsf_expr)); - - // This uses primal args which won't be available if we errored before - if !errored { - body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); - } - body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); - - (body, primal_call, black_box_primal_call, blackbox_call_expr) - } - - /// We only want this function to type-check, since we will replace the body - /// later on llvm level. Using `loop {}` does not cover all return types anymore, - /// so instead we manually build something that should pass the type checker. - /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining - /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another - /// bug would ever try to accidentally differentiate this placeholder function body. - /// Finally, we also add back_box usages of all input arguments, to prevent rustc - /// from optimizing any arguments away. - fn gen_enzyme_body( - ecx: &ExtCtxt<'_>, - x: &AutoDiffAttrs, - n_active: u32, - sig: &ast::FnSig, - d_sig: &ast::FnSig, - primal: Ident, - new_names: &[String], span: Span, - sig_span: Span, - idents: Vec<Ident>, - errored: bool, - generics: &Generics, - ) -> Box<ast::Block> { - let new_decl_span = d_sig.span; - - // Just adding some default inline-asm and black_box usages to prevent early inlining - // and optimizations which alter the function signature. - // - // The bb_primal_call is the black_box call of the primal function. We keep it around, - // since it has the convenient property of returning the type of the primal function, - // Remember, we only care to match types here. - // No matter which return we pick, we always wrap it into a std::hint::black_box call, - // to prevent rustc from propagating it into the caller. - let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper( - ecx, - span, - primal, - new_names, - sig_span, - new_decl_span, - &idents, - errored, - generics, - ); - - if !has_ret(&d_sig.decl.output) { - // there is no return type that we have to match, () works fine. - return body; - } - - // Everything from here onwards just tries to fulfil the return type. Fun! - - // having an active-only return means we'll drop the original return type. - // So that can be treated identical to not having one in the first place. - let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret(); - - if primal_ret && n_active == 0 && x.mode.is_rev() { - // We only have the primal ret. - body.stmts.push(ecx.stmt_expr(bb_primal_call)); - return body; - } - - if !primal_ret && n_active == 1 { - // Again no tuple return, so return default float val. - let ty = match d_sig.decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); + is_impl: bool, + ) -> Box<ast::Expr> { + let generic_args = generics + .params + .iter() + .filter_map(|p| match &p.kind { + GenericParamKind::Type { .. } => { + let path = ast::Path::from_ident(p.ident); + let ty = ecx.ty_path(path); + Some(AngleBracketedArg::Arg(GenericArg::Type(ty))) } - }; - let arg = ty.kind.is_simple_path().unwrap(); - let tmp = ecx.def_site_path(&[arg, kw::Default]); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - body.stmts.push(ecx.stmt_expr(default_call_expr)); - return body; - } - - let mut exprs: Box<ast::Expr> = primal_call; - let d_ret_ty = match d_sig.decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } - }; - if x.mode.is_fwd() { - // Fwd mode is easy. If the return activity is Const, we support arbitrary types. - // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. - // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function. - // In all three cases, we can return `std::hint::black_box(<T>::default())`. - if x.ret_activity == DiffActivity::Const { - // Here we call the primal function, since our dummy function has the same return - // type due to the Const return activity. - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); - } else { - let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 }; - let y = ExprKind::Path( - Some(Box::new(q)), - ecx.path_ident(span, Ident::with_dummy_span(kw::Default)), - ); - let default_call_expr = ecx.expr(span, y); - let default_call_expr = - ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]); - } - } else if x.mode.is_rev() { - if x.width == 1 { - // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`. - match d_ret_ty.kind { - TyKind::Tup(ref args) => { - // We have a tuple return type. We need to create a tuple of the same size - // and fill it with default values. - let mut exprs2 = thin_vec![exprs]; - for arg in args.iter().skip(1) { - let arg = arg.kind.is_simple_path().unwrap(); - let tmp = ecx.def_site_path(&[arg, kw::Default]); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = - ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs2.push(default_call_expr); - } - exprs = ecx.expr_tuple(new_decl_span, exprs2); - } - _ => { - // Interestingly, even the `-> ArbitraryType` case - // ends up getting matched and handled correctly above, - // so we don't have to handle any other case for now. - panic!("Unsupported return type: {:?}", d_ret_ty); - } + GenericParamKind::Const { .. } => { + let expr = ecx.expr_path(ast::Path::from_ident(p.ident)); + let anon_const = AnonConst { id: ast::DUMMY_NODE_ID, value: expr }; + Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const))) } - } - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); - } else { - unreachable!("Unsupported mode: {:?}", x.mode); - } - - body.stmts.push(ecx.stmt_expr(exprs)); + GenericParamKind::Lifetime { .. } => None, + }) + .collect::<ThinVec<_>>(); - body - } + let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args }; - fn gen_primal_call( - ecx: &ExtCtxt<'_>, - span: Span, - primal: Ident, - idents: &[Ident], - generics: &Generics, - ) -> Box<ast::Expr> { - let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; + let segment = PathSegment { + ident, + id: ast::DUMMY_NODE_ID, + args: Some(Box::new(GenericArgs::AngleBracketed(args))), + }; - if has_self { - let args: ThinVec<_> = - idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let self_expr = ecx.expr_self(span); - ecx.expr_method_call(span, self_expr, primal, args) + let segments = if is_impl { + thin_vec![ + PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None }, + segment, + ] } else { - let args: ThinVec<_> = - idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let mut primal_path = ecx.path_ident(span, primal); - - let is_generic = !generics.params.is_empty(); - - match (is_generic, primal_path.segments.last_mut()) { - (true, Some(function_path)) => { - let primal_generic_types = generics - .params - .iter() - .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. })); - - let generated_generic_types = primal_generic_types - .map(|type_param| { - let generic_param = TyKind::Path( - None, - ast::Path { - span, - segments: thin_vec![ast::PathSegment { - ident: type_param.ident, - args: None, - id: ast::DUMMY_NODE_ID, - }], - tokens: None, - }, - ); - - ast::AngleBracketedArg::Arg(ast::GenericArg::Type(Box::new(ast::Ty { - id: type_param.id, - span, - kind: generic_param, - tokens: None, - }))) - }) - .collect(); - - function_path.args = - Some(Box::new(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs { - span, - args: generated_generic_types, - }))); - } - _ => {} - } + thin_vec![segment] + }; - let primal_call_expr = ecx.expr_path(primal_path); - ecx.expr_call(span, primal_call_expr, args) - } + let path = Path { span, segments, tokens: None }; + + ecx.expr_path(path) } // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must @@ -905,7 +595,7 @@ mod llvm_enzyme { sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, - ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) { + ) -> ast::FnSig { let dcx = ecx.sess.dcx(); let has_ret = has_ret(&sig.decl.output); let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 }; @@ -917,7 +607,7 @@ mod llvm_enzyme { found: num_activities, }); // This is not the right signature, but we can continue parsing. - return (sig.clone(), vec![], vec![], true); + return sig.clone(); } assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(has_ret == x.has_ret_activity()); @@ -960,7 +650,7 @@ mod llvm_enzyme { if errors { // This is not the right signature, but we can continue parsing. - return (sig.clone(), new_inputs, idents, true); + return sig.clone(); } let unsafe_activities = x @@ -1174,7 +864,7 @@ mod llvm_enzyme { } let d_sig = FnSig { header: d_header, decl: d_decl, span }; trace!("Generated signature: {:?}", d_sig); - (d_sig, new_inputs, idents, false) + d_sig } } |
