about summary refs log tree commit diff
path: root/compiler
diff options
context:
space:
mode:
authorbit-aloo <sshourya17@gmail.com>2025-04-20 15:54:12 +0530
committerbit-aloo <sshourya17@gmail.com>2025-04-28 21:10:32 +0530
commit7018392337a938e25c9ed9190c4f0966737fffdb (patch)
tree94fbc240f3d868fe229a4848381f76a8e52d80f5 /compiler
parent9bc04016e6dffd6398ea62f05b9320e0198ab0be (diff)
downloadrust-7018392337a938e25c9ed9190c4f0966737fffdb.tar.gz
rust-7018392337a938e25c9ed9190c4f0966737fffdb.zip
remove noinline attribute and add alwaysinline after AD pass
Diffstat (limited to 'compiler')
-rw-r--r--compiler/rustc_codegen_llvm/src/attributes.rs5
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs28
-rw-r--r--compiler/rustc_codegen_llvm/src/context.rs10
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs8
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/mod.rs26
-rw-r--r--compiler/rustc_codegen_llvm/src/type_.rs4
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp10
7 files changed, 81 insertions, 10 deletions
diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs
index 545708384d9..176fb72dfdc 100644
--- a/compiler/rustc_codegen_llvm/src/attributes.rs
+++ b/compiler/rustc_codegen_llvm/src/attributes.rs
@@ -1,5 +1,4 @@
 //! Set and unset common attributes on LLVM values.
-
 use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
 use rustc_codegen_ssa::traits::*;
 use rustc_hir::def_id::DefId;
@@ -32,7 +31,7 @@ pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) -
     llvm::HasAttributeAtIndex(llfn, idx, attr)
 }
 
-pub(crate) fn has_string_attr(llfn: &Value, name: *const i8) -> bool {
+pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool {
     llvm::HasStringAttribute(llfn, name)
 }
 
@@ -40,7 +39,7 @@ pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: Attrib
     llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind);
 }
 
-pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: *const i8) {
+pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) {
     llvm::RemoveStringAttrFromFn(llfn, name);
 }
 
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index 925898d8173..39b3a23e0b1 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -28,8 +28,9 @@ use crate::back::write::{
 use crate::errors::{
     DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro,
 };
+use crate::llvm::AttributePlace::Function;
 use crate::llvm::{self, build_string};
-use crate::{LlvmCodegenBackend, ModuleLlvm};
+use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};
 
 /// We keep track of the computed LTO cache keys from the previous
 /// session to determine which CGUs we can reuse.
@@ -666,6 +667,31 @@ pub(crate) fn run_pass_manager(
     }
 
     if cfg!(llvm_enzyme) && enable_ad && !thin {
+        let cx =
+            SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
+
+        for function in cx.get_functions() {
+            let enzyme_marker = "enzyme_marker";
+            if attributes::has_string_attr(function, enzyme_marker) {
+                // Sanity check: Ensure 'noinline' is present before replacing it.
+                assert!(
+                    !attributes::has_attr(function, Function, llvm::AttributeKind::NoInline),
+                    "Expected __enzyme function to have 'noinline' before adding 'alwaysinline'"
+                );
+
+                attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
+                attributes::remove_string_attr_from_llfn(function, enzyme_marker);
+
+                assert!(
+                    !attributes::has_string_attr(function, enzyme_marker),
+                    "Expected function to not have 'enzyme_marker'"
+                );
+
+                let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx);
+                attributes::apply_to_llfn(function, Function, &[always_inline]);
+            }
+        }
+
         let opt_stage = llvm::OptStage::FatLTO;
         let stage = write::AutodiffStage::PostAD;
         if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs
index 4ec69995518..ed50515b707 100644
--- a/compiler/rustc_codegen_llvm/src/context.rs
+++ b/compiler/rustc_codegen_llvm/src/context.rs
@@ -698,6 +698,16 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
             llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len())
         })
     }
+
+    pub(crate) fn get_functions(&self) -> Vec<&'ll Value> {
+        let mut functions = vec![];
+        let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) };
+        while let Some(f) = func {
+            functions.push(f);
+            func = unsafe { llvm::LLVMGetNextFunction(f) }
+        }
+        functions
+    }
 }
 
 impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index 08987a3ad14..2ad39fc8538 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -19,8 +19,12 @@ unsafe extern "C" {
     pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
     pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
     pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
-    pub(crate) fn LLVMRustHasFnAttribute(F: &Value, Name: *const c_char) -> bool;
-    pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char);
+    pub(crate) fn LLVMRustHasFnAttribute(
+        F: &Value,
+        Name: *const c_char,
+        NameLen: libc::size_t,
+    ) -> bool;
+    pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t);
     pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
     pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>;
     pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex(
diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs
index 6ca81c651ed..d14aab06073 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs
@@ -41,6 +41,32 @@ pub(crate) fn AddFunctionAttributes<'ll>(
     }
 }
 
+pub(crate) fn HasAttributeAtIndex<'ll>(
+    llfn: &'ll Value,
+    idx: AttributePlace,
+    kind: AttributeKind,
+) -> bool {
+    unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) }
+}
+
+pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool {
+    unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
+}
+
+pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) {
+    unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
+}
+
+pub(crate) fn RemoveRustEnumAttributeAtIndex(
+    llfn: &Value,
+    place: AttributePlace,
+    kind: AttributeKind,
+) {
+    unsafe {
+        LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind);
+    }
+}
+
 pub(crate) fn AddCallSiteAttributes<'ll>(
     callsite: &'ll Value,
     idx: AttributePlace,
diff --git a/compiler/rustc_codegen_llvm/src/type_.rs b/compiler/rustc_codegen_llvm/src/type_.rs
index b89ce90d1a1..169036f5152 100644
--- a/compiler/rustc_codegen_llvm/src/type_.rs
+++ b/compiler/rustc_codegen_llvm/src/type_.rs
@@ -128,6 +128,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
         (**self).borrow().llcx
     }
 
+    pub(crate) fn llmod(&self) -> &'ll llvm::Module {
+        (**self).borrow().llmod
+    }
+
     pub(crate) fn isize_ty(&self) -> &'ll Type {
         (**self).borrow().isize_ty
     }
diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
index 2871b3c0293..72369ab7b69 100644
--- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
@@ -979,16 +979,18 @@ LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index,
   LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr));
 }
 
-extern "C" bool LLVMRustHasFnAttribute(LLVMValueRef F, const char *Name) {
+extern "C" bool LLVMRustHasFnAttribute(LLVMValueRef F, const char *Name,
+                                       size_t NameLen) {
   if (auto *Fn = dyn_cast<Function>(unwrap<Value>(F))) {
-    return Fn->hasFnAttribute(Name);
+    return Fn->hasFnAttribute(StringRef(Name, NameLen));
   }
   return false;
 }
 
-extern "C" void LLVMRustRemoveFnAttribute(LLVMValueRef Fn, const char *Name) {
+extern "C" void LLVMRustRemoveFnAttribute(LLVMValueRef Fn, const char *Name,
+                                          size_t NameLen) {
   if (auto *F = dyn_cast<Function>(unwrap<Value>(Fn))) {
-    F->removeFnAttr(Name);
+    F->removeFnAttr(StringRef(Name, NameLen));
   }
 }