diff options
Diffstat (limited to 'compiler/rustc_builtin_macros/src/autodiff.rs')
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 339 |
1 files changed, 213 insertions, 126 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 8937d35d53a..7f99f75b2b9 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -12,12 +12,12 @@ mod llvm_enzyme { valid_ty_for_activity, }; use rustc_ast::ptr::P; - use rustc_ast::token::{Token, TokenKind}; + use rustc_ast::token::{Lit, LitKind, Token, TokenKind}; use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ - self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner, - PatKind, TyKind, + self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind, + MetaItemInner, PatKind, QSelf, TyKind, }; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::{Ident, Span, Symbol, kw, sym}; @@ -45,6 +45,16 @@ mod llvm_enzyme { } } fn first_ident(x: &MetaItemInner) -> rustc_span::Ident { + if let Some(l) = x.lit() { + match l.kind { + ast::LitKind::Int(val, _) => { + // get an Ident from a lit + return rustc_span::Ident::from_str(val.get().to_string().as_str()); + } + _ => {} + } + } + let segments = &x.meta_item().unwrap().path.segments; assert!(segments.len() == 1); segments[0].ident @@ -54,6 +64,14 @@ mod llvm_enzyme { first_ident(x).name.to_string() } + fn width(x: &MetaItemInner) -> Option<u128> { + let lit = x.lit()?; + match lit.kind { + ast::LitKind::Int(x, _) => Some(x.get()), + _ => return None, + } + } + pub(crate) fn from_ast( ecx: &mut ExtCtxt<'_>, meta_item: &ThinVec<MetaItemInner>, @@ -65,9 +83,32 @@ mod llvm_enzyme { dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode }); return AutoDiffAttrs::error(); }; + + // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode. + // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1. + let mut first_activity = 2; + + let width = if let [_, _, x, ..] = &meta_item[..] + && let Some(x) = width(x) + { + first_activity = 3; + match x.try_into() { + Ok(x) => x, + Err(_) => { + dcx.emit_err(errors::AutoDiffInvalidWidth { + span: meta_item[2].span(), + width: x, + }); + return AutoDiffAttrs::error(); + } + } + } else { + 1 + }; + let mut activities: Vec<DiffActivity> = vec![]; let mut errors = false; - for x in &meta_item[2..] { + for x in &meta_item[first_activity..] { let activity_str = name(&x); let res = DiffActivity::from_str(&activity_str); match res { @@ -98,7 +139,20 @@ mod llvm_enzyme { (&DiffActivity::None, activities.as_slice()) }; - AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() } + AutoDiffAttrs { + mode, + width, + ret_activity: *ret_activity, + input_activity: input_activity.to_vec(), + } + } + + fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) { + let comma: Token = Token::new(TokenKind::Comma, Span::default()); + let val = first_ident(t); + let t = Token::from_ast_ident(val); + ts.push(TokenTree::Token(t, Spacing::Joint)); + ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); } /// We expand the autodiff macro to generate a new placeholder function which passes @@ -195,27 +249,49 @@ mod llvm_enzyme { // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field - let comma: Token = Token::new(TokenKind::Comma, Span::default()); let mut ts: Vec<TokenTree> = vec![]; if meta_item_vec.len() < 2 { // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no // input and output args. dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() }); return vec![item]; + } + + meta_item_inner_to_ts(&meta_item_vec[1], &mut ts); + + // Now, if the user gave a width (vector aka batch-mode ad), then we copy it. + // If it is not given, we default to 1 (scalar mode). + let start_position; + let kind: LitKind = LitKind::Integer; + let symbol; + if meta_item_vec.len() >= 3 + && let Some(width) = width(&meta_item_vec[2]) + { + start_position = 3; + symbol = Symbol::intern(&width.to_string()); } else { - for t in meta_item_vec.clone()[1..].iter() { - let val = first_ident(t); - let t = Token::from_ast_ident(val); - ts.push(TokenTree::Token(t, Spacing::Joint)); - ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); - } + start_position = 2; + symbol = sym::integer(1); } + let l: Lit = Lit { kind, symbol, suffix: None }; + let t = Token::new(TokenKind::Literal(l), Span::default()); + let comma = Token::new(TokenKind::Comma, Span::default()); + ts.push(TokenTree::Token(t, Spacing::Joint)); + ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); + + for t in meta_item_vec.clone()[start_position..].iter() { + meta_item_inner_to_ts(t, &mut ts); + } + if !has_ret { // We don't want users to provide a return activity if the function doesn't return anything. // For simplicity, we just add a dummy token to the end of the list. let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default()); ts.push(TokenTree::Token(t, Spacing::Joint)); + ts.push(TokenTree::Token(comma, Spacing::Alone)); } + // We remove the last, trailing comma. + ts.pop(); let ts: TokenStream = TokenStream::from_iter(ts); let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret); @@ -470,6 +546,8 @@ mod llvm_enzyme { return body; } + // Everything from here onwards just tries to fullfil the return type. Fun! + // having an active-only return means we'll drop the original return type. // So that can be treated identical to not having one in the first place. let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret(); @@ -497,86 +575,65 @@ mod llvm_enzyme { return body; } - let mut exprs = ThinVec::<P<ast::Expr>>::new(); - if primal_ret { - // We have both primal ret and active floats. - // primal ret is first, by construction. - exprs.push(primal_call); - } - - // Now construct default placeholder for each active float. - // Is there something nicer than f32::default() and f64::default()? + let mut exprs: P<ast::Expr> = primal_call.clone(); let d_ret_ty = match d_sig.decl.output { FnRetTy::Ty(ref ty) => ty.clone(), FnRetTy::Default(span) => { panic!("Did not expect Default ret ty: {:?}", span); } }; - let mut d_ret_ty = match d_ret_ty.kind.clone() { - TyKind::Tup(ref tys) => tys.clone(), - TyKind::Path(_, rustc_ast::Path { segments, .. }) => { - if let [segment] = &segments[..] - && segment.args.is_none() - { - let id = vec![segments[0].ident]; - let kind = TyKind::Path(None, ecx.path(span, id)); - let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None }); - thin_vec![ty] - } else { - panic!("Expected tuple or simple path return type"); - } - } - _ => { - // We messed up construction of d_sig - panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty); - } - }; - - if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual { - assert!(d_ret_ty.len() == 2); - // both should be identical, by construction - let arg = d_ret_ty[0].kind.is_simple_path().unwrap(); - let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap(); - assert!(arg == arg2); - let sl: Vec<Symbol> = vec![arg, kw::Default]; - let tmp = ecx.def_site_path(&sl); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs.push(default_call_expr); - } else if x.mode.is_rev() { - if primal_ret { - // We have extra handling above for the primal ret - d_ret_ty = d_ret_ty[1..].to_vec().into(); - } - for arg in d_ret_ty.iter() { - let arg = arg.kind.is_simple_path().unwrap(); - let sl: Vec<Symbol> = vec![arg, kw::Default]; - let tmp = ecx.def_site_path(&sl); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + if x.mode.is_fwd() { + // Fwd mode is easy. If the return activity is Const, we support arbitrary types. + // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. + // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function. + // In all three cases, we can return `std::hint::black_box(<T>::default())`. + if x.ret_activity == DiffActivity::Const { + // Here we call the primal function, since our dummy function has the same return + // type due to the Const return activity. + exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); + } else { + let q = QSelf { ty: d_ret_ty.clone(), path_span: span, position: 0 }; + let y = + ExprKind::Path(Some(P(q)), ecx.path_ident(span, Ident::from_str("default"))); + let default_call_expr = ecx.expr(span, y); let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs.push(default_call_expr); - } - } - - let ret: P<ast::Expr>; - match &exprs[..] { - [] => { - assert!(!has_ret(&d_sig.decl.output)); - // We don't have to match the return type. - return body; - } - [arg] => { - ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]); + exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]); } - args => { - let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into()); - ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]); + } else if x.mode.is_rev() { + if x.width == 1 { + // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`. + match d_ret_ty.kind { + TyKind::Tup(ref args) => { + // We have a tuple return type. We need to create a tuple of the same size + // and fill it with default values. + let mut exprs2 = thin_vec![exprs]; + for arg in args.iter().skip(1) { + let arg = arg.kind.is_simple_path().unwrap(); + let sl: Vec<Symbol> = vec![arg, kw::Default]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = + ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + exprs2.push(default_call_expr); + } + exprs = ecx.expr_tuple(new_decl_span, exprs2); + } + _ => { + // Interestingly, even the `-> ArbitraryType` case + // ends up getting matched and handled correctly above, + // so we don't have to handle any other case for now. + panic!("Unsupported return type: {:?}", d_ret_ty); + } + } } + exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); + } else { + unreachable!("Unsupported mode: {:?}", x.mode); } - assert!(has_ret(&d_sig.decl.output)); - body.stmts.push(ecx.stmt_expr(ret)); + + body.stmts.push(ecx.stmt_expr(exprs)); body } @@ -684,50 +741,55 @@ mod llvm_enzyme { match activity { DiffActivity::Active => { act_ret.push(arg.ty.clone()); + // if width =/= 1, then push [arg.ty; width] to act_ret } DiffActivity::ActiveOnly => { // We will add the active scalar to the return type. // This is handled later. } DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => { - let mut shadow_arg = arg.clone(); - // We += into the shadow in reverse mode. - shadow_arg.ty = P(assure_mut_ref(&arg.ty)); - let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { - ident.name - } else { - debug!("{:#?}", &shadow_arg.pat); - panic!("not an ident?"); - }; - let name: String = format!("d{}", old_name); - new_inputs.push(name.clone()); - let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); - shadow_arg.pat = P(ast::Pat { - id: ast::DUMMY_NODE_ID, - kind: PatKind::Ident(BindingMode::NONE, ident, None), - span: shadow_arg.pat.span, - tokens: shadow_arg.pat.tokens.clone(), - }); - d_inputs.push(shadow_arg); + for i in 0..x.width { + let mut shadow_arg = arg.clone(); + // We += into the shadow in reverse mode. + shadow_arg.ty = P(assure_mut_ref(&arg.ty)); + let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { + ident.name + } else { + debug!("{:#?}", &shadow_arg.pat); + panic!("not an ident?"); + }; + let name: String = format!("d{}_{}", old_name, i); + new_inputs.push(name.clone()); + let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); + shadow_arg.pat = P(ast::Pat { + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingMode::NONE, ident, None), + span: shadow_arg.pat.span, + tokens: shadow_arg.pat.tokens.clone(), + }); + d_inputs.push(shadow_arg.clone()); + } } DiffActivity::Dual | DiffActivity::DualOnly => { - let mut shadow_arg = arg.clone(); - let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { - ident.name - } else { - debug!("{:#?}", &shadow_arg.pat); - panic!("not an ident?"); - }; - let name: String = format!("b{}", old_name); - new_inputs.push(name.clone()); - let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); - shadow_arg.pat = P(ast::Pat { - id: ast::DUMMY_NODE_ID, - kind: PatKind::Ident(BindingMode::NONE, ident, None), - span: shadow_arg.pat.span, - tokens: shadow_arg.pat.tokens.clone(), - }); - d_inputs.push(shadow_arg); + for i in 0..x.width { + let mut shadow_arg = arg.clone(); + let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { + ident.name + } else { + debug!("{:#?}", &shadow_arg.pat); + panic!("not an ident?"); + }; + let name: String = format!("b{}_{}", old_name, i); + new_inputs.push(name.clone()); + let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); + shadow_arg.pat = P(ast::Pat { + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingMode::NONE, ident, None), + span: shadow_arg.pat.span, + tokens: shadow_arg.pat.tokens.clone(), + }); + d_inputs.push(shadow_arg.clone()); + } } DiffActivity::Const => { // Nothing to do here. @@ -783,23 +845,48 @@ mod llvm_enzyme { d_decl.inputs = d_inputs.into(); if x.mode.is_fwd() { + let ty = match d_decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + // We want to return std::hint::black_box(()). + let kind = TyKind::Tup(ThinVec::new()); + let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None }); + d_decl.output = FnRetTy::Ty(ty.clone()); + assert!(matches!(x.ret_activity, DiffActivity::None)); + // this won't be used below, so any type would be fine. + ty + } + }; + if let DiffActivity::Dual = x.ret_activity { - let ty = match d_decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } + let kind = if x.width == 1 { + // Dual can only be used for f32/f64 ret. + // In that case we return now a tuple with two floats. + TyKind::Tup(thin_vec![ty.clone(), ty.clone()]) + } else { + // We have to return [T; width+1], +1 for the primal return. + let anon_const = rustc_ast::AnonConst { + id: ast::DUMMY_NODE_ID, + value: ecx.expr_usize(span, 1 + x.width as usize), + }; + TyKind::Array(ty.clone(), anon_const) }; - // Dual can only be used for f32/f64 ret. - // In that case we return now a tuple with two floats. - let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]); let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }); d_decl.output = FnRetTy::Ty(ty); } if let DiffActivity::DualOnly = x.ret_activity { // No need to change the return type, - // we will just return the shadow in place - // of the primal return. + // we will just return the shadow in place of the primal return. + // However, if we have a width > 1, then we don't return -> T, but -> [T; width] + if x.width > 1 { + let anon_const = rustc_ast::AnonConst { + id: ast::DUMMY_NODE_ID, + value: ecx.expr_usize(span, x.width as usize), + }; + let kind = TyKind::Array(ty.clone(), anon_const); + let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }); + d_decl.output = FnRetTy::Ty(ty); + } } } |
