about summary refs log tree commit diff
diff options
context:
space:
mode:
authorJakub Beránek <berykubik@gmail.com>2025-04-07 08:23:34 +0200
committerGitHub <noreply@github.com>2025-04-07 08:23:34 +0200
commit5a43d92382397ae1968a6f7f79a31e09e1e7ad5f (patch)
tree24a784e8d75dbcbdc644ef998379b58d224aae93
parent25a615bf829b9f6d6f22da537e3851043f92e5f2 (diff)
parentbf69443a9f0fa9b44aaec36c1c470ad22a325c2a (diff)
downloadrust-5a43d92382397ae1968a6f7f79a31e09e1e7ad5f.tar.gz
rust-5a43d92382397ae1968a6f7f79a31e09e1e7ad5f.zip
Rollup merge of #138314 - haenoe:autodiff-inner-function, r=ZuseZ4
fix usage of `autodiff` macro with inner functions

This PR adds additional handling into the expansion step of the `std::autodiff` macro (in `compiler/rustc_builtin_macros/src/autodiff.rs`), which allows the macro to be applied to inner functions.

```rust
#![feature(autodiff)]
use std::autodiff::autodiff;

fn main() {
    #[autodiff(d_inner, Forward, Dual, DualOnly)]
    fn inner(x: f32) -> f32 {
        x * x
    }
}
```

Previously, the compiler didn't allow this due to only handling `Annotatable::Item` and `Annotatable::AssocItem` and missing the handling of `Annotatable::Stmt`. This resulted in the rather generic error

```
error: autodiff must be applied to function
 --> src/main.rs:6:5
  |
6 | /     fn inner(x: f32) -> f32 {
7 | |         x * x
8 | |     }
  | |_____^

error: could not compile `enzyme-test` (bin "enzyme-test") due to 1 previous error
```

This issue was originally reported [here](https://github.com/EnzymeAD/rust/issues/184).

Quick question: would it make sense to add a ui test to ensure there is no regression on this?
This is my first contribution, so I'm extra grateful for any piece of feedback!! :D

r? `@oli-obk`

Tracking issue for autodiff: #124509
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs125
-rw-r--r--tests/pretty/autodiff_forward.pp23
-rw-r--r--tests/pretty/autodiff_forward.rs9
3 files changed, 109 insertions, 48 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 7f99f75b2b9..351413dea49 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -17,7 +17,7 @@ mod llvm_enzyme {
     use rustc_ast::visit::AssocCtxt::*;
     use rustc_ast::{
         self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
-        MetaItemInner, PatKind, QSelf, TyKind,
+        MetaItemInner, PatKind, QSelf, TyKind, Visibility,
     };
     use rustc_expand::base::{Annotatable, ExtCtxt};
     use rustc_span::{Ident, Span, Symbol, kw, sym};
@@ -72,6 +72,16 @@ mod llvm_enzyme {
         }
     }
 
+    // Get information about the function the macro is applied to
+    fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
+        match &iitem.kind {
+            ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
+                Some((iitem.vis.clone(), sig.clone(), ident.clone()))
+            }
+            _ => None,
+        }
+    }
+
     pub(crate) fn from_ast(
         ecx: &mut ExtCtxt<'_>,
         meta_item: &ThinVec<MetaItemInner>,
@@ -199,32 +209,26 @@ mod llvm_enzyme {
             return vec![item];
         }
         let dcx = ecx.sess.dcx();
-        // first get the annotable item:
-        let (primal, sig, is_impl): (Ident, FnSig, bool) = match &item {
-            Annotatable::Item(iitem) => {
-                let (ident, sig) = match &iitem.kind {
-                    ItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
-                    _ => {
-                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
-                        return vec![item];
-                    }
-                };
-                (*ident, sig.clone(), false)
-            }
+
+        // first get information about the annotable item:
+        let Some((vis, sig, primal)) = (match &item {
+            Annotatable::Item(iitem) => extract_item_info(iitem),
+            Annotatable::Stmt(stmt) => match &stmt.kind {
+                ast::StmtKind::Item(iitem) => extract_item_info(iitem),
+                _ => None,
+            },
             Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
-                let (ident, sig) = match &assoc_item.kind {
-                    ast::AssocItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
-                    _ => {
-                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
-                        return vec![item];
+                match &assoc_item.kind {
+                    ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
+                        Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
                     }
-                };
-                (*ident, sig.clone(), true)
-            }
-            _ => {
-                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
-                return vec![item];
+                    _ => None,
+                }
             }
+            _ => None,
+        }) else {
+            dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
+            return vec![item];
         };
 
         let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
@@ -238,15 +242,6 @@ mod llvm_enzyme {
         let has_ret = has_ret(&sig.decl.output);
         let sig_span = ecx.with_call_site_ctxt(sig.span);
 
-        let vis = match &item {
-            Annotatable::Item(iitem) => iitem.vis.clone(),
-            Annotatable::AssocItem(assoc_item, _) => assoc_item.vis.clone(),
-            _ => {
-                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
-                return vec![item];
-            }
-        };
-
         // create TokenStream from vec elemtents:
         // meta_item doesn't have a .tokens field
         let mut ts: Vec<TokenTree> = vec![];
@@ -379,6 +374,22 @@ mod llvm_enzyme {
                 }
                 Annotatable::AssocItem(assoc_item.clone(), i)
             }
+            Annotatable::Stmt(ref mut stmt) => {
+                match stmt.kind {
+                    ast::StmtKind::Item(ref mut iitem) => {
+                        if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
+                            iitem.attrs.push(attr);
+                        }
+                        if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
+                        {
+                            iitem.attrs.push(inline_never.clone());
+                        }
+                    }
+                    _ => unreachable!("stmt kind checked previously"),
+                };
+
+                Annotatable::Stmt(stmt.clone())
+            }
             _ => {
                 unreachable!("annotatable kind checked previously")
             }
