about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs67
1 files changed, 26 insertions, 41 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 287f0fdc516..351413dea49 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -17,7 +17,7 @@ mod llvm_enzyme {
     use rustc_ast::visit::AssocCtxt::*;
     use rustc_ast::{
         self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
-        MetaItemInner, PatKind, QSelf, TyKind,
+        MetaItemInner, PatKind, QSelf, TyKind, Visibility,
     };
     use rustc_expand::base::{Annotatable, ExtCtxt};
     use rustc_span::{Ident, Span, Symbol, kw, sym};
@@ -72,6 +72,16 @@ mod llvm_enzyme {
         }
     }
 
+    // Get information about the function the macro is applied to
+    fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
+        match &iitem.kind {
+            ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
+                Some((iitem.vis.clone(), sig.clone(), ident.clone()))
+            }
+            _ => None,
+        }
+    }
+
     pub(crate) fn from_ast(
         ecx: &mut ExtCtxt<'_>,
         meta_item: &ThinVec<MetaItemInner>,
@@ -201,49 +211,24 @@ mod llvm_enzyme {
         let dcx = ecx.sess.dcx();
 
         // first get information about the annotable item:
-        let (sig, vis, primal) = match &item {
-            Annotatable::Item(iitem) => {
-                let (sig, ident) = match &iitem.kind {
-                    ItemKind::Fn(box ast::Fn { sig, ident, .. }) => (sig, ident),
-                    _ => {
-                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
-                        return vec![item];
-                    }
-                };
-                (sig.clone(), iitem.vis.clone(), ident.clone())
-            }
+        let Some((vis, sig, primal)) = (match &item {
+            Annotatable::Item(iitem) => extract_item_info(iitem),
+            Annotatable::Stmt(stmt) => match &stmt.kind {
+                ast::StmtKind::Item(iitem) => extract_item_info(iitem),
+                _ => None,
+            },
             Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
-                let (sig, ident) = match &assoc_item.kind {
-                    ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => (sig, ident),
-                    _ => {
-                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
-                        return vec![item];
-                    }
-                };
-                (sig.clone(), assoc_item.vis.clone(), ident.clone())
-            }
-            Annotatable::Stmt(stmt) => {
-                let (sig, vis, ident) = match &stmt.kind {
-                    ast::StmtKind::Item(iitem) => match &iitem.kind {
-                        ast::ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
-                            (sig.clone(), iitem.vis.clone(), ident.clone())
-                        }
-                        _ => {
-                            dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
-                            return vec![item];
-                        }
-                    },
-                    _ => {
-                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
-                        return vec![item];
+                match &assoc_item.kind {
+                    ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
+                        Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
                     }
-                };
-                (sig, vis, ident)
-            }
-            _ => {
-                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
-                return vec![item];
+                    _ => None,
+                }
             }
+            _ => None,
+        }) else {
+            dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
+            return vec![item];
         };
 
         let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {