diff options
| author | Marcelo DomÃnguez <69964857+Sa4dUs@users.noreply.github.com> | 2025-05-06 09:19:33 +0200 |
|---|---|---|
| committer | Marcelo DomÃnguez <dmmarcelo27@gmail.com> | 2025-05-20 11:58:26 +0000 |
| commit | b21c9e7bfb0180b67b486013a7137fb200cb1076 (patch) | |
| tree | 769762e6206c9f8618026eb6d213dbdbcd35af2f | |
| parent | f8e9e7636aabcbc29345d9614432d15b3c0c4ec7 (diff) | |
| download | rust-b21c9e7bfb0180b67b486013a7137fb200cb1076.tar.gz rust-b21c9e7bfb0180b67b486013a7137fb200cb1076.zip | |
Split `autodiff` into `autodiff_forward` and `autodiff_reverse`
Pending fix.
```
error: cannot find a built-in macro with name `autodiff_forward`
--> library\core\src\macros\mod.rs:1542:5
|
1542 | / pub macro autodiff_forward($item:item) {
1543 | | /* compiler built-in */
1544 | | }
| |_____^
error: cannot find a built-in macro with name `autodiff_reverse`
--> library\core\src\macros\mod.rs:1549:5
|
1549 | / pub macro autodiff_reverse($item:item) {
1550 | | /* compiler built-in */
1551 | | }
| |_____^
error: could not compile `core` (lib) due to 2 previous errors
```
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 41 | ||||
| -rw-r--r-- | compiler/rustc_builtin_macros/src/lib.rs | 3 | ||||
| -rw-r--r-- | compiler/rustc_passes/src/check_attr.rs | 2 | ||||
| -rw-r--r-- | compiler/rustc_span/src/symbol.rs | 3 | ||||
| -rw-r--r-- | library/core/src/lib.rs | 2 | ||||
| -rw-r--r-- | library/core/src/macros/mod.rs | 14 |
6 files changed, 48 insertions, 17 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 1ff4fc6aaab..8073de15925 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -88,25 +88,20 @@ mod llvm_enzyme { has_ret: bool, ) -> AutoDiffAttrs { let dcx = ecx.sess.dcx(); - let mode = name(&meta_item[1]); - let Ok(mode) = DiffMode::from_str(&mode) else { - dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode }); - return AutoDiffAttrs::error(); - }; // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode. // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1. - let mut first_activity = 2; + let mut first_activity = 1; - let width = if let [_, _, x, ..] = &meta_item[..] + let width = if let [_, x, ..] = &meta_item[..] && let Some(x) = width(x) { - first_activity = 3; + first_activity = 2; match x.try_into() { Ok(x) => x, Err(_) => { dcx.emit_err(errors::AutoDiffInvalidWidth { - span: meta_item[2].span(), + span: meta_item[1].span(), width: x, }); return AutoDiffAttrs::error(); @@ -150,7 +145,7 @@ mod llvm_enzyme { }; AutoDiffAttrs { - mode, + mode: DiffMode::Error, width, ret_activity: *ret_activity, input_activity: input_activity.to_vec(), @@ -165,6 +160,24 @@ mod llvm_enzyme { ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); } + pub(crate) fn expand_forward( + ecx: &mut ExtCtxt<'_>, + expand_span: Span, + meta_item: &ast::MetaItem, + item: Annotatable, + ) -> Vec<Annotatable> { + expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward) + } + + pub(crate) fn expand_reverse( + ecx: &mut ExtCtxt<'_>, + expand_span: Span, + meta_item: &ast::MetaItem, + item: Annotatable, + ) -> Vec<Annotatable> { + expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse) + } + /// We expand the autodiff macro to generate a new placeholder function which passes /// type-checking and can be called by users. The function body of the placeholder function will /// later be replaced on LLVM-IR level, so the design of the body is less important and for now @@ -198,11 +211,12 @@ mod llvm_enzyme { /// ``` /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked /// in CI. - pub(crate) fn expand( + pub(crate) fn expand_with_mode( ecx: &mut ExtCtxt<'_>, expand_span: Span, meta_item: &ast::MetaItem, mut item: Annotatable, + mode: DiffMode, ) -> Vec<Annotatable> { if cfg!(not(llvm_enzyme)) { ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span }); @@ -289,7 +303,8 @@ mod llvm_enzyme { ts.pop(); let ts: TokenStream = TokenStream::from_iter(ts); - let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret); + let mut x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret); + x.mode = mode; if !x.is_active() { // We encountered an error, so we return the original item. // This allows us to potentially parse other attributes. @@ -1017,4 +1032,4 @@ mod llvm_enzyme { } } -pub(crate) use llvm_enzyme::expand; +pub(crate) use llvm_enzyme::{expand_forward, expand_reverse}; diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs index 9cd4d17059a..a89b3642f7e 100644 --- a/compiler/rustc_builtin_macros/src/lib.rs +++ b/compiler/rustc_builtin_macros/src/lib.rs @@ -112,7 +112,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) { register_attr! { alloc_error_handler: alloc_error_handler::expand, - autodiff: autodiff::expand, + autodiff_forward: autodiff::expand_forward, + autodiff_reverse: autodiff::expand_reverse, bench: test::expand_bench, cfg_accessible: cfg_accessible::Expander, cfg_eval: cfg_eval::expand, diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index 5c0d0cf4796..5aff24f7aa0 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -255,7 +255,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> { self.check_generic_attr(hir_id, attr, target, Target::Fn); self.check_proc_macro(hir_id, target, ProcMacroKind::Derive) } - [sym::autodiff, ..] => { + [sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => { self.check_autodiff(hir_id, attr, span, target) } [sym::coroutine, ..] => { diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index efae6250b07..bb19b5761bb 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -531,7 +531,8 @@ symbols! { audit_that, augmented_assignments, auto_traits, - autodiff, + autodiff_forward, + autodiff_reverse, automatically_derived, avx, avx10_target_feature, diff --git a/library/core/src/lib.rs b/library/core/src/lib.rs index e605d7e0d78..aaa8c872f98 100644 --- a/library/core/src/lib.rs +++ b/library/core/src/lib.rs @@ -229,7 +229,7 @@ pub mod assert_matches { /// Unstable module containing the unstable `autodiff` macro. pub mod autodiff { #[unstable(feature = "autodiff", issue = "124509")] - pub use crate::macros::builtin::autodiff; + pub use crate::macros::builtin::{autodiff_forward, autodiff_reverse}; } #[unstable(feature = "contracts", issue = "128044")] diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index 7dc8c060cd5..dc50ad6a090 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1536,6 +1536,20 @@ pub(crate) mod builtin { /* compiler built-in */ } + #[unstable(feature = "autodiff", issue = "124509")] + #[allow_internal_unstable(rustc_attrs)] + #[rustc_builtin_macro] + pub macro autodiff_forward($item:item) { + /* compiler built-in */ + } + + #[unstable(feature = "autodiff", issue = "124509")] + #[allow_internal_unstable(rustc_attrs)] + #[rustc_builtin_macro] + pub macro autodiff_reverse($item:item) { + /* compiler built-in */ + } + /// Asserts that a boolean expression is `true` at runtime. /// /// This will invoke the [`panic!`] macro if the provided expression cannot be |