@@ -389,22 +400,40 @@ mod llvm_enzyme {
             delim: rustc_ast::token::Delimiter::Parenthesis,
             tokens: ts,
         });
+
         let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
-        let d_annotatable = if is_impl {
-            let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
-            let d_fn = P(ast::AssocItem {
-                attrs: thin_vec![d_attr, inline_never],
-                id: ast::DUMMY_NODE_ID,
-                span,
-                vis,
-                kind: assoc_item,
-                tokens: None,
-            });
-            Annotatable::AssocItem(d_fn, Impl { of_trait: false })
-        } else {
-            let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
-            d_fn.vis = vis;
-            Annotatable::Item(d_fn)
+        let d_annotatable = match &item {
+            Annotatable::AssocItem(_, _) => {
+                let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
+                let d_fn = P(ast::AssocItem {
+                    attrs: thin_vec![d_attr, inline_never],
+                    id: ast::DUMMY_NODE_ID,
+                    span,
+                    vis,
+                    kind: assoc_item,
+                    tokens: None,
+                });
+                Annotatable::AssocItem(d_fn, Impl { of_trait: false })
+            }
+            Annotatable::Item(_) => {
+                let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
+                d_fn.vis = vis;
+
+                Annotatable::Item(d_fn)
+            }
+            Annotatable::Stmt(_) => {
+                let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
+                d_fn.vis = vis;
+
+                Annotatable::Stmt(P(ast::Stmt {
+                    id: ast::DUMMY_NODE_ID,
+                    kind: ast::StmtKind::Item(d_fn),
+                    span,
+                }))
+            }
+            _ => {
+                unreachable!("item kind checked previously")
+            }
         };
 
         return vec![orig_annotatable, d_annotatable];
diff --git a/tests/pretty/autodiff_forward.pp b/tests/pretty/autodiff_forward.pp
index 4b2fb6166ff..713b8f541ae 100644
--- a/tests/pretty/autodiff_forward.pp
+++ b/tests/pretty/autodiff_forward.pp
@@ -29,6 +29,8 @@ pub fn f1(x: &[f64], y: f64) -> f64 {
     // Make sure, that we add the None for the default return.
 
 
+    // We want to make sure that we can use the macro for functions defined inside of functions
+
     ::core::panicking::panic("not implemented")
 }
 #[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
@@ -158,4 +160,25 @@ fn f8_1(x: &f32, bx_0: &f32) -> f32 {
     ::core::hint::black_box((bx_0,));
     ::core::hint::black_box(<f32>::default())
 }
+pub fn f9() {
+    #[rustc_autodiff]
+    #[inline(never)]
+    fn inner(x: f32) -> f32 { x * x }
+    #[rustc_autodiff(Forward, 1, Dual, Dual)]
+    #[inline(never)]
+    fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
+        unsafe { asm!("NOP", options(pure, nomem)); };
+        ::core::hint::black_box(inner(x));
+        ::core::hint::black_box((bx_0,));
+        ::core::hint::black_box(<(f32, f32)>::default())
+    }
+    #[rustc_autodiff(Forward, 1, Dual, DualOnly)]
+    #[inline(never)]
+    fn d_inner_1(x: f32, bx_0: f32) -> f32 {
+        unsafe { asm!("NOP", options(pure, nomem)); };
+        ::core::hint::black_box(inner(x));
+        ::core::hint::black_box((bx_0,));
+        ::core::hint::black_box(<f32>::default())
+    }
+}
 fn main() {}
diff --git a/tests/pretty/autodiff_forward.rs b/tests/pretty/autodiff_forward.rs
index a765738c2a8..5a0660a08e5 100644
--- a/tests/pretty/autodiff_forward.rs
+++ b/tests/pretty/autodiff_forward.rs
@@ -54,4 +54,13 @@ fn f8(x: &f32) -> f32 {
     unimplemented!()
 }
 
+// We want to make sure that we can use the macro for functions defined inside of functions
+pub fn f9() {
+    #[autodiff(d_inner_1, Forward, Dual, DualOnly)]
+    #[autodiff(d_inner_2, Forward, Dual, Dual)]
+    fn inner(x: f32) -> f32 {
+        x * x
+    }
+}
+
 fn main() {}