about summary refs log tree commit diff
path: root/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp')
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp175
1 files changed, 138 insertions, 37 deletions
diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
index 7326f2e8e2a..37c2da4c23a 100644
--- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
@@ -10,6 +10,7 @@
 #include "llvm/IR/IntrinsicsARM.h"
 #include "llvm/IR/LLVMRemarkStreamer.h"
 #include "llvm/IR/Mangler.h"
+#include "llvm/IR/Value.h"
 #include "llvm/Remarks/RemarkStreamer.h"
 #include "llvm/Remarks/RemarkSerializer.h"
 #include "llvm/Remarks/RemarkFormat.h"
@@ -24,6 +25,13 @@
 
 #include <iostream>
 
+// for raw `write` in the bad-alloc handler
+#ifdef _MSC_VER
+#include <io.h>
+#else
+#include <unistd.h>
+#endif
+
 //===----------------------------------------------------------------------===
 //
 // This file defines alternate interfaces to core functions that are more
@@ -66,19 +74,45 @@ static LLVM_THREAD_LOCAL char *LastError;
 static void FatalErrorHandler(void *UserData,
                               const char* Reason,
                               bool GenCrashDiag) {
-  // Do the same thing that the default error handler does.
-  std::cerr << "LLVM ERROR: " << Reason << std::endl;
+  // Once upon a time we emitted "LLVM ERROR:" specifically to mimic LLVM. Then,
+  // we developed crater and other tools which only expose logs, not error codes.
+  // Use a more greppable prefix that will still match the "LLVM ERROR:" prefix.
+  std::cerr << "rustc-LLVM ERROR: " << Reason << std::endl;
 
   // Since this error handler exits the process, we have to run any cleanup that
   // LLVM would run after handling the error. This might change with an LLVM
   // upgrade.
+  //
+  // In practice, this will do nothing, because the only cleanup LLVM does is
+  // to remove all files that were registered with it via a frontend calling
+  // one of the `createOutputFile` family of functions in LLVM and passing true
+  // to RemoveFileOnSignal, something that rustc does not do. However, it would
+  // be... inadvisable to suddenly stop running these handlers, if LLVM gets
+  // "interesting" ideas in the future about what cleanup should be done.
+  // We might even find it useful for generating less artifacts.
   sys::RunInterruptHandlers();
 
   exit(101);
 }
 
