diff options
| author | Chris Denton <chris@chrisdenton.dev> | 2025-04-28 23:29:14 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-28 23:29:14 +0000 |
| commit | d4845e1b0b745b4183096c363e3f4e71a45c14c9 (patch) | |
| tree | d7249a0e31be055f454f8338d5608fffe5f53239 /compiler/rustc_codegen_llvm/src | |
| parent | 25cdf1f67463c9365d8d83778c933ec7480e940b (diff) | |
| parent | 7018392337a938e25c9ed9190c4f0966737fffdb (diff) | |
| download | rust-d4845e1b0b745b4183096c363e3f4e71a45c14c9.tar.gz rust-d4845e1b0b745b4183096c363e3f4e71a45c14c9.zip | |
Rollup merge of #139308 - Shourya742:2025-03-29-add-autodiff-inline, r=ZuseZ4
add autodiff inline closes: #138920 r? ```@ZuseZ4``` try-job: dist-aarch64-linux
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/attributes.rs | 17 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/lto.rs | 28 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/builder/autodiff.rs | 5 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/context.rs | 10 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 13 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/mod.rs | 26 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/type_.rs | 4 |
7 files changed, 101 insertions, 2 deletions
diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index e8c42d16733..176fb72dfdc 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -1,5 +1,4 @@ //! Set and unset common attributes on LLVM values. - use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_codegen_ssa::traits::*; use rustc_hir::def_id::DefId; @@ -28,6 +27,22 @@ pub(crate) fn apply_to_callsite(callsite: &Value, idx: AttributePlace, attrs: &[ } } +pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) -> bool { + llvm::HasAttributeAtIndex(llfn, idx, attr) +} + +pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool { + llvm::HasStringAttribute(llfn, name) +} + +pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: AttributeKind) { + llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind); +} + +pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) { + llvm::RemoveStringAttrFromFn(llfn, name); +} + /// Get LLVM attribute for the provided inline heuristic. #[inline] fn inline_attr<'ll>(cx: &CodegenCx<'ll, '_>, inline: InlineAttr) -> Option<&'ll Attribute> { diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 925898d8173..39b3a23e0b1 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -28,8 +28,9 @@ use crate::back::write::{ use crate::errors::{ DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro, }; +use crate::llvm::AttributePlace::Function; use crate::llvm::{self, build_string}; -use crate::{LlvmCodegenBackend, ModuleLlvm}; +use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes}; /// We keep track of the computed LTO cache keys from the previous /// session to determine which CGUs we can reuse. @@ -666,6 +667,31 @@ pub(crate) fn run_pass_manager( } if cfg!(llvm_enzyme) && enable_ad && !thin { + let cx = + SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); + + for function in cx.get_functions() { + let enzyme_marker = "enzyme_marker"; + if attributes::has_string_attr(function, enzyme_marker) { + // Sanity check: Ensure 'noinline' is present before replacing it. + assert!( + !attributes::has_attr(function, Function, llvm::AttributeKind::NoInline), + "Expected __enzyme function to have 'noinline' before adding 'alwaysinline'" + ); + + attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline); + attributes::remove_string_attr_from_llfn(function, enzyme_marker); + + assert!( + !attributes::has_string_attr(function, enzyme_marker), + "Expected function to not have 'enzyme_marker'" + ); + + let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx); + attributes::apply_to_llfn(function, Function, &[always_inline]); + } + } + let opt_stage = llvm::OptStage::FatLTO; let stage = write::AutodiffStage::PostAD; if !config.autodiff.contains(&config::AutoDiff::NoPostopt) { diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 0147bd5a665..c5c13ac097a 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -361,6 +361,11 @@ fn generate_enzyme_call<'ll>( let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); attributes::apply_to_llfn(ad_fn, Function, &[attr]); + // We add a made-up attribute just such that we can recognize it after AD to update + // (no)-inline attributes. We'll then also remove this attribute. + let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); + attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]); + // first, remove all calls from fnc let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); let br = llvm::LLVMRustGetTerminator(entry); diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 4ec69995518..ed50515b707 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -698,6 +698,16 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> { llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len()) }) } + + pub(crate) fn get_functions(&self) -> Vec<&'ll Value> { + let mut functions = vec![]; + let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) }; + while let Some(f) = func { + functions.push(f); + func = unsafe { llvm::LLVMGetNextFunction(f) } + } + functions + } } impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> { diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index a9b3bdf7344..2ad39fc8538 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -19,6 +19,19 @@ unsafe extern "C" { pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool; pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool; pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64; + pub(crate) fn LLVMRustHasFnAttribute( + F: &Value, + Name: *const c_char, + NameLen: libc::size_t, + ) -> bool; + pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t); + pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>; + pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex( + Fn: &Value, + index: c_uint, + kind: AttributeKind, + ); } unsafe extern "C" { diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index 6ca81c651ed..d14aab06073 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -41,6 +41,32 @@ pub(crate) fn AddFunctionAttributes<'ll>( } } +pub(crate) fn HasAttributeAtIndex<'ll>( + llfn: &'ll Value, + idx: AttributePlace, + kind: AttributeKind, +) -> bool { + unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) } +} + +pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool { + unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) } +} + +pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) { + unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) } +} + +pub(crate) fn RemoveRustEnumAttributeAtIndex( + llfn: &Value, + place: AttributePlace, + kind: AttributeKind, +) { + unsafe { + LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind); + } +} + pub(crate) fn AddCallSiteAttributes<'ll>( callsite: &'ll Value, idx: AttributePlace, diff --git a/compiler/rustc_codegen_llvm/src/type_.rs b/compiler/rustc_codegen_llvm/src/type_.rs index b89ce90d1a1..169036f5152 100644 --- a/compiler/rustc_codegen_llvm/src/type_.rs +++ b/compiler/rustc_codegen_llvm/src/type_.rs @@ -128,6 +128,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> { (**self).borrow().llcx } + pub(crate) fn llmod(&self) -> &'ll llvm::Module { + (**self).borrow().llmod + } + pub(crate) fn isize_ty(&self) -> &'ll Type { (**self).borrow().isize_ty } |
