about summary refs log tree commit diff
path: root/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp')
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp22
1 files changed, 20 insertions, 2 deletions
diff --git a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
index 6447a9362b3..a6b2384f2d7 100644
--- a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp
@@ -688,14 +688,20 @@ struct LLVMRustSanitizerOptions {
   bool SanitizeKernelAddressRecover;
 };
 
+// This symbol won't be available or used when Enzyme is not enabled
+#ifdef ENZYME
+extern "C" void registerEnzyme(llvm::PassBuilder &PB);
+#endif
+
 extern "C" LLVMRustResult LLVMRustOptimize(
     LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef,
     LLVMRustPassBuilderOptLevel OptLevelRust, LLVMRustOptStage OptStage,
     bool IsLinkerPluginLTO, bool NoPrepopulatePasses, bool VerifyIR,
     bool LintIR, bool UseThinLTOBuffers, bool MergeFunctions, bool UnrollLoops,
     bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
-    bool EmitLifetimeMarkers, LLVMRustSanitizerOptions *SanitizerOptions,
-    const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage,
+    bool EmitLifetimeMarkers, bool RunEnzyme,
+    LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
+    const char *PGOUsePath, bool InstrumentCoverage,
     const char *InstrProfileOutput, const char *PGOSampleUsePath,
     bool DebugInfoForProfiling, void *LlvmSelfProfiler,
     LLVMRustSelfProfileBeforePassCallback BeforePassCallback,
@@ -1010,6 +1016,18 @@ extern "C" LLVMRustResult LLVMRustOptimize(
     MPM.addPass(NameAnonGlobalPass());
   }
 
+  // now load "-enzyme" pass:
+#ifdef ENZYME
+  if (RunEnzyme) {
+    registerEnzyme(PB);
+    if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
+      std::string ErrMsg = toString(std::move(Err));
+      LLVMRustSetLastError(ErrMsg.c_str());
+      return LLVMRustResult::Failure;
+    }
+  }
+#endif
+
   // Upgrade all calls to old intrinsics first.
   for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)
     UpgradeCallsToIntrinsic(&*I++); // must be post-increment, as we remove