about summary refs log tree commit diff
path: root/compiler/rustc_builtin_macros/src/autodiff.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_builtin_macros/src/autodiff.rs')
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs339
1 files changed, 213 insertions, 126 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 8937d35d53a..7f99f75b2b9 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -12,12 +12,12 @@ mod llvm_enzyme {
         valid_ty_for_activity,
     };
     use rustc_ast::ptr::P;
-    use rustc_ast::token::{Token, TokenKind};
+    use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
     use rustc_ast::tokenstream::*;
     use rustc_ast::visit::AssocCtxt::*;
     use rustc_ast::{
-        self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
-        PatKind, TyKind,
+        self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
+        MetaItemInner, PatKind, QSelf, TyKind,
     };
     use rustc_expand::base::{Annotatable, ExtCtxt};
     use rustc_span::{Ident, Span, Symbol, kw, sym};
@@ -45,6 +45,16 @@ mod llvm_enzyme {
         }
     }
     fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
+        if let Some(l) = x.lit() {
+            match l.kind {
+                ast::LitKind::Int(val, _) => {
+                    // get an Ident from a lit
+                    return rustc_span::Ident::from_str(val.get().to_string().as_str());
+                }
+                _ => {}
+            }
+        }
+
         let segments = &x.meta_item().unwrap().path.segments;
         assert!(segments.len() == 1);
         segments[0].ident
@@ -54,6 +64,14 @@ mod llvm_enzyme {
         first_ident(x).name.to_string()
     }
 
