diff options
| author | bit-aloo <sshourya17@gmail.com> | 2025-04-20 15:54:12 +0530 |
|---|---|---|
| committer | bit-aloo <sshourya17@gmail.com> | 2025-04-28 21:10:32 +0530 |
| commit | 7018392337a938e25c9ed9190c4f0966737fffdb (patch) | |
| tree | 94fbc240f3d868fe229a4848381f76a8e52d80f5 | |
| parent | 9bc04016e6dffd6398ea62f05b9320e0198ab0be (diff) | |
| download | rust-7018392337a938e25c9ed9190c4f0966737fffdb.tar.gz rust-7018392337a938e25c9ed9190c4f0966737fffdb.zip | |
remove noinline attribute and add alwaysinline after AD pass
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/attributes.rs | 5 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/back/lto.rs | 28 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/context.rs | 10 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 8 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/llvm/mod.rs | 26 | ||||
| -rw-r--r-- | compiler/rustc_codegen_llvm/src/type_.rs | 4 | ||||
| -rw-r--r-- | compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp | 10 | ||||
| -rw-r--r-- | tests/codegen/autodiff/inline.rs | 23 |
8 files changed, 104 insertions, 10 deletions
diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index 545708384d9..176fb72dfdc 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -1,5 +1,4 @@ //! Set and unset common attributes on LLVM values. - use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_codegen_ssa::traits::*; use rustc_hir::def_id::DefId; @@ -32,7 +31,7 @@ pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) - llvm::HasAttributeAtIndex(llfn, idx, attr) } -pub(crate) fn has_string_attr(llfn: &Value, name: *const i8) -> bool { +pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool { llvm::HasStringAttribute(llfn, name) } @@ -40,7 +39,7 @@ pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: Attrib llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind); } -pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: *const i8) { +pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) { llvm::RemoveStringAttrFromFn(llfn, name); } diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 925898d8173..39b3a23e0b1 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -28,8 +28,9 @@ use crate::back::write::{ use crate::errors::{ DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro, }; +use crate::llvm::AttributePlace::Function; use crate::llvm::{self, build_string}; -use crate::{LlvmCodegenBackend, ModuleLlvm}; +use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes}; /// We keep track of the computed LTO cache keys from the previous /// session to determine which CGUs we can reuse. @@ -666,6 +667,31 @@ pub(crate) fn run_pass_manager( } if cfg!(llvm_enzyme) && enable_ad && !thin { + let cx = + SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); + + for function in cx.get_functions() { + let enzyme_marker = "enzyme_marker"; + if attributes::has_string_attr(function, enzyme_marker) { + // Sanity check: Ensure 'noinline' is present before replacing it. + assert!( + !attributes::has_attr(function, Function, llvm::AttributeKind::NoInline), + "Expected __enzyme function to have 'noinline' before adding 'alwaysinline'" + ); + + attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline); + attributes::remove_string_attr_from_llfn(function, enzyme_marker); + + assert!( + !attributes::has_string_attr(function, enzyme_marker), + "Expected function to not have 'enzyme_marker'" + ); + + let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx); + attributes::apply_to_llfn(function, Function, &[always_inline]); + } + } + let opt_stage = llvm::OptStage::FatLTO; let stage = write::AutodiffStage::PostAD; if !config.autodiff.contains(&config::AutoDiff::NoPostopt) { diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 4ec69995518..ed50515b707 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -698,6 +698,16 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> { llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len()) }) } + + pub(crate) fn get_functions(&self) -> Vec<&'ll Value> { + let mut functions = vec![]; + let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) }; + while let Some(f) = func { + functions.push(f); + func = unsafe { llvm::LLVMGetNextFunction(f) } + } + functions + } } impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> { diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 08987a3ad14..2ad39fc8538 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -19,8 +19,12 @@ unsafe extern "C" { pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool; pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool; pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64; - pub(crate) fn LLVMRustHasFnAttribute(F: &Value, Name: *const c_char) -> bool; - pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char); + pub(crate) fn LLVMRustHasFnAttribute( + F: &Value, + Name: *const c_char, + NameLen: libc::size_t, + ) -> bool; + pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t); pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>; pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex( diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index 6ca81c651ed..d14aab06073 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -41,6 +41,32 @@ pub(crate) fn AddFunctionAttributes<'ll>( } } +pub(crate) fn HasAttributeAtIndex<'ll>( + llfn: &'ll Value, + idx: AttributePlace, + kind: AttributeKind, +) -> bool { + unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) } +} + +pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool { + unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) } +} + +pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) { + unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) } +} + +pub(crate) fn RemoveRustEnumAttributeAtIndex( + llfn: &Value, + place: AttributePlace, + kind: AttributeKind, +) { + unsafe { + LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind); + } +} + pub(crate) fn AddCallSiteAttributes<'ll>( callsite: &'ll Value, idx: AttributePlace, diff --git a/compiler/rustc_codegen_llvm/src/type_.rs b/compiler/rustc_codegen_llvm/src/type_.rs index b89ce90d1a1..169036f5152 100644 --- a/compiler/rustc_codegen_llvm/src/type_.rs +++ b/compiler/rustc_codegen_llvm/src/type_.rs @@ -128,6 +128,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> { (**self).borrow().llcx } + pub(crate) fn llmod(&self) -> &'ll llvm::Module { + (**self).borrow().llmod + } + pub(crate) fn isize_ty(&self) -> &'ll Type { (**self).borrow().isize_ty } diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 2871b3c0293..72369ab7b69 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -979,16 +979,18 @@ LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index, LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr)); } -extern "C" bool LLVMRustHasFnAttribute(LLVMValueRef F, const char *Name) { +extern "C" bool LLVMRustHasFnAttribute(LLVMValueRef F, const char *Name, + size_t NameLen) { if (auto *Fn = dyn_cast<Function>(unwrap<Value>(F))) { - return Fn->hasFnAttribute(Name); + return Fn->hasFnAttribute(StringRef(Name, NameLen)); } return false; } -extern "C" void LLVMRustRemoveFnAttribute(LLVMValueRef Fn, const char *Name) { +extern "C" void LLVMRustRemoveFnAttribute(LLVMValueRef Fn, const char *Name, + size_t NameLen) { if (auto *F = dyn_cast<Function>(unwrap<Value>(Fn))) { - F->removeFnAttr(Name); + F->removeFnAttr(StringRef(Name, NameLen)); } } diff --git a/tests/codegen/autodiff/inline.rs b/tests/codegen/autodiff/inline.rs new file mode 100644 index 00000000000..e90faa4aa38 --- /dev/null +++ b/tests/codegen/autodiff/inline.rs @@ -0,0 +1,23 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt +//@ no-prefer-dynamic +//@ needs-enzyme + +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[autodiff(d_square, Reverse, Duplicated, Active)] +fn square(x: &f64) -> f64 { + x * x +} + +// CHECK: ; inline::d_square +// CHECK-NEXT: ; Function Attrs: alwaysinline +// CHECK-NOT: noinline +// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE +fn main() { + let x = std::hint::black_box(3.0); + let mut dx1 = std::hint::black_box(1.0); + let _ = d_square(&x, &mut dx1, 1.0); + assert_eq!(dx1, 6.0); +} |