-extern "C" void LLVMRustInstallFatalErrorHandler() {
+// Custom error handler for bad-alloc LLVM errors.
+//
+// It aborts the process without any further allocations, similar to LLVM's
+// default except that may be configured to `throw std::bad_alloc()` instead.
+static void BadAllocErrorHandler(void *UserData,
+                                 const char* Reason,
+                                 bool GenCrashDiag) {
+  const char *OOM = "rustc-LLVM ERROR: out of memory\n";
+  (void)!::write(2, OOM, strlen(OOM));
+  (void)!::write(2, Reason, strlen(Reason));
+  (void)!::write(2, "\n", 1);
+  abort();
+}
+
+extern "C" void LLVMRustInstallErrorHandlers() {
+  install_bad_alloc_error_handler(BadAllocErrorHandler);
   install_fatal_error_handler(FatalErrorHandler);
+  install_out_of_memory_new_handler();
 }
 
 extern "C" void LLVMRustDisableSystemDialogsOnCrash() {
@@ -109,7 +143,7 @@ extern "C" void LLVMRustSetNormalizedTarget(LLVMModuleRef M,
 
 extern "C" const char *LLVMRustPrintPassTimings(size_t *Len) {
   std::string buf;
-  raw_string_ostream SS(buf);
+  auto SS = raw_string_ostream(buf);
   TimerGroup::printAll(SS);
   SS.flush();
   *Len = buf.length();
@@ -120,7 +154,7 @@ extern "C" const char *LLVMRustPrintPassTimings(size_t *Len) {
 
 extern "C" const char *LLVMRustPrintStatistics(size_t *Len) {
   std::string buf;
-  raw_string_ostream SS(buf);
+  auto SS = raw_string_ostream(buf);
   llvm::PrintStatistics(SS);
   SS.flush();
   *Len = buf.length();
@@ -174,7 +208,7 @@ extern "C" LLVMValueRef LLVMRustGetOrInsertFunction(LLVMModuleRef M,
 extern "C" LLVMValueRef
 LLVMRustGetOrInsertGlobal(LLVMModuleRef M, const char *Name, size_t NameLen, LLVMTypeRef Ty) {
   Module *Mod = unwrap(M);
-  StringRef NameRef(Name, NameLen);
+  auto NameRef = StringRef(Name, NameLen);
 
   // We don't use Module::getOrInsertGlobal because that returns a Constant*,
   // which may either be the real GlobalVariable*, or a constant bitcast of it
@@ -285,7 +319,7 @@ static Attribute::AttrKind fromRust(LLVMRustAttribute Kind) {
 template<typename T> static inline void AddAttributes(T *t, unsigned Index,
                                                       LLVMAttributeRef *Attrs, size_t AttrsLen) {
   AttributeList PAL = t->getAttributes();
-  AttrBuilder B(t->getContext());
+  auto B = AttrBuilder(t->getContext());
   for (LLVMAttributeRef Attr : ArrayRef<LLVMAttributeRef>(Attrs, AttrsLen))
     B.addAttribute(unwrap(Attr));
   AttributeList PALNew = PAL.addAttributesAtIndex(t->getContext(), Index, B);
@@ -450,6 +484,20 @@ extern "C" void LLVMRustSetAlgebraicMath(LLVMValueRef V) {
   }
 }
 
+// Enable the reassoc fast-math flag, allowing transformations that pretend
+// floating-point addition and multiplication are associative.
+//
+// Note that this does NOT enable any flags which can cause a floating-point operation on
+// well-defined inputs to return poison, and therefore this function can be used to build
+// safe Rust intrinsics (such as fadd_algebraic).
+//
+// https://llvm.org/docs/LangRef.html#fast-math-flags
+extern "C" void LLVMRustSetAllowReassoc(LLVMValueRef V) {
+  if (auto I = dyn_cast<Instruction>(unwrap<Value>(V))) {
+    I->setHasAllowReassoc(true);
+  }
+}
+
 extern "C" LLVMValueRef
 LLVMRustBuildAtomicLoad(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Source,
                         const char *Name, LLVMAtomicOrdering Order) {
@@ -769,7 +817,7 @@ extern "C" uint32_t LLVMRustVersionMinor() { return LLVM_VERSION_MINOR; }
 
 extern "C" uint32_t LLVMRustVersionMajor() { return LLVM_VERSION_MAJOR; }
 
-extern "C" void LLVMRustAddModuleFlag(
+extern "C" void LLVMRustAddModuleFlagU32(
     LLVMModuleRef M,
     Module::ModFlagBehavior MergeBehavior,
     const char *Name,
@@ -777,6 +825,16 @@ extern "C" void LLVMRustAddModuleFlag(
   unwrap(M)->addModuleFlag(MergeBehavior, Name, Value);
 }
 
+extern "C" void LLVMRustAddModuleFlagString(
+    LLVMModuleRef M,
+    Module::ModFlagBehavior MergeBehavior,
+    const char *Name,
+    const char *Value,
+    size_t ValueLen) {
+  unwrap(M)->addModuleFlag(MergeBehavior, Name,
+    MDString::get(unwrap(M)->getContext(), StringRef(Value, ValueLen)));
+}
+
 extern "C" bool LLVMRustHasModuleFlag(LLVMModuleRef M, const char *Name,
                                       size_t Len) {
   return unwrap(M)->getModuleFlag(StringRef(Name, Len)) != nullptr;
@@ -1087,11 +1145,16 @@ extern "C" LLVMValueRef LLVMRustDIBuilderInsertDeclareAtEnd(
     LLVMRustDIBuilderRef Builder, LLVMValueRef V, LLVMMetadataRef VarInfo,
     uint64_t *AddrOps, unsigned AddrOpsCount, LLVMMetadataRef DL,
     LLVMBasicBlockRef InsertAtEnd) {
-  return wrap(Builder->insertDeclare(
+  auto Result = Builder->insertDeclare(
       unwrap(V), unwrap<DILocalVariable>(VarInfo),
       Builder->createExpression(llvm::ArrayRef<uint64_t>(AddrOps, AddrOpsCount)),
       DebugLoc(cast<MDNode>(unwrap(DL))),
-      unwrap(InsertAtEnd)));
+      unwrap(InsertAtEnd));
+#if LLVM_VERSION_GE(19, 0)
+  return wrap(Result.get<llvm::Instruction *>());
+#else
+  return wrap(Result);
+#endif
 }
 
 extern "C" LLVMMetadataRef LLVMRustDIBuilderCreateEnumerator(
@@ -1181,13 +1244,13 @@ extern "C" int64_t LLVMRustDIBuilderCreateOpLLVMFragment() {
 }
 
 extern "C" void LLVMRustWriteTypeToString(LLVMTypeRef Ty, RustStringRef Str) {
-  RawRustStringOstream OS(Str);
+  auto OS = RawRustStringOstream(Str);
   unwrap<llvm::Type>(Ty)->print(OS);
 }
 
 extern "C" void LLVMRustWriteValueToString(LLVMValueRef V,
                                            RustStringRef Str) {
-  RawRustStringOstream OS(Str);
+  auto OS = RawRustStringOstream(Str);
   if (!V) {
     OS << "(null)";
   } else {
@@ -1199,18 +1262,10 @@ extern "C" void LLVMRustWriteValueToString(LLVMValueRef V,
   }
 }
 
-// LLVMArrayType function does not support 64-bit ElementCount
-// FIXME: replace with LLVMArrayType2 when bumped minimal version to llvm-17
-// https://github.com/llvm/llvm-project/commit/35276f16e5a2cae0dfb49c0fbf874d4d2f177acc
-extern "C" LLVMTypeRef LLVMRustArrayType(LLVMTypeRef ElementTy,
-                                         uint64_t ElementCount) {
-  return wrap(ArrayType::get(unwrap(ElementTy), ElementCount));
-}
-
 DEFINE_SIMPLE_CONVERSION_FUNCTIONS(Twine, LLVMTwineRef)
 
 extern "C" void LLVMRustWriteTwineToString(LLVMTwineRef T, RustStringRef Str) {
-  RawRustStringOstream OS(Str);
+  auto OS = RawRustStringOstream(Str);
   unwrap(T)->print(OS);
 }
 
@@ -1222,11 +1277,11 @@ extern "C" void LLVMRustUnpackOptimizationDiagnostic(
   llvm::DiagnosticInfoOptimizationBase *Opt =
       static_cast<llvm::DiagnosticInfoOptimizationBase *>(unwrap(DI));
 
-  RawRustStringOstream PassNameOS(PassNameOut);
+  auto PassNameOS = RawRustStringOstream(PassNameOut);
   PassNameOS << Opt->getPassName();
   *FunctionOut = wrap(&Opt->getFunction());
 
-  RawRustStringOstream FilenameOS(FilenameOut);
+  auto FilenameOS = RawRustStringOstream(FilenameOut);
   DiagnosticLocation loc = Opt->getLocation();
   if (loc.isValid()) {
     *Line = loc.getLine();
@@ -1234,7 +1289,7 @@ extern "C" void LLVMRustUnpackOptimizationDiagnostic(
     FilenameOS << loc.getAbsolutePath();
   }
 
-  RawRustStringOstream MessageOS(MessageOut);
+  auto MessageOS = RawRustStringOstream(MessageOut);
   MessageOS << Opt->getMsg();
 }
 
@@ -1248,7 +1303,7 @@ enum class LLVMRustDiagnosticLevel {
 extern "C" void
 LLVMRustUnpackInlineAsmDiagnostic(LLVMDiagnosticInfoRef DI,
                                   LLVMRustDiagnosticLevel *LevelOut,
-                                  unsigned *CookieOut,
+                                  uint64_t *CookieOut,
                                   LLVMTwineRef *MessageOut) {
   // Undefined to call this not on an inline assembly diagnostic!
   llvm::DiagnosticInfoInlineAsm *IA =
@@ -1277,8 +1332,8 @@ LLVMRustUnpackInlineAsmDiagnostic(LLVMDiagnosticInfoRef DI,
 
 extern "C" void LLVMRustWriteDiagnosticInfoToString(LLVMDiagnosticInfoRef DI,
                                                     RustStringRef Str) {
-  RawRustStringOstream OS(Str);
-  DiagnosticPrinterRawOStream DP(OS);
+  auto OS = RawRustStringOstream(Str);
+  auto DP = DiagnosticPrinterRawOStream(OS);
   unwrap(DI)->print(DP);
 }
 
@@ -1392,7 +1447,7 @@ extern "C" LLVMTypeKind LLVMRustGetTypeKind(LLVMTypeRef Ty) {
   default:
     {
       std::string error;
-      llvm::raw_string_ostream stream(error);
+      auto stream = llvm::raw_string_ostream(error);
       stream << "Rust does not support the TypeID: " << unwrap(Ty)->getTypeID()
              << " for the type: " << *unwrap(Ty);
       stream.flush();
@@ -1418,7 +1473,7 @@ extern "C" bool LLVMRustUnpackSMDiagnostic(LLVMSMDiagnosticRef DRef,
                                            unsigned* RangesOut,
                                            size_t* NumRanges) {
   SMDiagnostic& D = *unwrap(DRef);
-  RawRustStringOstream MessageOS(MessageOut);
+  auto MessageOS = RawRustStringOstream(MessageOut);
   MessageOS << D.getMessage();
 
   switch (D.getKind()) {
@@ -1479,8 +1534,8 @@ extern "C" LLVMValueRef LLVMRustBuildCall(LLVMBuilderRef B, LLVMTypeRef Ty, LLVM
 }
 
 extern "C" LLVMValueRef LLVMRustGetInstrProfIncrementIntrinsic(LLVMModuleRef M) {
-  return wrap(llvm::Intrinsic::getDeclaration(unwrap(M),
-              (llvm::Intrinsic::ID)llvm::Intrinsic::instrprof_increment));
+  return wrap(llvm::Intrinsic::getDeclaration(
+      unwrap(M), llvm::Intrinsic::instrprof_increment));
 }
 
 extern "C" LLVMValueRef LLVMRustBuildMemCpy(LLVMBuilderRef B,
@@ -1525,6 +1580,31 @@ LLVMRustBuildInvoke(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Fn,
                                       Name));
 }
 
+extern "C" LLVMValueRef
+LLVMRustBuildCallBr(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Fn,
+                    LLVMBasicBlockRef DefaultDest,
+                    LLVMBasicBlockRef *IndirectDests, unsigned NumIndirectDests,
+                    LLVMValueRef *Args, unsigned NumArgs,
+                    OperandBundleDef **OpBundles, unsigned NumOpBundles,
+                    const char *Name) {
+  Value *Callee = unwrap(Fn);
+  FunctionType *FTy = unwrap<FunctionType>(Ty);
+
+  // FIXME: Is there a way around this?
+  std::vector<BasicBlock*> IndirectDestsUnwrapped;
+  IndirectDestsUnwrapped.reserve(NumIndirectDests);
+  for (unsigned i = 0; i < NumIndirectDests; ++i) {
+    IndirectDestsUnwrapped.push_back(unwrap(IndirectDests[i]));
+  }
+
+  return wrap(unwrap(B)->CreateCallBr(
+      FTy, Callee, unwrap(DefaultDest),
+      ArrayRef<BasicBlock*>(IndirectDestsUnwrapped),
+      ArrayRef<Value*>(unwrap(Args), NumArgs),
+      ArrayRef<OperandBundleDef>(*OpBundles, NumOpBundles),
+      Name));
+}
+
 extern "C" void LLVMRustPositionBuilderAtStart(LLVMBuilderRef B,
                                                LLVMBasicBlockRef BB) {
   auto Point = unwrap(BB)->getFirstInsertionPt();
@@ -1533,7 +1613,7 @@ extern "C" void LLVMRustPositionBuilderAtStart(LLVMBuilderRef B,
 
 extern "C" void LLVMRustSetComdat(LLVMModuleRef M, LLVMValueRef V,
                                   const char *Name, size_t NameLen) {
-  Triple TargetTriple(unwrap(M)->getTargetTriple());
+  Triple TargetTriple = Triple(unwrap(M)->getTargetTriple());
   GlobalObject *GV = unwrap<GlobalObject>(V);
   if (TargetTriple.supportsCOMDAT()) {
     StringRef NameRef(Name, NameLen);
@@ -1697,7 +1777,7 @@ extern "C" LLVMRustModuleBuffer*
 LLVMRustModuleBufferCreate(LLVMModuleRef M) {
   auto Ret = std::make_unique<LLVMRustModuleBuffer>();
   {
-    raw_string_ostream OS(Ret->data);
+    auto OS = raw_string_ostream(Ret->data);
     WriteBitcodeToFile(*unwrap(M), OS);
   }
   return Ret.release();
@@ -1727,8 +1807,8 @@ LLVMRustModuleCost(LLVMModuleRef M) {
 extern "C" void
 LLVMRustModuleInstructionStats(LLVMModuleRef M, RustStringRef Str)
 {
-  RawRustStringOstream OS(Str);
-  llvm::json::OStream JOS(OS);
+  auto OS = RawRustStringOstream(Str);
+  auto JOS = llvm::json::OStream(OS);
   auto Module = unwrap(M);
 
   JOS.object([&] {
@@ -1843,7 +1923,7 @@ extern "C" LLVMRustResult LLVMRustWriteImportLibrary(
     MinGW);
   if (Error) {
     std::string errorString;
-    llvm::raw_string_ostream stream(errorString);
+    auto stream = llvm::raw_string_ostream(errorString);
     stream << Error;
     stream.flush();
     LLVMRustSetLastError(errorString.c_str());
@@ -1922,7 +2002,11 @@ extern "C" void LLVMRustContextConfigureDiagnosticHandler(
         }
       }
       if (DiagnosticHandlerCallback) {
+#if LLVM_VERSION_GE(19, 0)
+        DiagnosticHandlerCallback(&DI, DiagnosticHandlerContext);
+#else
         DiagnosticHandlerCallback(DI, DiagnosticHandlerContext);
+#endif
         return true;
       }
       return false;
@@ -2027,7 +2111,7 @@ extern "C" void LLVMRustContextConfigureDiagnosticHandler(
 }
 
 extern "C" void LLVMRustGetMangledName(LLVMValueRef V, RustStringRef Str) {
-  RawRustStringOstream OS(Str);
+  auto OS = RawRustStringOstream(Str);
   GlobalValue *GV = unwrap<GlobalValue>(V);
   Mangler().getNameWithPrefix(OS, GV, true);
 }
@@ -2065,3 +2149,20 @@ extern "C" bool LLVMRustLLVMHasZlibCompressionForDebugSymbols() {
 extern "C" bool LLVMRustLLVMHasZstdCompressionForDebugSymbols() {
   return llvm::compression::zstd::isAvailable();
 }
+
+// Operations on composite constants.
+// These are clones of LLVM api functions that will become available in future releases.
+// They can be removed once Rust's minimum supported LLVM version supports them.
+// See https://github.com/rust-lang/rust/issues/121868
+// See https://llvm.org/doxygen/group__LLVMCCoreValueConstantComposite.html
+
+// FIXME: Remove when Rust's minimum supported LLVM version reaches 19.
+// https://github.com/llvm/llvm-project/commit/e1405e4f71c899420ebf8262d5e9745598419df8
+#if LLVM_VERSION_LT(19, 0)
+extern "C" LLVMValueRef LLVMConstStringInContext2(LLVMContextRef C,
+                                                  const char *Str,
+                                                  size_t Length,
+                                                  bool DontNullTerminate) {
+  return wrap(ConstantDataArray::getString(*unwrap(C), StringRef(Str, Length), !DontNullTerminate));
+}
+#endif