about summary refs log tree commit diff
path: root/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp')
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp346
1 files changed, 197 insertions, 149 deletions
diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
index 4f07a0c67c1..919fe7cac5c 100644
--- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
@@ -1,5 +1,6 @@
 #include "LLVMWrapper.h"
 #include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/DiagnosticHandler.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/GlobalVariable.h"
@@ -54,7 +55,11 @@ static LLVM_THREAD_LOCAL char *LastError;
 //
 // Notably it exits the process with code 101, unlike LLVM's default of 1.
 static void FatalErrorHandler(void *UserData,
+#if LLVM_VERSION_LT(14, 0)
                               const std::string& Reason,
+#else
+                              const char* Reason,
+#endif
                               bool GenCrashDiag) {
   // Do the same thing that the default error handler does.
   std::cerr << "LLVM ERROR: " << Reason << std::endl;
@@ -71,6 +76,10 @@ extern "C" void LLVMRustInstallFatalErrorHandler() {
   install_fatal_error_handler(FatalErrorHandler);
 }
 
+extern "C" void LLVMRustDisableSystemDialogsOnCrash() {
+  sys::DisableSystemDialogsOnCrash();
+}
+
 extern "C" char *LLVMRustGetLastError(void) {
   char *Ret = LastError;
   LastError = nullptr;
@@ -120,8 +129,18 @@ extern "C" LLVMValueRef LLVMRustGetOrInsertFunction(LLVMModuleRef M,
 
 extern "C" LLVMValueRef
 LLVMRustGetOrInsertGlobal(LLVMModuleRef M, const char *Name, size_t NameLen, LLVMTypeRef Ty) {
+  Module *Mod = unwrap(M);
   StringRef NameRef(Name, NameLen);
-  return wrap(unwrap(M)->getOrInsertGlobal(NameRef, unwrap(Ty)));
+
+  // We don't use Module::getOrInsertGlobal because that returns a Constant*,
+  // which may either be the real GlobalVariable*, or a constant bitcast of it
+  // if our type doesn't match the original declaration. We always want the
+  // GlobalVariable* so we can access linkage, visibility, etc.
+  GlobalVariable *GV = Mod->getGlobalVariable(NameRef, true);
+  if (!GV)
+    GV = new GlobalVariable(*Mod, unwrap(Ty), false,
+                            GlobalValue::ExternalLinkage, nullptr, NameRef);
+  return wrap(GV);
 }
 
 extern "C" LLVMValueRef
@@ -199,139 +218,91 @@ static Attribute::AttrKind fromRust(LLVMRustAttribute Kind) {
     return Attribute::SanitizeHWAddress;
   case WillReturn:
     return Attribute::WillReturn;
+  case StackProtectReq:
+    return Attribute::StackProtectReq;
+  case StackProtectStrong:
+    return Attribute::StackProtectStrong;
+  case StackProtect:
+    return Attribute::StackProtect;
+  case NoUndef:
+    return Attribute::NoUndef;
+  case SanitizeMemTag:
+    return Attribute::SanitizeMemTag;
   }
   report_fatal_error("bad AttributeKind");
 }
 
-extern "C" void LLVMRustAddCallSiteAttribute(LLVMValueRef Instr, unsigned Index,
-                                             LLVMRustAttribute RustAttr) {
-  CallBase *Call = unwrap<CallBase>(Instr);
-  Attribute Attr = Attribute::get(Call->getContext(), fromRust(RustAttr));
-  Call->addAttribute(Index, Attr);
-}
-
-extern "C" void LLVMRustAddCallSiteAttrString(LLVMValueRef Instr, unsigned Index,
-                                              const char *Name) {
-  CallBase *Call = unwrap<CallBase>(Instr);
-  Attribute Attr = Attribute::get(Call->getContext(), Name);
-  Call->addAttribute(Index, Attr);
-}
-
-
-extern "C" void LLVMRustAddAlignmentCallSiteAttr(LLVMValueRef Instr,
-                                                 unsigned Index,
-                                                 uint32_t Bytes) {
-  CallBase *Call = unwrap<CallBase>(Instr);
+template<typename T> static inline void AddAttributes(T *t, unsigned Index,
+                                                      LLVMAttributeRef *Attrs, size_t AttrsLen) {
+  AttributeList PAL = t->getAttributes();
+  AttributeList PALNew;
+#if LLVM_VERSION_LT(14, 0)
   AttrBuilder B;
-  B.addAlignmentAttr(Bytes);
-  Call->setAttributes(Call->getAttributes().addAttributes(
-      Call->getContext(), Index, B));
+  for (LLVMAttributeRef Attr : makeArrayRef(Attrs, AttrsLen))
+    B.addAttribute(unwrap(Attr));
+  PALNew = PAL.addAttributes(t->getContext(), Index, B);
+#else
+  AttrBuilder B(t->getContext());
+  for (LLVMAttributeRef Attr : makeArrayRef(Attrs, AttrsLen))
+    B.addAttribute(unwrap(Attr));
+  PALNew = PAL.addAttributesAtIndex(t->getContext(), Index, B);
+#endif
+  t->setAttributes(PALNew);
 }
 
-extern "C" void LLVMRustAddDereferenceableCallSiteAttr(LLVMValueRef Instr,
-                                                       unsigned Index,
-                                                       uint64_t Bytes) {
-  CallBase *Call = unwrap<CallBase>(Instr);
-  AttrBuilder B;
-  B.addDereferenceableAttr(Bytes);
-  Call->setAttributes(Call->getAttributes().addAttributes(
-      Call->getContext(), Index, B));
+extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index,
+                                              LLVMAttributeRef *Attrs, size_t AttrsLen) {
+  Function *F = unwrap<Function>(Fn);
+  AddAttributes(F, Index, Attrs, AttrsLen);
 }
 
-extern "C" void LLVMRustAddDereferenceableOrNullCallSiteAttr(LLVMValueRef Instr,
-                                                             unsigned Index,
-                                                             uint64_t Bytes) {
+extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, unsigned Index,
+                                              LLVMAttributeRef *Attrs, size_t AttrsLen) {
   CallBase *Call = unwrap<CallBase>(Instr);
-  AttrBuilder B;
-  B.addDereferenceableOrNullAttr(Bytes);
-  Call->setAttributes(Call->getAttributes().addAttributes(
-      Call->getContext(), Index, B));
+  AddAttributes(Call, Index, Attrs, AttrsLen);
 }
 
-extern "C" void LLVMRustAddByValCallSiteAttr(LLVMValueRef Instr, unsigned Index,
-                                             LLVMTypeRef Ty) {
-  CallBase *Call = unwrap<CallBase>(Instr);
-  Attribute Attr = Attribute::getWithByValType(Call->getContext(), unwrap(Ty));
-  Call->addAttribute(Index, Attr);
+extern "C" LLVMAttributeRef LLVMRustCreateAttrNoValue(LLVMContextRef C,
+                                                      LLVMRustAttribute RustAttr) {
+  return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr)));
 }
 
-extern "C" void LLVMRustAddStructRetCallSiteAttr(LLVMValueRef Instr, unsigned Index,
-                                                 LLVMTypeRef Ty) {
-  CallBase *Call = unwrap<CallBase>(Instr);
-#if LLVM_VERSION_GE(12, 0)
-  Attribute Attr = Attribute::getWithStructRetType(Call->getContext(), unwrap(Ty));
-#else
-  Attribute Attr = Attribute::get(Call->getContext(), Attribute::StructRet);
-#endif
-  Call->addAttribute(Index, Attr);
+extern "C" LLVMAttributeRef LLVMRustCreateAttrStringValue(LLVMContextRef C,
+                                                          const char *Name,
+                                                          const char *Value) {
+  return wrap(Attribute::get(*unwrap(C), StringRef(Name), StringRef(Value)));
 }
 
-extern "C" void LLVMRustAddFunctionAttribute(LLVMValueRef Fn, unsigned Index,
-                                             LLVMRustAttribute RustAttr) {
-  Function *A = unwrap<Function>(Fn);
-  Attribute Attr = Attribute::get(A->getContext(), fromRust(RustAttr));
-  A->addAttribute(Index, Attr);
+extern "C" LLVMAttributeRef LLVMRustCreateAlignmentAttr(LLVMContextRef C,
+                                                        uint64_t Bytes) {
+  return wrap(Attribute::getWithAlignment(*unwrap(C), llvm::Align(Bytes)));
 }
 
-extern "C" void LLVMRustAddAlignmentAttr(LLVMValueRef Fn,
-                                         unsigned Index,
-                                         uint32_t Bytes) {
-  Function *A = unwrap<Function>(Fn);
-  A->addAttribute(Index, Attribute::getWithAlignment(
-      A->getContext(), llvm::Align(Bytes)));
+extern "C" LLVMAttributeRef LLVMRustCreateDereferenceableAttr(LLVMContextRef C,
+                                                              uint64_t Bytes) {
+  return wrap(Attribute::getWithDereferenceableBytes(*unwrap(C), Bytes));
 }
 
-extern "C" void LLVMRustAddDereferenceableAttr(LLVMValueRef Fn, unsigned Index,
-                                               uint64_t Bytes) {
-  Function *A = unwrap<Function>(Fn);
-  A->addAttribute(Index, Attribute::getWithDereferenceableBytes(A->getContext(),
-                                                                Bytes));
+extern "C" LLVMAttributeRef LLVMRustCreateDereferenceableOrNullAttr(LLVMContextRef C,
+                                                                    uint64_t Bytes) {
+  return wrap(Attribute::getWithDereferenceableOrNullBytes(*unwrap(C), Bytes));
 }
 
-extern "C" void LLVMRustAddDereferenceableOrNullAttr(LLVMValueRef Fn,
-                                                     unsigned Index,
-                                                     uint64_t Bytes) {
-  Function *A = unwrap<Function>(Fn);
-  A->addAttribute(Index, Attribute::getWithDereferenceableOrNullBytes(
-      A->getContext(), Bytes));
+extern "C" LLVMAttributeRef LLVMRustCreateByValAttr(LLVMContextRef C, LLVMTypeRef Ty) {
+  return wrap(Attribute::getWithByValType(*unwrap(C), unwrap(Ty)));
 }
 
-extern "C" void LLVMRustAddByValAttr(LLVMValueRef Fn, unsigned Index,
-                                     LLVMTypeRef Ty) {
-  Function *F = unwrap<Function>(Fn);
-  Attribute Attr = Attribute::getWithByValType(F->getContext(), unwrap(Ty));
-  F->addAttribute(Index, Attr);
+extern "C" LLVMAttributeRef LLVMRustCreateStructRetAttr(LLVMContextRef C, LLVMTypeRef Ty) {
+  return wrap(Attribute::getWithStructRetType(*unwrap(C), unwrap(Ty)));
 }
 
-extern "C" void LLVMRustAddStructRetAttr(LLVMValueRef Fn, unsigned Index,
-                                         LLVMTypeRef Ty) {
-  Function *F = unwrap<Function>(Fn);
-#if LLVM_VERSION_GE(12, 0)
-  Attribute Attr = Attribute::getWithStructRetType(F->getContext(), unwrap(Ty));
+extern "C" LLVMAttributeRef LLVMRustCreateUWTableAttr(LLVMContextRef C, bool Async) {
+#if LLVM_VERSION_LT(15, 0)
+  return wrap(Attribute::get(*unwrap(C), Attribute::UWTable));
 #else
-  Attribute Attr = Attribute::get(F->getContext(), Attribute::StructRet);
+  return wrap(Attribute::getWithUWTableKind(
+      *unwrap(C), Async ? UWTableKind::Async : UWTableKind::Sync));
 #endif
-  F->addAttribute(Index, Attr);
-}
-
-extern "C" void LLVMRustAddFunctionAttrStringValue(LLVMValueRef Fn,
-                                                   unsigned Index,
-                                                   const char *Name,
-                                                   const char *Value) {
-  Function *F = unwrap<Function>(Fn);
-  F->addAttribute(Index, Attribute::get(
-      F->getContext(), StringRef(Name), StringRef(Value)));
-}
-
-extern "C" void LLVMRustRemoveFunctionAttributes(LLVMValueRef Fn,
-                                                 unsigned Index,
-                                                 LLVMRustAttribute RustAttr) {
-  Function *F = unwrap<Function>(Fn);
-  Attribute Attr = Attribute::get(F->getContext(), fromRust(RustAttr));
-  AttrBuilder B(Attr);
-  auto PAL = F->getAttributes();
-  auto PALNew = PAL.removeAttributes(F->getContext(), Index, B);
-  F->setAttributes(PALNew);
 }
 
 // Enable a fast-math flag
@@ -426,11 +397,20 @@ extern "C" LLVMValueRef
 LLVMRustInlineAsm(LLVMTypeRef Ty, char *AsmString, size_t AsmStringLen,
                   char *Constraints, size_t ConstraintsLen,
                   LLVMBool HasSideEffects, LLVMBool IsAlignStack,
-                  LLVMRustAsmDialect Dialect) {
+                  LLVMRustAsmDialect Dialect, LLVMBool CanThrow) {
+#if LLVM_VERSION_GE(13, 0)
   return wrap(InlineAsm::get(unwrap<FunctionType>(Ty),
                              StringRef(AsmString, AsmStringLen),
                              StringRef(Constraints, ConstraintsLen),
-                             HasSideEffects, IsAlignStack, fromRust(Dialect)));
+                             HasSideEffects, IsAlignStack,
+                             fromRust(Dialect), CanThrow));
+#else
+  return wrap(InlineAsm::get(unwrap<FunctionType>(Ty),
+                             StringRef(AsmString, AsmStringLen),
+                             StringRef(Constraints, ConstraintsLen),
+                             HasSideEffects, IsAlignStack,
+                             fromRust(Dialect)));
+#endif
 }
 
 extern "C" bool LLVMRustInlineAsmVerify(LLVMTypeRef Ty, char *Constraints,
@@ -671,10 +651,8 @@ static Optional<DIFile::ChecksumKind> fromRust(LLVMRustChecksumKind Kind) {
     return DIFile::ChecksumKind::CSK_MD5;
   case LLVMRustChecksumKind::SHA1:
     return DIFile::ChecksumKind::CSK_SHA1;
-#if (LLVM_VERSION_MAJOR >= 11)
   case LLVMRustChecksumKind::SHA256:
     return DIFile::ChecksumKind::CSK_SHA256;
-#endif
   default:
     report_fatal_error("bad ChecksumKind.");
   }
@@ -690,9 +668,12 @@ extern "C" uint32_t LLVMRustVersionMinor() { return LLVM_VERSION_MINOR; }
 
 extern "C" uint32_t LLVMRustVersionMajor() { return LLVM_VERSION_MAJOR; }
 
-extern "C" void LLVMRustAddModuleFlag(LLVMModuleRef M, const char *Name,
-                                      uint32_t Value) {
-  unwrap(M)->addModuleFlag(Module::Warning, Name, Value);
+extern "C" void LLVMRustAddModuleFlag(
+    LLVMModuleRef M,
+    Module::ModFlagBehavior MergeBehavior,
+    const char *Name,
+    uint32_t Value) {
+  unwrap(M)->addModuleFlag(MergeBehavior, Name, Value);
 }
 
 extern "C" LLVMValueRef LLVMRustMetadataAsValue(LLVMContextRef C, LLVMMetadataRef MD) {
@@ -945,11 +926,11 @@ LLVMRustDIBuilderGetOrCreateArray(LLVMRustDIBuilderRef Builder,
 
 extern "C" LLVMValueRef LLVMRustDIBuilderInsertDeclareAtEnd(
     LLVMRustDIBuilderRef Builder, LLVMValueRef V, LLVMMetadataRef VarInfo,
-    int64_t *AddrOps, unsigned AddrOpsCount, LLVMMetadataRef DL,
+    uint64_t *AddrOps, unsigned AddrOpsCount, LLVMMetadataRef DL,
     LLVMBasicBlockRef InsertAtEnd) {
   return wrap(Builder->insertDeclare(
       unwrap(V), unwrap<DILocalVariable>(VarInfo),
-      Builder->createExpression(llvm::ArrayRef<int64_t>(AddrOps, AddrOpsCount)),
+      Builder->createExpression(llvm::ArrayRef<uint64_t>(AddrOps, AddrOpsCount)),
       DebugLoc(cast<MDNode>(unwrap(DL))),
       unwrap(InsertAtEnd)));
 }
@@ -989,14 +970,9 @@ extern "C" LLVMMetadataRef LLVMRustDIBuilderCreateUnionType(
 extern "C" LLVMMetadataRef LLVMRustDIBuilderCreateTemplateTypeParameter(
     LLVMRustDIBuilderRef Builder, LLVMMetadataRef Scope,
     const char *Name, size_t NameLen, LLVMMetadataRef Ty) {
-#if LLVM_VERSION_GE(11, 0)
   bool IsDefault = false; // FIXME: should we ever set this true?
   return wrap(Builder->createTemplateTypeParameter(
       unwrapDI<DIDescriptor>(Scope), StringRef(Name, NameLen), unwrapDI<DIType>(Ty), IsDefault));
-#else
-  return wrap(Builder->createTemplateTypeParameter(
-      unwrapDI<DIDescriptor>(Scope), StringRef(Name, NameLen), unwrapDI<DIType>(Ty)));
-#endif
 }
 
 extern "C" LLVMMetadataRef LLVMRustDIBuilderCreateNameSpace(
@@ -1021,24 +997,18 @@ extern "C" LLVMMetadataRef
 LLVMRustDIBuilderCreateDebugLocation(unsigned Line, unsigned Column,
                                      LLVMMetadataRef ScopeRef,
                                      LLVMMetadataRef InlinedAt) {
-#if LLVM_VERSION_GE(12, 0)
   MDNode *Scope = unwrapDIPtr<MDNode>(ScopeRef);
   DILocation *Loc = DILocation::get(
       Scope->getContext(), Line, Column, Scope,
       unwrapDIPtr<MDNode>(InlinedAt));
   return wrap(Loc);
-#else
-  DebugLoc debug_loc = DebugLoc::get(Line, Column, unwrapDIPtr<MDNode>(ScopeRef),
-                                     unwrapDIPtr<MDNode>(InlinedAt));
-  return wrap(debug_loc.getAsMDNode());
-#endif
 }
 
-extern "C" int64_t LLVMRustDIBuilderCreateOpDeref() {
+extern "C" uint64_t LLVMRustDIBuilderCreateOpDeref() {
   return dwarf::DW_OP_deref;
 }
 
-extern "C" int64_t LLVMRustDIBuilderCreateOpPlusUconst() {
+extern "C" uint64_t LLVMRustDIBuilderCreateOpPlusUconst() {
   return dwarf::DW_OP_plus_uconst;
 }
 
@@ -1172,10 +1142,13 @@ static LLVMRustDiagnosticKind toRust(DiagnosticKind Kind) {
   case DK_SampleProfile:
     return LLVMRustDiagnosticKind::SampleProfile;
   case DK_OptimizationRemark:
+  case DK_MachineOptimizationRemark:
     return LLVMRustDiagnosticKind::OptimizationRemark;
   case DK_OptimizationRemarkMissed:
+  case DK_MachineOptimizationRemarkMissed:
     return LLVMRustDiagnosticKind::OptimizationRemarkMissed;
   case DK_OptimizationRemarkAnalysis:
+  case DK_MachineOptimizationRemarkAnalysis:
     return LLVMRustDiagnosticKind::OptimizationRemarkAnalysis;
   case DK_OptimizationRemarkAnalysisFPCommute:
     return LLVMRustDiagnosticKind::OptimizationRemarkAnalysisFPCommute;
@@ -1236,27 +1209,18 @@ extern "C" LLVMTypeKind LLVMRustGetTypeKind(LLVMTypeRef Ty) {
     return LLVMArrayTypeKind;
   case Type::PointerTyID:
     return LLVMPointerTypeKind;
-#if LLVM_VERSION_GE(11, 0)
   case Type::FixedVectorTyID:
     return LLVMVectorTypeKind;
-#else
-  case Type::VectorTyID:
-    return LLVMVectorTypeKind;
-#endif
   case Type::X86_MMXTyID:
     return LLVMX86_MMXTypeKind;
   case Type::TokenTyID:
     return LLVMTokenTypeKind;
-#if LLVM_VERSION_GE(11, 0)
   case Type::ScalableVectorTyID:
     return LLVMScalableVectorTypeKind;
   case Type::BFloatTyID:
     return LLVMBFloatTypeKind;
-#endif
-#if LLVM_VERSION_GE(12, 0)
   case Type::X86_AMXTyID:
     return LLVMX86_AMXTypeKind;
-#endif
   }
   report_fatal_error("Unhandled TypeID.");
 }
@@ -1714,23 +1678,15 @@ LLVMRustBuildVectorReduceMax(LLVMBuilderRef B, LLVMValueRef Src, bool IsSigned)
 }
 extern "C" LLVMValueRef
 LLVMRustBuildVectorReduceFMin(LLVMBuilderRef B, LLVMValueRef Src, bool NoNaN) {
-#if LLVM_VERSION_GE(12, 0)
   Instruction *I = unwrap(B)->CreateFPMinReduce(unwrap(Src));
   I->setHasNoNaNs(NoNaN);
   return wrap(I);
-#else
-  return wrap(unwrap(B)->CreateFPMinReduce(unwrap(Src), NoNaN));
-#endif
 }
 extern "C" LLVMValueRef
 LLVMRustBuildVectorReduceFMax(LLVMBuilderRef B, LLVMValueRef Src, bool NoNaN) {
-#if LLVM_VERSION_GE(12, 0)
   Instruction *I = unwrap(B)->CreateFPMaxReduce(unwrap(Src));
   I->setHasNoNaNs(NoNaN);
   return wrap(I);
-#else
-  return wrap(unwrap(B)->CreateFPMaxReduce(unwrap(Src), NoNaN));
-#endif
 }
 
 extern "C" LLVMValueRef
@@ -1743,10 +1699,11 @@ LLVMRustBuildMaxNum(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS) {
 }
 
 // This struct contains all necessary info about a symbol exported from a DLL.
-// At the moment, it's just the symbol's name, but we use a separate struct to
-// make it easier to add other information like ordinal later.
 struct LLVMRustCOFFShortExport {
   const char* name;
+  bool ordinal_present;
+  // The value of `ordinal` is only meaningful if `ordinal_present` is true.
+  uint16_t ordinal;
 };
 
 // Machine must be a COFF machine type, as defined in PE specs.
@@ -1762,13 +1719,15 @@ extern "C" LLVMRustResult LLVMRustWriteImportLibrary(
   ConvertedExports.reserve(NumExports);
 
   for (size_t i = 0; i < NumExports; ++i) {
+    bool ordinal_present = Exports[i].ordinal_present;
+    uint16_t ordinal = ordinal_present ? Exports[i].ordinal : 0;
     ConvertedExports.push_back(llvm::object::COFFShortExport{
       Exports[i].name,  // Name
       std::string{},    // ExtName
       std::string{},    // SymbolName
       std::string{},    // AliasTarget
-      0,                // Ordinal
-      false,            // Noname
+      ordinal,          // Ordinal
+      ordinal_present,  // Noname
       false,            // Data
       false,            // Private
       false             // Constant
@@ -1792,3 +1751,92 @@ extern "C" LLVMRustResult LLVMRustWriteImportLibrary(
     return LLVMRustResult::Success;
   }
 }
+
+// Transfers ownership of DiagnosticHandler unique_ptr to the caller.
+extern "C" DiagnosticHandler *
+LLVMRustContextGetDiagnosticHandler(LLVMContextRef C) {
+  std::unique_ptr<DiagnosticHandler> DH = unwrap(C)->getDiagnosticHandler();
+  return DH.release();
+}
+
+// Sets unique_ptr to object of DiagnosticHandler to provide custom diagnostic
+// handling. Ownership of the handler is moved to the LLVMContext.
+extern "C" void LLVMRustContextSetDiagnosticHandler(LLVMContextRef C,
+                                                    DiagnosticHandler *DH) {
+  unwrap(C)->setDiagnosticHandler(std::unique_ptr<DiagnosticHandler>(DH));
+}
+
+using LLVMDiagnosticHandlerTy = DiagnosticHandler::DiagnosticHandlerTy;
+
+// Configures a diagnostic handler that invokes provided callback when a
+// backend needs to emit a diagnostic.
+//
+// When RemarkAllPasses is true, remarks are enabled for all passes. Otherwise
+// the RemarkPasses array specifies individual passes for which remarks will be
+// enabled.
+extern "C" void LLVMRustContextConfigureDiagnosticHandler(
+    LLVMContextRef C, LLVMDiagnosticHandlerTy DiagnosticHandlerCallback,
+    void *DiagnosticHandlerContext, bool RemarkAllPasses,
+    const char * const * RemarkPasses, size_t RemarkPassesLen) {
+
+  class RustDiagnosticHandler final : public DiagnosticHandler {
+  public:
+    RustDiagnosticHandler(LLVMDiagnosticHandlerTy DiagnosticHandlerCallback,
+                          void *DiagnosticHandlerContext,
+                          bool RemarkAllPasses,
+                          std::vector<std::string> RemarkPasses)
+        : DiagnosticHandlerCallback(DiagnosticHandlerCallback),
+          DiagnosticHandlerContext(DiagnosticHandlerContext),
+          RemarkAllPasses(RemarkAllPasses),
+          RemarkPasses(RemarkPasses) {}
+
+    virtual bool handleDiagnostics(const DiagnosticInfo &DI) override {
+      if (DiagnosticHandlerCallback) {
+        DiagnosticHandlerCallback(DI, DiagnosticHandlerContext);
+        return true;
+      }
+      return false;
+    }
+
+    bool isAnalysisRemarkEnabled(StringRef PassName) const override {
+      return isRemarkEnabled(PassName);
+    }
+
+    bool isMissedOptRemarkEnabled(StringRef PassName) const override {
+      return isRemarkEnabled(PassName);
+    }
+
+    bool isPassedOptRemarkEnabled(StringRef PassName) const override {
+      return isRemarkEnabled(PassName);
+    }
+
+    bool isAnyRemarkEnabled() const override {
+      return RemarkAllPasses || !RemarkPasses.empty();
+    }
+
+  private:
+    bool isRemarkEnabled(StringRef PassName) const {
+      if (RemarkAllPasses)
+        return true;
+
+      for (auto &Pass : RemarkPasses)
+        if (Pass == PassName)
+          return true;
+
+      return false;
+    }
+
+    LLVMDiagnosticHandlerTy DiagnosticHandlerCallback = nullptr;
+    void *DiagnosticHandlerContext = nullptr;
+
+    bool RemarkAllPasses = false;
+    std::vector<std::string> RemarkPasses;
+  };
+
+  std::vector<std::string> Passes;
+  for (size_t I = 0; I != RemarkPassesLen; ++I)
+    Passes.push_back(RemarkPasses[I]);
+
+  unwrap(C)->setDiagnosticHandler(std::make_unique<RustDiagnosticHandler>(
+      DiagnosticHandlerCallback, DiagnosticHandlerContext, RemarkAllPasses, Passes));
+}