about summary refs log tree commit diff
path: root/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp')
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp207
1 files changed, 124 insertions, 83 deletions
diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
index 2e135fbe2bd..9850f395a0f 100644
--- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
@@ -6,6 +6,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/Object/Archive.h"
+#include "llvm/Object/COFFImportFile.h"
 #include "llvm/Object/ObjectFile.h"
 #include "llvm/Bitcode/BitcodeWriterPass.h"
 #include "llvm/Support/Signals.h"
@@ -70,17 +71,6 @@ extern "C" void LLVMRustInstallFatalErrorHandler() {
   install_fatal_error_handler(FatalErrorHandler);
 }
 
-extern "C" LLVMMemoryBufferRef
-LLVMRustCreateMemoryBufferWithContentsOfFile(const char *Path) {
-  ErrorOr<std::unique_ptr<MemoryBuffer>> BufOr =
-      MemoryBuffer::getFile(Path, -1, false);
-  if (!BufOr) {
-    LLVMRustSetLastError(BufOr.getError().message().c_str());
-    return nullptr;
-  }
-  return wrap(BufOr.get().release());
-}
-
 extern "C" char *LLVMRustGetLastError(void) {
   char *Ret = LastError;
   LastError = nullptr;
@@ -213,56 +203,57 @@ static Attribute::AttrKind fromRust(LLVMRustAttribute Kind) {
   report_fatal_error("bad AttributeKind");
 }
 
+template<typename T> static inline void AddAttribute(T *t, unsigned Index, Attribute Attr) {
+#if LLVM_VERSION_LT(14, 0)
+  t->addAttribute(Index, Attr);
+#else
+  t->addAttributeAtIndex(Index, Attr);
+#endif
+}
+
 extern "C" void LLVMRustAddCallSiteAttribute(LLVMValueRef Instr, unsigned Index,
                                              LLVMRustAttribute RustAttr) {
   CallBase *Call = unwrap<CallBase>(Instr);
   Attribute Attr = Attribute::get(Call->getContext(), fromRust(RustAttr));
-  Call->addAttribute(Index, Attr);
+  AddAttribute(Call, Index, Attr);
 }
 
 extern "C" void LLVMRustAddCallSiteAttrString(LLVMValueRef Instr, unsigned Index,
                                               const char *Name) {
   CallBase *Call = unwrap<CallBase>(Instr);
   Attribute Attr = Attribute::get(Call->getContext(), Name);
-  Call->addAttribute(Index, Attr);
+  AddAttribute(Call, Index, Attr);
 }
 
-
 extern "C" void LLVMRustAddAlignmentCallSiteAttr(LLVMValueRef Instr,
                                                  unsigned Index,
                                                  uint32_t Bytes) {
   CallBase *Call = unwrap<CallBase>(Instr);
-  AttrBuilder B;
-  B.addAlignmentAttr(Bytes);
-  Call->setAttributes(Call->getAttributes().addAttributes(
-      Call->getContext(), Index, B));
+  Attribute Attr = Attribute::getWithAlignment(Call->getContext(), Align(Bytes));
+  AddAttribute(Call, Index, Attr);
 }
 
 extern "C" void LLVMRustAddDereferenceableCallSiteAttr(LLVMValueRef Instr,
                                                        unsigned Index,
                                                        uint64_t Bytes) {
   CallBase *Call = unwrap<CallBase>(Instr);
-  AttrBuilder B;
-  B.addDereferenceableAttr(Bytes);
-  Call->setAttributes(Call->getAttributes().addAttributes(
-      Call->getContext(), Index, B));
+  Attribute Attr = Attribute::getWithDereferenceableBytes(Call->getContext(), Bytes);
+  AddAttribute(Call, Index, Attr);
 }
 
 extern "C" void LLVMRustAddDereferenceableOrNullCallSiteAttr(LLVMValueRef Instr,
                                                              unsigned Index,
                                                              uint64_t Bytes) {
   CallBase *Call = unwrap<CallBase>(Instr);
-  AttrBuilder B;
-  B.addDereferenceableOrNullAttr(Bytes);
-  Call->setAttributes(Call->getAttributes().addAttributes(
-      Call->getContext(), Index, B));
+  Attribute Attr = Attribute::getWithDereferenceableOrNullBytes(Call->getContext(), Bytes);
+  AddAttribute(Call, Index, Attr);
 }
 
 extern "C" void LLVMRustAddByValCallSiteAttr(LLVMValueRef Instr, unsigned Index,
                                              LLVMTypeRef Ty) {
   CallBase *Call = unwrap<CallBase>(Instr);
   Attribute Attr = Attribute::getWithByValType(Call->getContext(), unwrap(Ty));
-  Call->addAttribute(Index, Attr);
+  AddAttribute(Call, Index, Attr);
 }
 
 extern "C" void LLVMRustAddStructRetCallSiteAttr(LLVMValueRef Instr, unsigned Index,
@@ -273,48 +264,44 @@ extern "C" void LLVMRustAddStructRetCallSiteAttr(LLVMValueRef Instr, unsigned In
 #else
   Attribute Attr = Attribute::get(Call->getContext(), Attribute::StructRet);
 #endif
-  Call->addAttribute(Index, Attr);
+  AddAttribute(Call, Index, Attr);
 }
 
 extern "C" void LLVMRustAddFunctionAttribute(LLVMValueRef Fn, unsigned Index,
                                              LLVMRustAttribute RustAttr) {
   Function *A = unwrap<Function>(Fn);
   Attribute Attr = Attribute::get(A->getContext(), fromRust(RustAttr));
-  AttrBuilder B(Attr);
-  A->addAttributes(Index, B);
+  AddAttribute(A, Index, Attr);
 }
 
 extern "C" void LLVMRustAddAlignmentAttr(LLVMValueRef Fn,
                                          unsigned Index,
                                          uint32_t Bytes) {
   Function *A = unwrap<Function>(Fn);
-  AttrBuilder B;
-  B.addAlignmentAttr(Bytes);
-  A->addAttributes(Index, B);
+  AddAttribute(A, Index, Attribute::getWithAlignment(
+      A->getContext(), llvm::Align(Bytes)));
 }
 
 extern "C" void LLVMRustAddDereferenceableAttr(LLVMValueRef Fn, unsigned Index,
                                                uint64_t Bytes) {
   Function *A = unwrap<Function>(Fn);
-  AttrBuilder B;
-  B.addDereferenceableAttr(Bytes);
-  A->addAttributes(Index, B);
+  AddAttribute(A, Index, Attribute::getWithDereferenceableBytes(A->getContext(),
+                                                                Bytes));
 }
 
 extern "C" void LLVMRustAddDereferenceableOrNullAttr(LLVMValueRef Fn,
                                                      unsigned Index,
                                                      uint64_t Bytes) {
   Function *A = unwrap<Function>(Fn);
-  AttrBuilder B;
-  B.addDereferenceableOrNullAttr(Bytes);
-  A->addAttributes(Index, B);
+  AddAttribute(A, Index, Attribute::getWithDereferenceableOrNullBytes(
+      A->getContext(), Bytes));
 }
 
 extern "C" void LLVMRustAddByValAttr(LLVMValueRef Fn, unsigned Index,
                                      LLVMTypeRef Ty) {
   Function *F = unwrap<Function>(Fn);
   Attribute Attr = Attribute::getWithByValType(F->getContext(), unwrap(Ty));
-  F->addAttribute(Index, Attr);
+  AddAttribute(F, Index, Attr);
 }
 
 extern "C" void LLVMRustAddStructRetAttr(LLVMValueRef Fn, unsigned Index,
@@ -325,7 +312,7 @@ extern "C" void LLVMRustAddStructRetAttr(LLVMValueRef Fn, unsigned Index,
 #else
   Attribute Attr = Attribute::get(F->getContext(), Attribute::StructRet);
 #endif
-  F->addAttribute(Index, Attr);
+  AddAttribute(F, Index, Attr);
 }
 
 extern "C" void LLVMRustAddFunctionAttrStringValue(LLVMValueRef Fn,
@@ -333,9 +320,8 @@ extern "C" void LLVMRustAddFunctionAttrStringValue(LLVMValueRef Fn,
                                                    const char *Name,
                                                    const char *Value) {
   Function *F = unwrap<Function>(Fn);
-  AttrBuilder B;
-  B.addAttribute(Name, Value);
-  F->addAttributes(Index, B);
+  AddAttribute(F, Index, Attribute::get(
+      F->getContext(), StringRef(Name), StringRef(Value)));
 }
 
 extern "C" void LLVMRustRemoveFunctionAttributes(LLVMValueRef Fn,
@@ -345,7 +331,12 @@ extern "C" void LLVMRustRemoveFunctionAttributes(LLVMValueRef Fn,
   Attribute Attr = Attribute::get(F->getContext(), fromRust(RustAttr));
   AttrBuilder B(Attr);
   auto PAL = F->getAttributes();
-  auto PALNew = PAL.removeAttributes(F->getContext(), Index, B);
+  AttributeList PALNew;
+#if LLVM_VERSION_LT(14, 0)
+  PALNew = PAL.removeAttributes(F->getContext(), Index, B);
+#else
+  PALNew = PAL.removeAttributesAtIndex(F->getContext(), Index, B);
+#endif
   F->setAttributes(PALNew);
 }
 
@@ -359,11 +350,10 @@ extern "C" void LLVMRustSetFastMath(LLVMValueRef V) {
 }
 
 extern "C" LLVMValueRef
-LLVMRustBuildAtomicLoad(LLVMBuilderRef B, LLVMValueRef Source, const char *Name,
-                        LLVMAtomicOrdering Order) {
+LLVMRustBuildAtomicLoad(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Source,
+                        const char *Name, LLVMAtomicOrdering Order) {
   Value *Ptr = unwrap(Source);
-  Type *Ty = Ptr->getType()->getPointerElementType();
-  LoadInst *LI = unwrap(B)->CreateLoad(Ty, Ptr, Name);
+  LoadInst *LI = unwrap(B)->CreateLoad(unwrap(Ty), Ptr, Name);
   LI->setAtomic(fromRust(Order));
   return wrap(LI);
 }
@@ -1077,30 +1067,6 @@ extern "C" void LLVMRustWriteValueToString(LLVMValueRef V,
   }
 }
 
-// Note that the two following functions look quite similar to the
-// LLVMGetSectionName function. Sadly, it appears that this function only
-// returns a char* pointer, which isn't guaranteed to be null-terminated. The
-// function provided by LLVM doesn't return the length, so we've created our own
-// function which returns the length as well as the data pointer.
-//
-// For an example of this not returning a null terminated string, see
-// lib/Object/COFFObjectFile.cpp in the getSectionName function. One of the
-// branches explicitly creates a StringRef without a null terminator, and then
-// that's returned.
-
-inline section_iterator *unwrap(LLVMSectionIteratorRef SI) {
-  return reinterpret_cast<section_iterator *>(SI);
-}
-
-extern "C" size_t LLVMRustGetSectionName(LLVMSectionIteratorRef SI,
-                                         const char **Ptr) {
-  auto NameOrErr = (*unwrap(SI))->getName();
-  if (!NameOrErr)
-    report_fatal_error(NameOrErr.takeError());
-  *Ptr = NameOrErr->data();
-  return NameOrErr->size();
-}
-
 // LLVMArrayType function does not support 64-bit ElementCount
 extern "C" LLVMTypeRef LLVMRustArrayType(LLVMTypeRef ElementTy,
                                          uint64_t ElementCount) {
@@ -1149,15 +1115,13 @@ extern "C" void
 LLVMRustUnpackInlineAsmDiagnostic(LLVMDiagnosticInfoRef DI,
                                   LLVMRustDiagnosticLevel *LevelOut,
                                   unsigned *CookieOut,
-                                  LLVMTwineRef *MessageOut,
-                                  LLVMValueRef *InstructionOut) {
+                                  LLVMTwineRef *MessageOut) {
   // Undefined to call this not on an inline assembly diagnostic!
   llvm::DiagnosticInfoInlineAsm *IA =
       static_cast<llvm::DiagnosticInfoInlineAsm *>(unwrap(DI));
 
   *CookieOut = IA->getLocCookie();
   *MessageOut = wrap(&IA->getMsgStr());
-  *InstructionOut = wrap(IA->getInstruction());
 
   switch (IA->getSeverity()) {
     case DS_Error:
@@ -1200,6 +1164,7 @@ enum class LLVMRustDiagnosticKind {
   PGOProfile,
   Linker,
   Unsupported,
+  SrcMgr,
 };
 
 static LLVMRustDiagnosticKind toRust(DiagnosticKind Kind) {
@@ -1228,6 +1193,10 @@ static LLVMRustDiagnosticKind toRust(DiagnosticKind Kind) {
     return LLVMRustDiagnosticKind::Linker;
   case DK_Unsupported:
     return LLVMRustDiagnosticKind::Unsupported;
+#if LLVM_VERSION_GE(13, 0)
+  case DK_SrcMgr:
+    return LLVMRustDiagnosticKind::SrcMgr;
+#endif
   default:
     return (Kind >= DK_FirstRemark && Kind <= DK_LastRemark)
                ? LLVMRustDiagnosticKind::OptimizationRemarkOther
@@ -1315,6 +1284,17 @@ extern "C" void LLVMRustSetInlineAsmDiagnosticHandler(
 #endif
 }
 
+extern "C" LLVMSMDiagnosticRef LLVMRustGetSMDiagnostic(
+    LLVMDiagnosticInfoRef DI, unsigned *Cookie) {
+#if LLVM_VERSION_GE(13, 0)
+  llvm::DiagnosticInfoSrcMgr *SM = static_cast<llvm::DiagnosticInfoSrcMgr *>(unwrap(DI));
+  *Cookie = SM->getLocCookie();
+  return wrap(&SM->getSMDiag());
+#else
+  report_fatal_error("Shouldn't get called on older versions");
+#endif
+}
+
 extern "C" bool LLVMRustUnpackSMDiagnostic(LLVMSMDiagnosticRef DRef,
                                            RustStringRef MessageOut,
                                            RustStringRef BufferOut,
@@ -1427,11 +1407,11 @@ extern "C" void LLVMRustFreeOperandBundleDef(OperandBundleDef *Bundle) {
   delete Bundle;
 }
 
-extern "C" LLVMValueRef LLVMRustBuildCall(LLVMBuilderRef B, LLVMValueRef Fn,
+extern "C" LLVMValueRef LLVMRustBuildCall(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Fn,
                                           LLVMValueRef *Args, unsigned NumArgs,
                                           OperandBundleDef *Bundle) {
   Value *Callee = unwrap(Fn);
-  FunctionType *FTy = cast<FunctionType>(Callee->getType()->getPointerElementType());
+  FunctionType *FTy = unwrap<FunctionType>(Ty);
   unsigned Len = Bundle ? 1 : 0;
   ArrayRef<OperandBundleDef> Bundles = makeArrayRef(Bundle, Len);
   return wrap(unwrap(B)->CreateCall(
@@ -1472,12 +1452,12 @@ extern "C" LLVMValueRef LLVMRustBuildMemSet(LLVMBuilderRef B,
 }
 
 extern "C" LLVMValueRef
-LLVMRustBuildInvoke(LLVMBuilderRef B, LLVMValueRef Fn, LLVMValueRef *Args,
-                    unsigned NumArgs, LLVMBasicBlockRef Then,
-                    LLVMBasicBlockRef Catch, OperandBundleDef *Bundle,
-                    const char *Name) {
+LLVMRustBuildInvoke(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Fn,
+                    LLVMValueRef *Args, unsigned NumArgs,
+                    LLVMBasicBlockRef Then, LLVMBasicBlockRef Catch,
+                    OperandBundleDef *Bundle, const char *Name) {
   Value *Callee = unwrap(Fn);
-  FunctionType *FTy = cast<FunctionType>(Callee->getType()->getPointerElementType());
+  FunctionType *FTy = unwrap<FunctionType>(Ty);
   unsigned Len = Bundle ? 1 : 0;
   ArrayRef<OperandBundleDef> Bundles = makeArrayRef(Bundle, Len);
   return wrap(unwrap(B)->CreateInvoke(FTy, Callee, unwrap(Then), unwrap(Catch),
@@ -1586,6 +1566,16 @@ extern "C" void LLVMRustSetLinkage(LLVMValueRef V,
   LLVMSetLinkage(V, fromRust(RustLinkage));
 }
 
+extern "C" LLVMValueRef LLVMRustConstInBoundsGEP2(LLVMTypeRef Ty,
+                                                  LLVMValueRef ConstantVal,
+                                                  LLVMValueRef *ConstantIndices,
+                                                  unsigned NumIndices) {
+  ArrayRef<Constant *> IdxList(unwrap<Constant>(ConstantIndices, NumIndices),
+                               NumIndices);
+  Constant *Val = unwrap<Constant>(ConstantVal);
+  return wrap(ConstantExpr::getInBoundsGetElementPtr(unwrap(Ty), Val, IdxList));
+}
+
 // Returns true if both high and low were successfully set. Fails in case constant wasn’t any of
 // the common sizes (1, 8, 16, 32, 64, 128 bits)
 extern "C" bool LLVMRustConstInt128Get(LLVMValueRef CV, bool sext, uint64_t *high, uint64_t *low)
@@ -1757,3 +1747,54 @@ extern "C" LLVMValueRef
 LLVMRustBuildMaxNum(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS) {
     return wrap(unwrap(B)->CreateMaxNum(unwrap(LHS),unwrap(RHS)));
 }
+
+// This struct contains all necessary info about a symbol exported from a DLL.
+// At the moment, it's just the symbol's name, but we use a separate struct to
+// make it easier to add other information like ordinal later.
+struct LLVMRustCOFFShortExport {
+  const char* name;
+};
+
+// Machine must be a COFF machine type, as defined in PE specs.
+extern "C" LLVMRustResult LLVMRustWriteImportLibrary(
+  const char* ImportName,
+  const char* Path,
+  const LLVMRustCOFFShortExport* Exports,
+  size_t NumExports,
+  uint16_t Machine,
+  bool MinGW)
+{
+  std::vector<llvm::object::COFFShortExport> ConvertedExports;
+  ConvertedExports.reserve(NumExports);
+
+  for (size_t i = 0; i < NumExports; ++i) {
+    ConvertedExports.push_back(llvm::object::COFFShortExport{
+      Exports[i].name,  // Name
+      std::string{},    // ExtName
+      std::string{},    // SymbolName
+      std::string{},    // AliasTarget
+      0,                // Ordinal
+      false,            // Noname
+      false,            // Data
+      false,            // Private
+      false             // Constant
+    });
+  }
+
+  auto Error = llvm::object::writeImportLibrary(
+    ImportName,
+    Path,
+    ConvertedExports,
+    static_cast<llvm::COFF::MachineTypes>(Machine),
+    MinGW);
+  if (Error) {
+    std::string errorString;
+    llvm::raw_string_ostream stream(errorString);
+    stream << Error;
+    stream.flush();
+    LLVMRustSetLastError(errorString.c_str());
+    return LLVMRustResult::Failure;
+  } else {
+    return LLVMRustResult::Success;
+  }
+}