about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMarcelo Domínguez <dmmarcelo27@gmail.com>2025-05-10 00:52:47 +0000
committerMarcelo Domínguez <dmmarcelo27@gmail.com>2025-05-21 07:24:42 +0000
commitf92d84cc6e0770c0d8f3dc775bafdf7f8786db61 (patch)
treee53ea47842d523f1ead2ca5d482e802ab18fe189
parent2041de7083c114025646220068d3943be517af4a (diff)
downloadrust-f92d84cc6e0770c0d8f3dc775bafdf7f8786db61.tar.gz
rust-f92d84cc6e0770c0d8f3dc775bafdf7f8786db61.zip
Initial naive implementation using `Symbols` to represent autodiff modes (`Forward`, `Reverse`)
Since the mode is no longer part of `meta_item`, we must insert it manually (otherwise macro expansion with `#[rustc_autodiff]` won't work).

This can be revised later if a more structured representation becomes necessary (using enums, annotated structs, etc).

Some tests are currently failing. I'll address them next.
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs28
-rw-r--r--compiler/rustc_span/src/symbol.rs2
2 files changed, 22 insertions, 8 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 8073de15925..d60bb0ae5cb 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -259,29 +259,41 @@ mod llvm_enzyme {
         // create TokenStream from vec elemtents:
         // meta_item doesn't have a .tokens field
         let mut ts: Vec<TokenTree> = vec![];
-        if meta_item_vec.len() < 2 {
-            // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
-            // input and output args.
+        if meta_item_vec.len() < 1 {
+            // At the bare minimum, we need a fnc name.
             dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
             return vec![item];
         }
 
-        meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
+        let mode_symbol = match mode {
+            DiffMode::Forward => sym::Forward,
+            DiffMode::Reverse => sym::Reverse,
+            _ => unreachable!("Unsupported mode: {:?}", mode),
+        };
+
+        // Insert mode token
+        let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
+        ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
+        ts.insert(
+            1,
+            TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
+        );
 
         // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
         // If it is not given, we default to 1 (scalar mode).
         let start_position;
         let kind: LitKind = LitKind::Integer;
         let symbol;
-        if meta_item_vec.len() >= 3
-            && let Some(width) = width(&meta_item_vec[2])
+        if meta_item_vec.len() >= 2
+            && let Some(width) = width(&meta_item_vec[1])
         {
-            start_position = 3;
+            start_position = 2;
             symbol = Symbol::intern(&width.to_string());
         } else {
-            start_position = 2;
+            start_position = 1;
             symbol = sym::integer(1);
         }
+
         let l: Lit = Lit { kind, symbol, suffix: None };
         let t = Token::new(TokenKind::Literal(l), Span::default());
         let comma = Token::new(TokenKind::Comma, Span::default());
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index bb19b5761bb..9dff87e0a00 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -253,6 +253,7 @@ symbols! {
         FnMut,
         FnOnce,
         Formatter,
+        Forward,
         From,
         FromIterator,
         FromResidual,
@@ -348,6 +349,7 @@ symbols! {
         Result,
         ResumeTy,
         Return,
+        Reverse,
         Right,
         Rust,
         RustaceansAreAwesome,