about summary refs log tree commit diff
path: root/compiler/rustc_llvm/llvm-wrapper
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_llvm/llvm-wrapper')
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp22
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp67
2 files changed, 20 insertions, 69 deletions
diff --git a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
index 6447a9362b3..a6b2384f2d7 100644
--- a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
@@ -688,14 +688,20 @@ struct LLVMRustSanitizerOptions {
   bool SanitizeKernelAddressRecover;
 };
 
+// This symbol won't be available or used when Enzyme is not enabled
+#ifdef ENZYME
+extern "C" void registerEnzyme(llvm::PassBuilder &PB);
+#endif
+
 extern "C" LLVMRustResult LLVMRustOptimize(
     LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef,
     LLVMRustPassBuilderOptLevel OptLevelRust, LLVMRustOptStage OptStage,
     bool IsLinkerPluginLTO, bool NoPrepopulatePasses, bool VerifyIR,
     bool LintIR, bool UseThinLTOBuffers, bool MergeFunctions, bool UnrollLoops,
     bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
-    bool EmitLifetimeMarkers, LLVMRustSanitizerOptions *SanitizerOptions,
-    const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage,
+    bool EmitLifetimeMarkers, bool RunEnzyme,
+    LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
+    const char *PGOUsePath, bool InstrumentCoverage,
     const char *InstrProfileOutput, const char *PGOSampleUsePath,
     bool DebugInfoForProfiling, void *LlvmSelfProfiler,
     LLVMRustSelfProfileBeforePassCallback BeforePassCallback,
@@ -1010,6 +1016,18 @@ extern "C" LLVMRustResult LLVMRustOptimize(
     MPM.addPass(NameAnonGlobalPass());
   }
 
+  // now load "-enzyme" pass:
+#ifdef ENZYME
+  if (RunEnzyme) {
+    registerEnzyme(PB);
+    if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
+      std::string ErrMsg = toString(std::move(Err));
+      LLVMRustSetLastError(ErrMsg.c_str());
+      return LLVMRustResult::Failure;
+    }
+  }
+#endif
+
   // Upgrade all calls to old intrinsics first.
   for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)
     UpgradeCallsToIntrinsic(&*I++); // must be post-increment, as we remove
diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
index 7ff316ba83a..b8cef6a7e25 100644
--- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
@@ -195,33 +195,6 @@ LLVMRustVerifyFunction(LLVMValueRef Fn, LLVMRustVerifierFailureAction Action) {
   return LLVMVerifyFunction(Fn, fromRust(Action));
 }
 
