about summary refs log tree commit diff
path: root/compiler/rustc_builtin_macros/src
diff options
context:
space:
mode:
authorManuel Drehwald <git@manuel.drehwald.info>2025-03-17 16:54:41 -0400
committerManuel Drehwald <git@manuel.drehwald.info>2025-03-17 16:54:41 -0400
commitf5c37c3732360b21dc6d049003838e99b5ec5263 (patch)
tree71af0969b753418f3bdc4a2e49592afd1b1befd0 /compiler/rustc_builtin_macros/src
parent03ece26b79c8b04a916bcd6ee0ab26c8c20e7b66 (diff)
downloadrust-f5c37c3732360b21dc6d049003838e99b5ec5263.tar.gz
rust-f5c37c3732360b21dc6d049003838e99b5ec5263.zip
[NFC] split up gen_body_helper
Diffstat (limited to 'compiler/rustc_builtin_macros/src')
-rw-r--r--compiler/rustc_builtin_macros/src/autodiff.rs89
1 files changed, 67 insertions, 22 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs
index 7b5c4a159b0..06cc4d3d8ba 100644
--- a/compiler/rustc_builtin_macros/src/autodiff.rs
+++ b/compiler/rustc_builtin_macros/src/autodiff.rs
@@ -359,30 +359,27 @@ mod llvm_enzyme {
         ty
     }
 
-    /// We only want this function to type-check, since we will replace the body
-    /// later on llvm level. Using `loop {}` does not cover all return types anymore,
-    /// so instead we build something that should pass. We also add a inline_asm
-    /// line, as one more barrier for rustc to prevent inlining of this function.
-    /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
-    /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
-    /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
-    /// this function (which should never happen, since it is only a placeholder).
-    /// Finally, we also add back_box usages of all input arguments, to prevent rustc
-    /// from optimizing any arguments away.
-    fn gen_enzyme_body(
+    // Will generate a body of the type:
+    // ```
+    // {
+    //   unsafe {
+    //   asm!("NOP");
+    //   }
+    //   ::core::hint::black_box(primal(args));
+    //   ::core::hint::black_box((args, ret));
+    //   <This part remains to be done by following function>
+    // }
+    // ```
+    fn init_body_helper(
         ecx: &ExtCtxt<'_>,
-        x: &AutoDiffAttrs,
-        n_active: u32,
-        sig: &ast::FnSig,
-        d_sig: &ast::FnSig,
+        span: Span,
         primal: Ident,
         new_names: &[String],
-        span: Span,
         sig_span: Span,
         new_decl_span: Span,
-        idents: Vec<Ident>,
+        idents: &[Ident],
         errored: bool,
-    ) -> P<ast::Block> {
+    ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
         let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
         let noop = ast::InlineAsm {
             asm_macro: ast::AsmMacro::Asm,
@@ -431,6 +428,54 @@ mod llvm_enzyme {
         }
         body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
 
+        (body, primal_call, black_box_primal_call, blackbox_call_expr)
+    }
+
+    /// We only want this function to type-check, since we will replace the body
+    /// later on llvm level. Using `loop {}` does not cover all return types anymore,
+    /// so instead we build something that should pass. We also add a inline_asm
+    /// line, as one more barrier for rustc to prevent inlining of this function.
+    /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
+    /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
+    /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
+    /// this function (which should never happen, since it is only a placeholder).
+    /// Finally, we also add back_box usages of all input arguments, to prevent rustc
+    /// from optimizing any arguments away.
+    fn gen_enzyme_body(
+        ecx: &ExtCtxt<'_>,
+        x: &AutoDiffAttrs,
+        n_active: u32,
+        sig: &ast::FnSig,
+        d_sig: &ast::FnSig,
+        primal: Ident,
+        new_names: &[String],
+        span: Span,
+        sig_span: Span,
+        _new_decl_span: Span,
+        idents: Vec<Ident>,
+        errored: bool,
+    ) -> P<ast::Block> {
+        let new_decl_span = d_sig.span;
+
+        // Just adding some default inline-asm and black_box usages to prevent early inlining
+        // and optimizations which alter the function signature.
+        //
+        // The bb_primal_call is the black_box call of the primal function. We keep it around,
+        // since it has the convenient property of returning the type of the primal function,
+        // Remember, we only care to match types here.
+        // No matter which return we pick, we always wrap it into a std::hint::black_box call,
+        // to prevent rustc from propagating it into the caller.
+        let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
+            ecx,
+            span,
+            primal,
+            new_names,
+            sig_span,
+            new_decl_span,
+            &idents,
+            errored,
+        );
+
         if !has_ret(&d_sig.decl.output) {
             // there is no return type that we have to match, () works fine.
             return body;
@@ -442,7 +487,7 @@ mod llvm_enzyme {
 
         if primal_ret && n_active == 0 && x.mode.is_rev() {
             // We only have the primal ret.
-            body.stmts.push(ecx.stmt_expr(black_box_primal_call));
+            body.stmts.push(ecx.stmt_expr(bb_primal_call));
             return body;
         }
 
@@ -534,11 +579,11 @@ mod llvm_enzyme {
                 return body;
             }
             [arg] => {
-                ret = ecx.expr_call(new_decl_span, blackbox_call_expr, thin_vec![arg.clone()]);
+                ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
             }
             args => {
                 let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
-                ret = ecx.expr_call(new_decl_span, blackbox_call_expr, thin_vec![ret_tuple]);
+                ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
             }
         }
         assert!(has_ret(&d_sig.decl.output));
@@ -551,7 +596,7 @@ mod llvm_enzyme {
         ecx: &ExtCtxt<'_>,
         span: Span,
         primal: Ident,
-        idents: Vec<Ident>,
+        idents: &[Ident],
     ) -> P<ast::Expr> {
         let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
         if has_self {