about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMarcelo Domínguez <69964857+Sa4dUs@users.noreply.github.com>2025-05-06 09:19:33 +0200
committerMarcelo Domínguez <dmmarcelo27@gmail.com>2025-05-20 11:58:26 +0000
commitb21c9e7bfb0180b67b486013a7137fb200cb1076 (patch)
tree769762e6206c9f8618026eb6d213dbdbcd35af2f
parentf8e9e7636aabcbc29345d9614432d15b3c0c4ec7 (diff)
downloadrust-b21c9e7bfb0180b67b486013a7137fb200cb1076.tar.gz
rust-b21c9e7bfb0180b67b486013a7137fb200cb1076.zip
Split `autodiff` into `autodiff_forward` and `autodiff_reverse`
Pending fix.
```
error: cannot find a built-in macro with name `autodiff_forward`
    --> library\core\src\macros\mod.rs:1542:5
     |
1542 | /     pub macro autodiff_forward($item:item) {
1543 | |         /* compiler built-in */
1544 | |     }
     | |_____^

error: cannot find a built-in macro with name `autodiff_reverse`
    --> library\core\src\macros\mod.rs:1549:5
     |
1549 | /     pub macro autodiff_reverse($item:item) {
1550 | |         /* compiler built-in */
1551 | |     }
     | |_____^

error: could not compile `core` (lib) due to 2 previous errors
```
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs41
-rw-r--r--compiler/rustc_builtin_macros/src/lib.rs3
-rw-r--r--compiler/rustc_passes/src/check_attr.rs2
-rw-r--r--compiler/rustc_span/src/symbol.rs3
-rw-r--r--library/core/src/lib.rs2
-rw-r--r--library/core/src/macros/mod.rs14
6 files changed, 48 insertions, 17 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 1ff4fc6aaab..8073de15925 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -88,25 +88,20 @@ mod llvm_enzyme {
         has_ret: bool,
     ) -> AutoDiffAttrs {
         let dcx = ecx.sess.dcx();
-        let mode = name(&meta_item[1]);
-        let Ok(mode) = DiffMode::from_str(&mode) else {
-            dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
-            return AutoDiffAttrs::error();
-        };
 
         // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
         // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
-        let mut first_activity = 2;
+        let mut first_activity = 1;
 
-        let width = if let [_, _, x, ..] = &meta_item[..]
+        let width = if let [_, x, ..] = &meta_item[..]
             && let Some(x) = width(x)
         {
-            first_activity = 3;
+            first_activity = 2;
             match x.try_into() {
                 Ok(x) => x,
                 Err(_) => {
                     dcx.emit_err(errors::AutoDiffInvalidWidth {
-                        span: meta_item[2].span(),
+                        span: meta_item[1].span(),
                         width: x,
                     });
                     return AutoDiffAttrs::error();
@@ -150,7 +145,7 @@ mod llvm_enzyme {
         };
 
         AutoDiffAttrs {
-            mode,
+            mode: DiffMode::Error,
             width,
             ret_activity: *ret_activity,
             input_activity: input_activity.to_vec(),
@@ -165,6 +160,24 @@ mod llvm_enzyme {
         ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
     }
 
+    pub(crate) fn expand_forward(
+        ecx: &mut ExtCtxt<'_>,
+        expand_span: Span,
+        meta_item: &ast::MetaItem,
+        item: Annotatable,
+    ) -> Vec<Annotatable> {
+        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
+    }
+
+    pub(crate) fn expand_reverse(
+        ecx: &mut ExtCtxt<'_>,
+        expand_span: Span,
+        meta_item: &ast::MetaItem,
+        item: Annotatable,
+    ) -> Vec<Annotatable> {
+        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
+    }
+
     /// We expand the autodiff macro to generate a new placeholder function which passes
     /// type-checking and can be called by users. The function body of the placeholder function will
     /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +211,12 @@ mod llvm_enzyme {
     /// ```
     /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
     /// in CI.
-    pub(crate) fn expand(
+    pub(crate) fn expand_with_mode(
         ecx: &mut ExtCtxt<'_>,
         expand_span: Span,
         meta_item: &ast::MetaItem,
         mut item: Annotatable,
+        mode: DiffMode,
     ) -> Vec<Annotatable> {
         if cfg!(not(llvm_enzyme)) {
             ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
@@ -289,7 +303,8 @@ mod llvm_enzyme {
         ts.pop();
         let ts: TokenStream = TokenStream::from_iter(ts);
 
-        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
+        let mut x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
+        x.mode = mode;
         if !x.is_active() {
             // We encountered an error, so we return the original item.
             // This allows us to potentially parse other attributes.
@@ -1017,4 +1032,4 @@ mod llvm_enzyme {
     }
 }
 
-pub(crate) use llvm_enzyme::expand;
+pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};
diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs
index 9cd4d17059a..a89b3642f7e 100644
--- a/compiler/rustc_builtin_macros/src/lib.rs
+++ b/compiler/rustc_builtin_macros/src/lib.rs
@@ -112,7 +112,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
 
     register_attr! {
         alloc_error_handler: alloc_error_handler::expand,
-        autodiff: autodiff::expand,
+        autodiff_forward: autodiff::expand_forward,
+        autodiff_reverse: autodiff::expand_reverse,
         bench: test::expand_bench,
         cfg_accessible: cfg_accessible::Expander,
         cfg_eval: cfg_eval::expand,
diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs
index 5c0d0cf4796..5aff24f7aa0 100644
--- a/compiler/rustc_passes/src/check_attr.rs
+++ b/compiler/rustc_passes/src/check_attr.rs
@@ -255,7 +255,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
                             self.check_generic_attr(hir_id, attr, target, Target::Fn);
                             self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
                         }
-                        [sym::autodiff, ..] => {
+                        [sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
                             self.check_autodiff(hir_id, attr, span, target)
                         }
                         [sym::coroutine, ..] => {
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index efae6250b07..bb19b5761bb 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -531,7 +531,8 @@ symbols! {
         audit_that,
         augmented_assignments,
         auto_traits,
-        autodiff,
+        autodiff_forward,
+        autodiff_reverse,
         automatically_derived,
         avx,
         avx10_target_feature,
diff --git a/library/core/src/lib.rs b/library/core/src/lib.rs
index e605d7e0d78..aaa8c872f98 100644
--- a/library/core/src/lib.rs
+++ b/library/core/src/lib.rs
@@ -229,7 +229,7 @@ pub mod assert_matches {
 /// Unstable module containing the unstable `autodiff` macro.
 pub mod autodiff {
     #[unstable(feature = "autodiff", issue = "124509")]
-    pub use crate::macros::builtin::autodiff;
+    pub use crate::macros::builtin::{autodiff_forward, autodiff_reverse};
 }
 
 #[unstable(feature = "contracts", issue = "128044")]
diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs
index 7dc8c060cd5..dc50ad6a090 100644
--- a/library/core/src/macros/mod.rs
+++ b/library/core/src/macros/mod.rs
@@ -1536,6 +1536,20 @@ pub(crate) mod builtin {
         /* compiler built-in */
     }
 
+    #[unstable(feature = "autodiff", issue = "124509")]
+    #[allow_internal_unstable(rustc_attrs)]
+    #[rustc_builtin_macro]
+    pub macro autodiff_forward($item:item) {
+        /* compiler built-in */
+    }
+
+    #[unstable(feature = "autodiff", issue = "124509")]
+    #[allow_internal_unstable(rustc_attrs)]
+    #[rustc_builtin_macro]
+    pub macro autodiff_reverse($item:item) {
+        /* compiler built-in */
+    }
+
     /// Asserts that a boolean expression is `true` at runtime.
     ///
     /// This will invoke the [`panic!`] macro if the provided expression cannot be