about summary refs log tree commit diff
path: root/compiler/rustc_codegen_llvm/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_codegen_llvm/src')
-rw-r--r--compiler/rustc_codegen_llvm/src/attributes.rs17
-rw-r--r--compiler/rustc_codegen_llvm/src/back/lto.rs28
-rw-r--r--compiler/rustc_codegen_llvm/src/builder/autodiff.rs5
-rw-r--r--compiler/rustc_codegen_llvm/src/context.rs10
-rw-r--r--compiler/rustc_codegen_llvm/src/lib.rs8
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs13
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm/mod.rs26
-rw-r--r--compiler/rustc_codegen_llvm/src/llvm_util.rs85
-rw-r--r--compiler/rustc_codegen_llvm/src/type_.rs4
9 files changed, 187 insertions, 9 deletions
diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs
index e8c42d16733..176fb72dfdc 100644
--- a/compiler/rustc_codegen_llvm/src/attributes.rs
+++ b/compiler/rustc_codegen_llvm/src/attributes.rs
@@ -1,5 +1,4 @@
 //! Set and unset common attributes on LLVM values.
-
 use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
 use rustc_codegen_ssa::traits::*;
 use rustc_hir::def_id::DefId;
@@ -28,6 +27,22 @@ pub(crate) fn apply_to_callsite(callsite: &Value, idx: AttributePlace, attrs: &[
     }
 }
 
+pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) -> bool {
+    llvm::HasAttributeAtIndex(llfn, idx, attr)
+}
+
+pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool {
+    llvm::HasStringAttribute(llfn, name)
+}
+
+pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: AttributeKind) {
+    llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind);
+}
+
+pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) {
+    llvm::RemoveStringAttrFromFn(llfn, name);
+}
+
 /// Get LLVM attribute for the provided inline heuristic.
 #[inline]
 fn inline_attr<'ll>(cx: &CodegenCx<'ll, '_>, inline: InlineAttr) -> Option<&'ll Attribute> {
diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs
index 925898d8173..39b3a23e0b1 100644
--- a/compiler/rustc_codegen_llvm/src/back/lto.rs
+++ b/compiler/rustc_codegen_llvm/src/back/lto.rs
@@ -28,8 +28,9 @@ use crate::back::write::{
 use crate::errors::{
     DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro,
 };
+use crate::llvm::AttributePlace::Function;
 use crate::llvm::{self, build_string};
-use crate::{LlvmCodegenBackend, ModuleLlvm};
+use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};
 
 /// We keep track of the computed LTO cache keys from the previous
 /// session to determine which CGUs we can reuse.
@@ -666,6 +667,31 @@ pub(crate) fn run_pass_manager(
     }
 
     if cfg!(llvm_enzyme) && enable_ad && !thin {
+        let cx =
+            SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
+
+        for function in cx.get_functions() {
+            let enzyme_marker = "enzyme_marker";
+            if attributes::has_string_attr(function, enzyme_marker) {
+                // Sanity check: Ensure 'noinline' is present before replacing it.
+                assert!(
+                    !attributes::has_attr(function, Function, llvm::AttributeKind::NoInline),
+                    "Expected __enzyme function to have 'noinline' before adding 'alwaysinline'"
+                );
+
+                attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
+                attributes::remove_string_attr_from_llfn(function, enzyme_marker);
+
+                assert!(
+                    !attributes::has_string_attr(function, enzyme_marker),
+                    "Expected function to not have 'enzyme_marker'"
+                );
+
+                let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx);
+                attributes::apply_to_llfn(function, Function, &[always_inline]);
+            }
+        }
+
         let opt_stage = llvm::OptStage::FatLTO;
         let stage = write::AutodiffStage::PostAD;
         if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
index 0147bd5a665..c5c13ac097a 100644
--- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
+++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs
@@ -361,6 +361,11 @@ fn generate_enzyme_call<'ll>(
         let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
         attributes::apply_to_llfn(ad_fn, Function, &[attr]);
 
+        // We add a made-up attribute just such that we can recognize it after AD to update
+        // (no)-inline attributes. We'll then also remove this attribute.
+        let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
+        attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);
+
         // first, remove all calls from fnc
         let entry = llvm::LLVMGetFirstBasicBlock(outer_fn);
         let br = llvm::LLVMRustGetTerminator(entry);
diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs
index 4ec69995518..ed50515b707 100644
--- a/compiler/rustc_codegen_llvm/src/context.rs
+++ b/compiler/rustc_codegen_llvm/src/context.rs
@@ -698,6 +698,16 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
             llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len())
         })
     }