+    fn width(x: &MetaItemInner) -> Option<u128> {
+        let lit = x.lit()?;
+        match lit.kind {
+            ast::LitKind::Int(x, _) => Some(x.get()),
+            _ => return None,
+        }
+    }
+
     pub(crate) fn from_ast(
         ecx: &mut ExtCtxt<'_>,
         meta_item: &ThinVec<MetaItemInner>,
@@ -65,9 +83,32 @@ mod llvm_enzyme {
             dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
             return AutoDiffAttrs::error();
         };
+
+        // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
+        // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
+        let mut first_activity = 2;
+
+        let width = if let [_, _, x, ..] = &meta_item[..]
+            && let Some(x) = width(x)
+        {
+            first_activity = 3;
+            match x.try_into() {
+                Ok(x) => x,
+                Err(_) => {
+                    dcx.emit_err(errors::AutoDiffInvalidWidth {
+                        span: meta_item[2].span(),
+                        width: x,
+                    });
+                    return AutoDiffAttrs::error();
+                }
+            }
+        } else {
+            1
+        };
+
         let mut activities: Vec<DiffActivity> = vec![];
         let mut errors = false;
-        for x in &meta_item[2..] {
+        for x in &meta_item[first_activity..] {
             let activity_str = name(&x);
             let res = DiffActivity::from_str(&activity_str);
             match res {
@@ -98,7 +139,20 @@ mod llvm_enzyme {
             (&DiffActivity::None, activities.as_slice())
         };
 
-        AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
+        AutoDiffAttrs {
+            mode,
+            width,
+            ret_activity: *ret_activity,
+            input_activity: input_activity.to_vec(),
+        }
+    }
+
+    fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
+        let comma: Token = Token::new(TokenKind::Comma, Span::default());
+        let val = first_ident(t);
+        let t = Token::from_ast_ident(val);
+        ts.push(TokenTree::Token(t, Spacing::Joint));
+        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
     }
 
     /// We expand the autodiff macro to generate a new placeholder function which passes
@@ -195,27 +249,49 @@ mod llvm_enzyme {
 
         // create TokenStream from vec elemtents:
         // meta_item doesn't have a .tokens field
-        let comma: Token = Token::new(TokenKind::Comma, Span::default());
         let mut ts: Vec<TokenTree> = vec![];
         if meta_item_vec.len() < 2 {
             // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
             // input and output args.
             dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
             return vec![item];
+        }
+
+        meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
+
+        // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
+        // If it is not given, we default to 1 (scalar mode).
+        let start_position;
+        let kind: LitKind = LitKind::Integer;
+        let symbol;
+        if meta_item_vec.len() >= 3
+            && let Some(width) = width(&meta_item_vec[2])
+        {
+            start_position = 3;
+            symbol = Symbol::intern(&width.to_string());
         } else {
-            for t in meta_item_vec.clone()[1..].iter() {
-                let val = first_ident(t);
-                let t = Token::from_ast_ident(val);
-                ts.push(TokenTree::Token(t, Spacing::Joint));
-                ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
-            }
+            start_position = 2;
+            symbol = sym::integer(1);
         }
+        let l: Lit = Lit { kind, symbol, suffix: None };
+        let t = Token::new(TokenKind::Literal(l), Span::default());
+        let comma = Token::new(TokenKind::Comma, Span::default());
+        ts.push(TokenTree::Token(t, Spacing::Joint));
+        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
+
+        for t in meta_item_vec.clone()[start_position..].iter() {
+            meta_item_inner_to_ts(t, &mut ts);
+        }
+
         if !has_ret {
             // We don't want users to provide a return activity if the function doesn't return anything.
             // For simplicity, we just add a dummy token to the end of the list.
             let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
             ts.push(TokenTree::Token(t, Spacing::Joint));
+            ts.push(TokenTree::Token(comma, Spacing::Alone));
         }
+        // We remove the last, trailing comma.
+        ts.pop();
         let ts: TokenStream = TokenStream::from_iter(ts);
 
         let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
@@ -470,6 +546,8 @@ mod llvm_enzyme {
             return body;
         }
 
+        // Everything from here onwards just tries to fullfil the return type. Fun!
+
         // having an active-only return means we'll drop the original return type.
         // So that can be treated identical to not having one in the first place.
         let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
@@ -497,86 +575,65 @@ mod llvm_enzyme {
             return body;
         }
 
-        let mut exprs = ThinVec::<P<ast::Expr>>::new();
-        if primal_ret {
-            // We have both primal ret and active floats.
-            // primal ret is first, by construction.
-            exprs.push(primal_call);
-        }
-
-        // Now construct default placeholder for each active float.
-        // Is there something nicer than f32::default() and f64::default()?
+        let mut exprs: P<ast::Expr> = primal_call.clone();
         let d_ret_ty = match d_sig.decl.output {
             FnRetTy::Ty(ref ty) => ty.clone(),
             FnRetTy::Default(span) => {
                 panic!("Did not expect Default ret ty: {:?}", span);
             }
         };
-        let mut d_ret_ty = match d_ret_ty.kind.clone() {
-            TyKind::Tup(ref tys) => tys.clone(),
-            TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
-                if let [segment] = &segments[..]
-                    && segment.args.is_none()
-                {
-                    let id = vec![segments[0].ident];
-                    let kind = TyKind::Path(None, ecx.path(span, id));
-                    let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
-                    thin_vec![ty]
-                } else {
-                    panic!("Expected tuple or simple path return type");
-                }
-            }
-            _ => {
-                // We messed up construction of d_sig
-                panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
-            }
-        };
-
-        if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
-            assert!(d_ret_ty.len() == 2);
-            // both should be identical, by construction
-            let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
-            let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
-            assert!(arg == arg2);
-            let sl: Vec<Symbol> = vec![arg, kw::Default];
-            let tmp = ecx.def_site_path(&sl);
-            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
-            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
-            exprs.push(default_call_expr);
-        } else if x.mode.is_rev() {
-            if primal_ret {
-                // We have extra handling above for the primal ret
-                d_ret_ty = d_ret_ty[1..].to_vec().into();
-            }
 
-            for arg in d_ret_ty.iter() {
-                let arg = arg.kind.is_simple_path().unwrap();
-                let sl: Vec<Symbol> = vec![arg, kw::Default];
-                let tmp = ecx.def_site_path(&sl);
-                let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
+        if x.mode.is_fwd() {
+            // Fwd mode is easy. If the return activity is Const, we support arbitrary types.
+            // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
+            // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
+            // In all three cases, we can return `std::hint::black_box(<T>::default())`.
+            if x.ret_activity == DiffActivity::Const {
+                // Here we call the primal function, since our dummy function has the same return
+                // type due to the Const return activity.
+                exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
+            } else {
+                let q = QSelf { ty: d_ret_ty.clone(), path_span: span, position: 0 };
+                let y =
+                    ExprKind::Path(Some(P(q)), ecx.path_ident(span, Ident::from_str("default")));
+                let default_call_expr = ecx.expr(span, y);
                 let default_call_expr =
                     ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
-                exprs.push(default_call_expr);
-            }
-        }
-
-        let ret: P<ast::Expr>;
-        match &exprs[..] {
-            [] => {
-                assert!(!has_ret(&d_sig.decl.output));
-                // We don't have to match the return type.
-                return body;
-            }
-            [arg] => {
-                ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
+                exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
             }
-            args => {
-                let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
-                ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
+        } else if x.mode.is_rev() {
+            if x.width == 1 {
+                // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
+                match d_ret_ty.kind {
+                    TyKind::Tup(ref args) => {
+                        // We have a tuple return type. We need to create a tuple of the same size
+                        // and fill it with default values.
+                        let mut exprs2 = thin_vec![exprs];
+                        for arg in args.iter().skip(1) {
+                            let arg = arg.kind.is_simple_path().unwrap();
+                            let sl: Vec<Symbol> = vec![arg, kw::Default];
+                            let tmp = ecx.def_site_path(&sl);
+                            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
+                            let default_call_expr =
+                                ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
+                            exprs2.push(default_call_expr);
+                        }
+                        exprs = ecx.expr_tuple(new_decl_span, exprs2);
+                    }
+                    _ => {
+                        // Interestingly, even the `-> ArbitraryType` case
+                        // ends up getting matched and handled correctly above,
+                        // so we don't have to handle any other case for now.
+                        panic!("Unsupported return type: {:?}", d_ret_ty);
+                    }
+                }
             }
+            exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
+        } else {
+            unreachable!("Unsupported mode: {:?}", x.mode);
         }
-        assert!(has_ret(&d_sig.decl.output));
-        body.stmts.push(ecx.stmt_expr(ret));
+
+        body.stmts.push(ecx.stmt_expr(exprs));
 
         body
     }
@@ -684,50 +741,55 @@ mod llvm_enzyme {
             match activity {
                 DiffActivity::Active => {
                     act_ret.push(arg.ty.clone());
+                    // if width =/= 1, then push [arg.ty; width] to act_ret
                 }
                 DiffActivity::ActiveOnly => {
                     // We will add the active scalar to the return type.
                     // This is handled later.
                 }
                 DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
-                    let mut shadow_arg = arg.clone();
-                    // We += into the shadow in reverse mode.
-                    shadow_arg.ty = P(assure_mut_ref(&arg.ty));
-                    let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
-                        ident.name
-                    } else {
-                        debug!("{:#?}", &shadow_arg.pat);
-                        panic!("not an ident?");
-                    };
-                    let name: String = format!("d{}", old_name);
-                    new_inputs.push(name.clone());
-                    let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
-                    shadow_arg.pat = P(ast::Pat {
-                        id: ast::DUMMY_NODE_ID,
-                        kind: PatKind::Ident(BindingMode::NONE, ident, None),
-                        span: shadow_arg.pat.span,
-                        tokens: shadow_arg.pat.tokens.clone(),
-                    });
-                    d_inputs.push(shadow_arg);
+                    for i in 0..x.width {
+                        let mut shadow_arg = arg.clone();
+                        // We += into the shadow in reverse mode.
+                        shadow_arg.ty = P(assure_mut_ref(&arg.ty));
+                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
+                            ident.name
+                        } else {
+                            debug!("{:#?}", &shadow_arg.pat);
+                            panic!("not an ident?");
+                        };
+                        let name: String = format!("d{}_{}", old_name, i);
+                        new_inputs.push(name.clone());
+                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
+                        shadow_arg.pat = P(ast::Pat {
+                            id: ast::DUMMY_NODE_ID,
+                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
+                            span: shadow_arg.pat.span,
+                            tokens: shadow_arg.pat.tokens.clone(),
+                        });
+                        d_inputs.push(shadow_arg.clone());
+                    }
                 }
                 DiffActivity::Dual | DiffActivity::DualOnly => {
-                    let mut shadow_arg = arg.clone();
-                    let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
-                        ident.name
-                    } else {
-                        debug!("{:#?}", &shadow_arg.pat);
-                        panic!("not an ident?");
-                    };
-                    let name: String = format!("b{}", old_name);
-                    new_inputs.push(name.clone());
-                    let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
-                    shadow_arg.pat = P(ast::Pat {
-                        id: ast::DUMMY_NODE_ID,
-                        kind: PatKind::Ident(BindingMode::NONE, ident, None),
-                        span: shadow_arg.pat.span,
-                        tokens: shadow_arg.pat.tokens.clone(),
-                    });
-                    d_inputs.push(shadow_arg);
+                    for i in 0..x.width {
+                        let mut shadow_arg = arg.clone();
+                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
+                            ident.name
+                        } else {
+                            debug!("{:#?}", &shadow_arg.pat);
+                            panic!("not an ident?");
+                        };
+                        let name: String = format!("b{}_{}", old_name, i);
+                        new_inputs.push(name.clone());
+                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
+                        shadow_arg.pat = P(ast::Pat {
+                            id: ast::DUMMY_NODE_ID,
+                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
+                            span: shadow_arg.pat.span,
+                            tokens: shadow_arg.pat.tokens.clone(),
+                        });
+                        d_inputs.push(shadow_arg.clone());
+                    }
                 }
                 DiffActivity::Const => {
                     // Nothing to do here.
@@ -783,23 +845,48 @@ mod llvm_enzyme {
         d_decl.inputs = d_inputs.into();
 
         if x.mode.is_fwd() {
+            let ty = match d_decl.output {
+                FnRetTy::Ty(ref ty) => ty.clone(),
+                FnRetTy::Default(span) => {
+                    // We want to return std::hint::black_box(()).
+                    let kind = TyKind::Tup(ThinVec::new());
+                    let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
+                    d_decl.output = FnRetTy::Ty(ty.clone());
+                    assert!(matches!(x.ret_activity, DiffActivity::None));
+                    // this won't be used below, so any type would be fine.
+                    ty
+                }
+            };
+
             if let DiffActivity::Dual = x.ret_activity {
-                let ty = match d_decl.output {
-                    FnRetTy::Ty(ref ty) => ty.clone(),
-                    FnRetTy::Default(span) => {
-                        panic!("Did not expect Default ret ty: {:?}", span);
-                    }
+                let kind = if x.width == 1 {
+                    // Dual can only be used for f32/f64 ret.
+                    // In that case we return now a tuple with two floats.
+                    TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
+                } else {
+                    // We have to return [T; width+1], +1 for the primal return.
+                    let anon_const = rustc_ast::AnonConst {
+                        id: ast::DUMMY_NODE_ID,
+                        value: ecx.expr_usize(span, 1 + x.width as usize),
+                    };
+                    TyKind::Array(ty.clone(), anon_const)
                 };
-                // Dual can only be used for f32/f64 ret.
-                // In that case we return now a tuple with two floats.
-                let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
                 let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
                 d_decl.output = FnRetTy::Ty(ty);
             }
             if let DiffActivity::DualOnly = x.ret_activity {
                 // No need to change the return type,
-                // we will just return the shadow in place
-                // of the primal return.
+                // we will just return the shadow in place of the primal return.
+                // However, if we have a width > 1, then we don't return -> T, but -> [T; width]
+                if x.width > 1 {
+                    let anon_const = rustc_ast::AnonConst {
+                        id: ast::DUMMY_NODE_ID,
+                        value: ecx.expr_usize(span, x.width as usize),
+                    };
+                    let kind = TyKind::Array(ty.clone(), anon_const);
+                    let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
+                    d_decl.output = FnRetTy::Ty(ty);
+                }
             }
         }