about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_builtin_macros/messages.ftl1
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs67
-rw-r--r--compiler/rustc_builtin_macros/src/errors.rs8
-rw-r--r--compiler/rustc_builtin_macros/src/lib.rs5
-rw-r--r--compiler/rustc_passes/src/check_attr.rs2
-rw-r--r--compiler/rustc_span/src/symbol.rs5
6 files changed, 55 insertions, 33 deletions
diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl
index 628bdee1129..d32e6f1558e 100644
--- a/compiler/rustc_builtin_macros/messages.ftl
+++ b/compiler/rustc_builtin_macros/messages.ftl
@@ -56,7 +56,6 @@ builtin_macros_assert_requires_expression = macro requires an expression as an a
 
 builtin_macros_autodiff = autodiff must be applied to function
 builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
-builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
 builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
 builtin_macros_autodiff_not_build = this rustc version does not support autodiff
 builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 1ff4fc6aaab..dc3bb8ab52a 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -86,27 +86,23 @@ mod llvm_enzyme {
         ecx: &mut ExtCtxt<'_>,
         meta_item: &ThinVec<MetaItemInner>,
         has_ret: bool,
+        mode: DiffMode,
     ) -> AutoDiffAttrs {
         let dcx = ecx.sess.dcx();
-        let mode = name(&meta_item[1]);
-        let Ok(mode) = DiffMode::from_str(&mode) else {
-            dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
-            return AutoDiffAttrs::error();
-        };
 
         // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
         // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
-        let mut first_activity = 2;
+        let mut first_activity = 1;
 
-        let width = if let [_, _, x, ..] = &meta_item[..]
+        let width = if let [_, x, ..] = &meta_item[..]
             && let Some(x) = width(x)
         {
-            first_activity = 3;
+            first_activity = 2;
             match x.try_into() {
                 Ok(x) => x,
                 Err(_) => {
                     dcx.emit_err(errors::AutoDiffInvalidWidth {
-                        span: meta_item[2].span(),
+                        span: meta_item[1].span(),
                         width: x,
                     });
                     return AutoDiffAttrs::error();
@@ -165,6 +161,24 @@ mod llvm_enzyme {
         ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
     }
 
+    pub(crate) fn expand_forward(
+        ecx: &mut ExtCtxt<'_>,
+        expand_span: Span,
+        meta_item: &ast::MetaItem,
+        item: Annotatable,
+    ) -> Vec<Annotatable> {
+        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
+    }
+
+    pub(crate) fn expand_reverse(
+        ecx: &mut ExtCtxt<'_>,
+        expand_span: Span,
+        meta_item: &ast::MetaItem,
+        item: Annotatable,
+    ) -> Vec<Annotatable> {
+        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
+    }
+
     /// We expand the autodiff macro to generate a new placeholder function which passes
     /// type-checking and can be called by users. The function body of the placeholder function will
     /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +212,12 @@ mod llvm_enzyme {
     /// ```
     /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
     /// in CI.
-    pub(crate) fn expand(
+    pub(crate) fn expand_with_mode(
         ecx: &mut ExtCtxt<'_>,
         expand_span: Span,
         meta_item: &ast::MetaItem,
         mut item: Annotatable,
+        mode: DiffMode,
     ) -> Vec<Annotatable> {
         if cfg!(not(llvm_enzyme)) {
             ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
@@ -245,29 +260,41 @@ mod llvm_enzyme {
         // create TokenStream from vec elemtents:
         // meta_item doesn't have a .tokens field
         let mut ts: Vec<TokenTree> = vec![];
-        if meta_item_vec.len() < 2 {
-            // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
-            // input and output args.
+        if meta_item_vec.len() < 1 {
+            // At the bare minimum, we need a fnc name.
             dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
             return vec![item];
         }
 
-        meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
+        let mode_symbol = match mode {
+            DiffMode::Forward => sym::Forward,
+            DiffMode::Reverse => sym::Reverse,
+            _ => unreachable!("Unsupported mode: {:?}", mode),
+        };
+
+        // Insert mode token
+        let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
+        ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
+        ts.insert(
+            1,
+            TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
+        );
 
         // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
         // If it is not given, we default to 1 (scalar mode).
         let start_position;
         let kind: LitKind = LitKind::Integer;
         let symbol;
-        if meta_item_vec.len() >= 3
-            && let Some(width) = width(&meta_item_vec[2])
+        if meta_item_vec.len() >= 2
+            && let Some(width) = width(&meta_item_vec[1])
         {
-            start_position = 3;
+            start_position = 2;
             symbol = Symbol::intern(&width.to_string());
         } else {
-            start_position = 2;
+            start_position = 1;
             symbol = sym::integer(1);
         }
+
         let l: Lit = Lit { kind, symbol, suffix: None };
         let t = Token::new(TokenKind::Literal(l), Span::default());
         let comma = Token::new(TokenKind::Comma, Span::default());
@@ -289,7 +316,7 @@ mod llvm_enzyme {
         ts.pop();
         let ts: TokenStream = TokenStream::from_iter(ts);
 
-        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
+        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
         if !x.is_active() {
             // We encountered an error, so we return the original item.
             // This allows us to potentially parse other attributes.
@@ -1017,4 +1044,4 @@ mod llvm_enzyme {
     }
 }
 
-pub(crate) use llvm_enzyme::expand;
+pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};
diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs
index 75d06a8df14..3a2e96a5e5a 100644
--- a/compiler/rustc_builtin_macros/src/errors.rs
+++ b/compiler/rustc_builtin_macros/src/errors.rs
@@ -181,14 +181,6 @@ mod autodiff {
     }
 
     #[derive(Diagnostic)]
-    #[diag(builtin_macros_autodiff_mode)]
-    pub(crate) struct AutoDiffInvalidMode {
-        #[primary_span]
-        pub(crate) span: Span,
-        pub(crate) mode: String,
-    }
-
-    #[derive(Diagnostic)]
     #[diag(builtin_macros_autodiff_width)]
     pub(crate) struct AutoDiffInvalidWidth {
         #[primary_span]
diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs
index 9cd4d17059a..b16f3cff5cf 100644
--- a/compiler/rustc_builtin_macros/src/lib.rs
+++ b/compiler/rustc_builtin_macros/src/lib.rs
@@ -5,10 +5,10 @@
 #![allow(internal_features)]
 #![allow(rustc::diagnostic_outside_of_impl)]
 #![allow(rustc::untranslatable_diagnostic)]
+#![cfg_attr(not(bootstrap), feature(autodiff))]
 #![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
 #![doc(rust_logo)]
 #![feature(assert_matches)]
-#![feature(autodiff)]
 #![feature(box_patterns)]
 #![feature(decl_macro)]
 #![feature(if_let_guard)]
@@ -112,7 +112,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
 
     register_attr! {
         alloc_error_handler: alloc_error_handler::expand,
-        autodiff: autodiff::expand,
+        autodiff_forward: autodiff::expand_forward,
+        autodiff_reverse: autodiff::expand_reverse,
         bench: test::expand_bench,
         cfg_accessible: cfg_accessible::Expander,
         cfg_eval: cfg_eval::expand,
diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs
index 1d024694049..df8b5a6b181 100644
--- a/compiler/rustc_passes/src/check_attr.rs
+++ b/compiler/rustc_passes/src/check_attr.rs
@@ -255,7 +255,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
                             self.check_generic_attr(hir_id, attr, target, Target::Fn);
                             self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
                         }
-                        [sym::autodiff, ..] => {
+                        [sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
                             self.check_autodiff(hir_id, attr, span, target)
                         }
                         [sym::coroutine, ..] => {
diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs
index 0447713fb05..8f7e72f0ae1 100644
--- a/compiler/rustc_span/src/symbol.rs
+++ b/compiler/rustc_span/src/symbol.rs
@@ -244,6 +244,7 @@ symbols! {
         FnMut,
         FnOnce,
         Formatter,
+        Forward,
         From,
         FromIterator,
         FromResidual,
@@ -339,6 +340,7 @@ symbols! {
         Result,
         ResumeTy,
         Return,
+        Reverse,
         Right,
         Rust,
         RustaceansAreAwesome,
@@ -522,7 +524,8 @@ symbols! {
         audit_that,
         augmented_assignments,
         auto_traits,
-        autodiff,
+        autodiff_forward,
+        autodiff_reverse,
         automatically_derived,
         avx,
         avx10_target_feature,