+
+    pub(crate) fn get_functions(&self) -> Vec<&'ll Value> {
+        let mut functions = vec![];
+        let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) };
+        while let Some(f) = func {
+            functions.push(f);
+            func = unsafe { llvm::LLVMGetNextFunction(f) }
+        }
+        functions
+    }
 }
 
 impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {
diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs
index b2feeacdb46..e8010ec9fc4 100644
--- a/compiler/rustc_codegen_llvm/src/lib.rs
+++ b/compiler/rustc_codegen_llvm/src/lib.rs
@@ -29,7 +29,7 @@ use back::owned_target_machine::OwnedTargetMachine;
 use back::write::{create_informational_target_machine, create_target_machine};
 use context::SimpleCx;
 use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig};
-use llvm_util::target_features_cfg;
+use llvm_util::target_config;
 use rustc_ast::expand::allocator::AllocatorKind;
 use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
 use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
@@ -37,7 +37,7 @@ use rustc_codegen_ssa::back::write::{
     CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn,
 };
 use rustc_codegen_ssa::traits::*;
-use rustc_codegen_ssa::{CodegenResults, CompiledModule, ModuleCodegen};
+use rustc_codegen_ssa::{CodegenResults, CompiledModule, ModuleCodegen, TargetConfig};
 use rustc_data_structures::fx::FxIndexMap;
 use rustc_errors::{DiagCtxtHandle, FatalError};
 use rustc_metadata::EncodedMetadata;
@@ -338,8 +338,8 @@ impl CodegenBackend for LlvmCodegenBackend {
         llvm_util::print_version();
     }
 
-    fn target_features_cfg(&self, sess: &Session) -> (Vec<Symbol>, Vec<Symbol>) {
-        target_features_cfg(sess)
+    fn target_config(&self, sess: &Session) -> TargetConfig {
+        target_config(sess)
     }
 
     fn codegen_crate<'tcx>(
diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
index a9b3bdf7344..2ad39fc8538 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
@@ -19,6 +19,19 @@ unsafe extern "C" {
     pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
     pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
     pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
+    pub(crate) fn LLVMRustHasFnAttribute(
+        F: &Value,
+        Name: *const c_char,
+        NameLen: libc::size_t,
+    ) -> bool;
+    pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t);
+    pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
+    pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>;
+    pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex(
+        Fn: &Value,
+        index: c_uint,
+        kind: AttributeKind,
+    );
 }
 
 unsafe extern "C" {
diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs
index 6ca81c651ed..d14aab06073 100644
--- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs
@@ -41,6 +41,32 @@ pub(crate) fn AddFunctionAttributes<'ll>(
     }
 }
 
+pub(crate) fn HasAttributeAtIndex<'ll>(
+    llfn: &'ll Value,
+    idx: AttributePlace,
+    kind: AttributeKind,
+) -> bool {
+    unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) }
+}
+
+pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool {
+    unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
+}
+
+pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) {
+    unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
+}
+
+pub(crate) fn RemoveRustEnumAttributeAtIndex(
+    llfn: &Value,
+    place: AttributePlace,
+    kind: AttributeKind,
+) {
+    unsafe {
+        LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind);
+    }
+}
+
 pub(crate) fn AddCallSiteAttributes<'ll>(
     callsite: &'ll Value,
     idx: AttributePlace,
diff --git a/compiler/rustc_codegen_llvm/src/llvm_util.rs b/compiler/rustc_codegen_llvm/src/llvm_util.rs
index 36e35f81392..6412a537a79 100644
--- a/compiler/rustc_codegen_llvm/src/llvm_util.rs
+++ b/compiler/rustc_codegen_llvm/src/llvm_util.rs
@@ -6,6 +6,7 @@ use std::sync::Once;
 use std::{ptr, slice, str};
 
 use libc::c_int;
+use rustc_codegen_ssa::TargetConfig;
 use rustc_codegen_ssa::base::wants_wasm_eh;
 use rustc_codegen_ssa::codegen_attrs::check_tied_features;
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
@@ -302,7 +303,7 @@ pub(crate) fn to_llvm_features<'a>(sess: &Session, s: &'a str) -> Option<LLVMFea
 /// Must express features in the way Rust understands them.
 ///
 /// We do not have to worry about RUSTC_SPECIFIC_FEATURES here, those are handled outside codegen.
