From b21c9e7bfb0180b67b486013a7137fb200cb1076 Mon Sep 17 00:00:00 2001 From: Marcelo Domínguez <69964857+Sa4dUs@users.noreply.github.com> Date: Tue, 6 May 2025 09:19:33 +0200 Subject: Split `autodiff` into `autodiff_forward` and `autodiff_reverse` Pending fix. ``` error: cannot find a built-in macro with name `autodiff_forward` --> library\core\src\macros\mod.rs:1542:5 | 1542 | / pub macro autodiff_forward($item:item) { 1543 | | /* compiler built-in */ 1544 | | } | |_____^ error: cannot find a built-in macro with name `autodiff_reverse` --> library\core\src\macros\mod.rs:1549:5 | 1549 | / pub macro autodiff_reverse($item:item) { 1550 | | /* compiler built-in */ 1551 | | } | |_____^ error: could not compile `core` (lib) due to 2 previous errors ``` --- compiler/rustc_span/src/symbol.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'compiler/rustc_span/src') diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index efae6250b07..bb19b5761bb 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -531,7 +531,8 @@ symbols! { audit_that, augmented_assignments, auto_traits, - autodiff, + autodiff_forward, + autodiff_reverse, automatically_derived, avx, avx10_target_feature, -- cgit 1.4.1-3-g733a5 From f92d84cc6e0770c0d8f3dc775bafdf7f8786db61 Mon Sep 17 00:00:00 2001 From: Marcelo Domínguez Date: Sat, 10 May 2025 00:52:47 +0000 Subject: Initial naive implementation using `Symbols` to represent autodiff modes (`Forward`, `Reverse`) Since the mode is no longer part of `meta_item`, we must insert it manually (otherwise macro expansion with `#[rustc_autodiff]` won't work). This can be revised later if a more structured representation becomes necessary (using enums, annotated structs, etc). Some tests are currently failing. I'll address them next. --- compiler/rustc_builtin_macros/src/autodiff.rs | 28 +++++++++++++++++++-------- compiler/rustc_span/src/symbol.rs | 2 ++ 2 files changed, 22 insertions(+), 8 deletions(-) (limited to 'compiler/rustc_span/src') diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 8073de15925..d60bb0ae5cb 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -259,29 +259,41 @@ mod llvm_enzyme { // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field let mut ts: Vec = vec![]; - if meta_item_vec.len() < 2 { - // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no - // input and output args. + if meta_item_vec.len() < 1 { + // At the bare minimum, we need a fnc name. dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() }); return vec![item]; } - meta_item_inner_to_ts(&meta_item_vec[1], &mut ts); + let mode_symbol = match mode { + DiffMode::Forward => sym::Forward, + DiffMode::Reverse => sym::Reverse, + _ => unreachable!("Unsupported mode: {:?}", mode), + }; + + // Insert mode token + let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default()); + ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint)); + ts.insert( + 1, + TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone), + ); // Now, if the user gave a width (vector aka batch-mode ad), then we copy it. // If it is not given, we default to 1 (scalar mode). let start_position; let kind: LitKind = LitKind::Integer; let symbol; - if meta_item_vec.len() >= 3 - && let Some(width) = width(&meta_item_vec[2]) + if meta_item_vec.len() >= 2 + && let Some(width) = width(&meta_item_vec[1]) { - start_position = 3; + start_position = 2; symbol = Symbol::intern(&width.to_string()); } else { - start_position = 2; + start_position = 1; symbol = sym::integer(1); } + let l: Lit = Lit { kind, symbol, suffix: None }; let t = Token::new(TokenKind::Literal(l), Span::default()); let comma = Token::new(TokenKind::Comma, Span::default()); diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index bb19b5761bb..9dff87e0a00 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -253,6 +253,7 @@ symbols! { FnMut, FnOnce, Formatter, + Forward, From, FromIterator, FromResidual, @@ -348,6 +349,7 @@ symbols! { Result, ResumeTy, Return, + Reverse, Right, Rust, RustaceansAreAwesome, -- cgit 1.4.1-3-g733a5