-enum class LLVMRustTailCallKind {
-  None,
-  Tail,
-  MustTail,
-  NoTail,
-};
-
-static CallInst::TailCallKind fromRust(LLVMRustTailCallKind Kind) {
-  switch (Kind) {
-  case LLVMRustTailCallKind::None:
-    return CallInst::TailCallKind::TCK_None;
-  case LLVMRustTailCallKind::Tail:
-    return CallInst::TailCallKind::TCK_Tail;
-  case LLVMRustTailCallKind::MustTail:
-    return CallInst::TailCallKind::TCK_MustTail;
-  case LLVMRustTailCallKind::NoTail:
-    return CallInst::TailCallKind::TCK_NoTail;
-  default:
-    report_fatal_error("bad CallInst::TailCallKind.");
-  }
-}
-
-extern "C" void LLVMRustSetTailCallKind(LLVMValueRef Call,
-                                        LLVMRustTailCallKind TCK) {
-  unwrap<CallInst>(Call)->setTailCallKind(fromRust(TCK));
-}
-
 extern "C" LLVMValueRef LLVMRustGetOrInsertFunction(LLVMModuleRef M,
                                                     const char *Name,
                                                     size_t NameLen,
@@ -1003,10 +976,6 @@ extern "C" void LLVMRustDIBuilderDispose(LLVMDIBuilderRef Builder) {
   delete unwrap(Builder);
 }
 
-extern "C" void LLVMRustDIBuilderFinalize(LLVMDIBuilderRef Builder) {
-  unwrap(Builder)->finalize();
-}
-
 extern "C" LLVMMetadataRef LLVMRustDIBuilderCreateCompileUnit(
     LLVMDIBuilderRef Builder, unsigned Lang, LLVMMetadataRef FileRef,
     const char *Producer, size_t ProducerLen, bool isOptimized,
@@ -1183,20 +1152,6 @@ LLVMRustDIBuilderCreateQualifiedType(LLVMDIBuilderRef Builder, unsigned Tag,
       unwrap(Builder)->createQualifiedType(Tag, unwrapDI<DIType>(Type)));
 }
 
-extern "C" LLVMMetadataRef
-LLVMRustDIBuilderCreateLexicalBlock(LLVMDIBuilderRef Builder,
-                                    LLVMMetadataRef Scope, LLVMMetadataRef File,
-                                    unsigned Line, unsigned Col) {
-  return wrap(unwrap(Builder)->createLexicalBlock(
-      unwrapDI<DIDescriptor>(Scope), unwrapDI<DIFile>(File), Line, Col));
-}
-
-extern "C" LLVMMetadataRef LLVMRustDIBuilderCreateLexicalBlockFile(
-    LLVMDIBuilderRef Builder, LLVMMetadataRef Scope, LLVMMetadataRef File) {
-  return wrap(unwrap(Builder)->createLexicalBlockFile(
-      unwrapDI<DIDescriptor>(Scope), unwrapDI<DIFile>(File)));
-}
-
 extern "C" LLVMMetadataRef LLVMRustDIBuilderCreateStaticVariable(
     LLVMDIBuilderRef Builder, LLVMMetadataRef Context, const char *Name,
     size_t NameLen, const char *LinkageName, size_t LinkageNameLen,
@@ -1325,14 +1280,6 @@ extern "C" LLVMMetadataRef LLVMRustDIBuilderCreateTemplateTypeParameter(
       unwrapDI<DIType>(Ty), IsDefault));
 }
 
-extern "C" LLVMMetadataRef
-LLVMRustDIBuilderCreateNameSpace(LLVMDIBuilderRef Builder,
-                                 LLVMMetadataRef Scope, const char *Name,
-                                 size_t NameLen, bool ExportSymbols) {
-  return wrap(unwrap(Builder)->createNameSpace(
-      unwrapDI<DIDescriptor>(Scope), StringRef(Name, NameLen), ExportSymbols));
-}
-
 extern "C" void LLVMRustDICompositeTypeReplaceArrays(
     LLVMDIBuilderRef Builder, LLVMMetadataRef CompositeTy,
     LLVMMetadataRef Elements, LLVMMetadataRef Params) {
@@ -1342,16 +1289,6 @@ extern "C" void LLVMRustDICompositeTypeReplaceArrays(
 }
 
 extern "C" LLVMMetadataRef
-LLVMRustDIBuilderCreateDebugLocation(unsigned Line, unsigned Column,
-                                     LLVMMetadataRef ScopeRef,
-                                     LLVMMetadataRef InlinedAt) {
-  MDNode *Scope = unwrapDIPtr<MDNode>(ScopeRef);
-  DILocation *Loc = DILocation::get(Scope->getContext(), Line, Column, Scope,
-                                    unwrapDIPtr<MDNode>(InlinedAt));
-  return wrap(Loc);
-}
-
-extern "C" LLVMMetadataRef
 LLVMRustDILocationCloneWithBaseDiscriminator(LLVMMetadataRef Location,
                                              unsigned BD) {
   DILocation *Loc = unwrapDIPtr<DILocation>(Location);
@@ -2012,10 +1949,6 @@ extern "C" int32_t LLVMRustGetElementTypeArgIndex(LLVMValueRef CallSite) {
   return -1;
 }
 
-extern "C" bool LLVMRustIsBitcode(char *ptr, size_t len) {
-  return identify_magic(StringRef(ptr, len)) == file_magic::bitcode;
-}
-
 extern "C" bool LLVMRustIsNonGVFunctionPointerTy(LLVMValueRef V) {
   if (unwrap<Value>(V)->getType()->isPointerTy()) {
     if (auto *GV = dyn_cast<GlobalValue>(unwrap<Value>(V))) {