-pub(crate) fn target_features_cfg(sess: &Session) -> (Vec<Symbol>, Vec<Symbol>) {
+pub(crate) fn target_config(sess: &Session) -> TargetConfig {
     // Add base features for the target.
     // We do *not* add the -Ctarget-features there, and instead duplicate the logic for that below.
     // The reason is that if LLVM considers a feature implied but we do not, we don't want that to
@@ -402,7 +403,85 @@ pub(crate) fn target_features_cfg(sess: &Session) -> (Vec<Symbol>, Vec<Symbol>)
 
     let target_features = f(false);
     let unstable_target_features = f(true);
-    (target_features, unstable_target_features)
+    let mut cfg = TargetConfig {
+        target_features,
+        unstable_target_features,
+        has_reliable_f16: true,
+        has_reliable_f16_math: true,
+        has_reliable_f128: true,
+        has_reliable_f128_math: true,
+    };
+
+    update_target_reliable_float_cfg(sess, &mut cfg);
+    cfg
+}
+
+/// Determine whether or not experimental float types are reliable based on known bugs.
+fn update_target_reliable_float_cfg(sess: &Session, cfg: &mut TargetConfig) {
+    let target_arch = sess.target.arch.as_ref();
+    let target_os = sess.target.options.os.as_ref();
+    let target_env = sess.target.options.env.as_ref();
+    let target_abi = sess.target.options.abi.as_ref();
+    let target_pointer_width = sess.target.pointer_width;
+
+    cfg.has_reliable_f16 = match (target_arch, target_os) {
+        // Selection failure <https://github.com/llvm/llvm-project/issues/50374>
+        ("s390x", _) => false,
+        // Unsupported <https://github.com/llvm/llvm-project/issues/94434>
+        ("arm64ec", _) => false,
+        // MinGW ABI bugs <https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115054>
+        ("x86_64", "windows") if target_env == "gnu" && target_abi != "llvm" => false,
+        // Infinite recursion <https://github.com/llvm/llvm-project/issues/97981>
+        ("csky", _) => false,
+        ("hexagon", _) => false,
+        ("powerpc" | "powerpc64", _) => false,
+        ("sparc" | "sparc64", _) => false,
+        ("wasm32" | "wasm64", _) => false,
+        // `f16` support only requires that symbols converting to and from `f32` are available. We
+        // provide these in `compiler-builtins`, so `f16` should be available on all platforms that
+        // do not have other ABI issues or LLVM crashes.
+        _ => true,
+    };
+
+    cfg.has_reliable_f128 = match (target_arch, target_os) {
+        // Unsupported <https://github.com/llvm/llvm-project/issues/94434>
+        ("arm64ec", _) => false,
+        // Selection bug <https://github.com/llvm/llvm-project/issues/96432>
+        ("mips64" | "mips64r6", _) => false,
+        // Selection bug <https://github.com/llvm/llvm-project/issues/95471>
+        ("nvptx64", _) => false,
+        // ABI bugs <https://github.com/rust-lang/rust/issues/125109> et al. (full
+        // list at <https://github.com/rust-lang/rust/issues/116909>)
+        ("powerpc" | "powerpc64", _) => false,
+        // ABI unsupported  <https://github.com/llvm/llvm-project/issues/41838>
+        ("sparc", _) => false,
+        // Stack alignment bug <https://github.com/llvm/llvm-project/issues/77401>. NB: tests may
+        // not fail if our compiler-builtins is linked.
+        ("x86", _) => false,
+        // MinGW ABI bugs <https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115054>
+        ("x86_64", "windows") if target_env == "gnu" && target_abi != "llvm" => false,
+        // There are no known problems on other platforms, so the only requirement is that symbols
+        // are available. `compiler-builtins` provides all symbols required for core `f128`
+        // support, so this should work for everything else.
+        _ => true,
+    };
+
+    // Assume that working `f16` means working `f16` math for most platforms, since
+    // operations just go through `f32`.
+    cfg.has_reliable_f16_math = cfg.has_reliable_f16;
+
+    cfg.has_reliable_f128_math = match (target_arch, target_os) {
+        // LLVM lowers `fp128` math to `long double` symbols even on platforms where
+        // `long double` is not IEEE binary128. See
+        // <https://github.com/llvm/llvm-project/issues/44744>.
+        //
+        // This rules out anything that doesn't have `long double` = `binary128`; <= 32 bits
+        // (ld is `f64`), anything other than Linux (Windows and MacOS use `f64`), and `x86`
+        // (ld is 80-bit extended precision).
+        ("x86_64", _) => false,
+        (_, "linux") if target_pointer_width == 64 => true,
+        _ => false,
+    } && cfg.has_reliable_f128;
 }
 
 pub(crate) fn print_version() {
@@ -686,7 +765,7 @@ pub(crate) fn global_llvm_features(
                 )
             } else if let Some(feature) = feature.strip_prefix('-') {
                 // FIXME: Why do we not remove implied features on "-" here?
-                // We do the equivalent above in `target_features_cfg`.
+                // We do the equivalent above in `target_config`.
                 // See <https://github.com/rust-lang/rust/issues/134792>.
                 all_rust_features.push((false, feature));
             } else if !feature.is_empty() {
diff --git a/compiler/rustc_codegen_llvm/src/type_.rs b/compiler/rustc_codegen_llvm/src/type_.rs
index b89ce90d1a1..169036f5152 100644
--- a/compiler/rustc_codegen_llvm/src/type_.rs
+++ b/compiler/rustc_codegen_llvm/src/type_.rs
@@ -128,6 +128,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
         (**self).borrow().llcx
     }
 
+    pub(crate) fn llmod(&self) -> &'ll llvm::Module {
+        (**self).borrow().llmod
+    }
+
     pub(crate) fn isize_ty(&self) -> &'ll Type {
         (**self).borrow().isize_ty
     }