diff options
| author | Manuel Drehwald <git@manuel.drehwald.info> | 2025-03-17 16:54:41 -0400 |
|---|---|---|
| committer | Manuel Drehwald <git@manuel.drehwald.info> | 2025-03-17 16:54:41 -0400 |
| commit | f5c37c3732360b21dc6d049003838e99b5ec5263 (patch) | |
| tree | 71af0969b753418f3bdc4a2e49592afd1b1befd0 /compiler/rustc_builtin_macros/src | |
| parent | 03ece26b79c8b04a916bcd6ee0ab26c8c20e7b66 (diff) | |
| download | rust-f5c37c3732360b21dc6d049003838e99b5ec5263.tar.gz rust-f5c37c3732360b21dc6d049003838e99b5ec5263.zip | |
[NFC] split up gen_body_helper
Diffstat (limited to 'compiler/rustc_builtin_macros/src')
| -rw-r--r-- | compiler/rustc_builtin_macros/src/autodiff.rs | 89 |
1 files changed, 67 insertions, 22 deletions
diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 7b5c4a159b0..06cc4d3d8ba 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -359,30 +359,27 @@ mod llvm_enzyme { ty } - /// We only want this function to type-check, since we will replace the body - /// later on llvm level. Using `loop {}` does not cover all return types anymore, - /// so instead we build something that should pass. We also add a inline_asm - /// line, as one more barrier for rustc to prevent inlining of this function. - /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see - /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient. - /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate - /// this function (which should never happen, since it is only a placeholder). - /// Finally, we also add back_box usages of all input arguments, to prevent rustc - /// from optimizing any arguments away. - fn gen_enzyme_body( + // Will generate a body of the type: + // ``` + // { + // unsafe { + // asm!("NOP"); + // } + // ::core::hint::black_box(primal(args)); + // ::core::hint::black_box((args, ret)); + // <This part remains to be done by following function> + // } + // ``` + fn init_body_helper( ecx: &ExtCtxt<'_>, - x: &AutoDiffAttrs, - n_active: u32, - sig: &ast::FnSig, - d_sig: &ast::FnSig, + span: Span, primal: Ident, new_names: &[String], - span: Span, sig_span: Span, new_decl_span: Span, - idents: Vec<Ident>, + idents: &[Ident], errored: bool, - ) -> P<ast::Block> { + ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) { let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); let noop = ast::InlineAsm { asm_macro: ast::AsmMacro::Asm, @@ -431,6 +428,54 @@ mod llvm_enzyme { } body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); + (body, primal_call, black_box_primal_call, blackbox_call_expr) + } + + /// We only want this function to type-check, since we will replace the body + /// later on llvm level. Using `loop {}` does not cover all return types anymore, + /// so instead we build something that should pass. We also add a inline_asm + /// line, as one more barrier for rustc to prevent inlining of this function. + /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see + /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient. + /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate + /// this function (which should never happen, since it is only a placeholder). + /// Finally, we also add back_box usages of all input arguments, to prevent rustc + /// from optimizing any arguments away. + fn gen_enzyme_body( + ecx: &ExtCtxt<'_>, + x: &AutoDiffAttrs, + n_active: u32, + sig: &ast::FnSig, + d_sig: &ast::FnSig, + primal: Ident, + new_names: &[String], + span: Span, + sig_span: Span, + _new_decl_span: Span, + idents: Vec<Ident>, + errored: bool, + ) -> P<ast::Block> { + let new_decl_span = d_sig.span; + + // Just adding some default inline-asm and black_box usages to prevent early inlining + // and optimizations which alter the function signature. + // + // The bb_primal_call is the black_box call of the primal function. We keep it around, + // since it has the convenient property of returning the type of the primal function, + // Remember, we only care to match types here. + // No matter which return we pick, we always wrap it into a std::hint::black_box call, + // to prevent rustc from propagating it into the caller. + let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper( + ecx, + span, + primal, + new_names, + sig_span, + new_decl_span, + &idents, + errored, + ); + if !has_ret(&d_sig.decl.output) { // there is no return type that we have to match, () works fine. return body; @@ -442,7 +487,7 @@ mod llvm_enzyme { if primal_ret && n_active == 0 && x.mode.is_rev() { // We only have the primal ret. - body.stmts.push(ecx.stmt_expr(black_box_primal_call)); + body.stmts.push(ecx.stmt_expr(bb_primal_call)); return body; } @@ -534,11 +579,11 @@ mod llvm_enzyme { return body; } [arg] => { - ret = ecx.expr_call(new_decl_span, blackbox_call_expr, thin_vec![arg.clone()]); + ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]); } args => { let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into()); - ret = ecx.expr_call(new_decl_span, blackbox_call_expr, thin_vec![ret_tuple]); + ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]); } } assert!(has_ret(&d_sig.decl.output)); @@ -551,7 +596,7 @@ mod llvm_enzyme { ecx: &ExtCtxt<'_>, span: Span, primal: Ident, - idents: Vec<Ident>, + idents: &[Ident], ) -> P<ast::Expr> { let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; if has_self { |
