diff options
| author | Marcelo DomÃnguez <dmmarcelo27@gmail.com> | 2025-05-10 00:52:47 +0000 |
|---|---|---|
| committer | Marcelo DomÃnguez <dmmarcelo27@gmail.com> | 2025-05-21 07:24:42 +0000 |
| commit | f92d84cc6e0770c0d8f3dc775bafdf7f8786db61 (patch) | |
| tree | e53ea47842d523f1ead2ca5d482e802ab18fe189 | |
| parent | 2041de7083c114025646220068d3943be517af4a (diff) | |
| download | rust-f92d84cc6e0770c0d8f3dc775bafdf7f8786db61.tar.gz rust-f92d84cc6e0770c0d8f3dc775bafdf7f8786db61.zip | |
Initial naive implementation using `Symbols` to represent autodiff modes (`Forward`, `Reverse`)
Since the mode is no longer part of `meta_item`, we must insert it manually (otherwise macro expansion with `#[rustc_autodiff]` won't work). This can be revised later if a more structured representation becomes necessary (using enums, annotated structs, etc). Some tests are currently failing. I'll address them next.
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 28 | ||||
| -rw-r--r-- | compiler/rustc_span/src/symbol.rs | 2 |
2 files changed, 22 insertions, 8 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 8073de15925..d60bb0ae5cb 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -259,29 +259,41 @@ mod llvm_enzyme { // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field let mut ts: Vec<TokenTree> = vec![]; - if meta_item_vec.len() < 2 { - // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no - // input and output args. + if meta_item_vec.len() < 1 { + // At the bare minimum, we need a fnc name. dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() }); return vec![item]; } - meta_item_inner_to_ts(&meta_item_vec[1], &mut ts); + let mode_symbol = match mode { + DiffMode::Forward => sym::Forward, + DiffMode::Reverse => sym::Reverse, + _ => unreachable!("Unsupported mode: {:?}", mode), + }; + + // Insert mode token + let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default()); + ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint)); + ts.insert( + 1, + TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone), + ); // Now, if the user gave a width (vector aka batch-mode ad), then we copy it. // If it is not given, we default to 1 (scalar mode). let start_position; let kind: LitKind = LitKind::Integer; let symbol; - if meta_item_vec.len() >= 3 - && let Some(width) = width(&meta_item_vec[2]) + if meta_item_vec.len() >= 2 + && let Some(width) = width(&meta_item_vec[1]) { - start_position = 3; + start_position = 2; symbol = Symbol::intern(&width.to_string()); } else { - start_position = 2; + start_position = 1; symbol = sym::integer(1); } + let l: Lit = Lit { kind, symbol, suffix: None }; let t = Token::new(TokenKind::Literal(l), Span::default()); let comma = Token::new(TokenKind::Comma, Span::default()); diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index bb19b5761bb..9dff87e0a00 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -253,6 +253,7 @@ symbols! { FnMut, FnOnce, Formatter, + Forward, From, FromIterator, FromResidual, @@ -348,6 +349,7 @@ symbols! { Result, ResumeTy, Return, + Reverse, Right, Rust, RustaceansAreAwesome, |
