summary refs log tree commit diff
path: root/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
diff options
context:
space:
mode:
authorManuel Drehwald <git@manuel.drehwald.info>2025-01-01 21:42:45 +0100
committerManuel Drehwald <git@manuel.drehwald.info>2025-01-01 21:42:45 +0100
commitd753cbf7793f20229ed7a151d060456a9e1396e9 (patch)
tree2a885dd222de7440293c7159de4d4c5f9e82bd45 /compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
parent372442fe5ff1a2d06f4119f2b2e7d1e42388a0d3 (diff)
downloadrust-d753cbf7793f20229ed7a151d060456a9e1396e9.tar.gz
rust-d753cbf7793f20229ed7a151d060456a9e1396e9.zip
upstream rustc_codegen_llvm changes for enzyme/autodiff
Diffstat (limited to 'compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp')
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp77
1 files changed, 77 insertions, 0 deletions
diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
index 36441a95adb..87b1c944aa8 100644
--- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
@@ -1,5 +1,6 @@
 #include "LLVMWrapper.h"
 
+#include "llvm-c/Analysis.h"
 #include "llvm-c/Core.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
@@ -165,6 +166,30 @@ extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,
   return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen)));
 }
 
+enum class LLVMRustVerifierFailureAction {
+  AbortProcessAction = 0,
+  PrintMessageAction = 1,
+  ReturnStatusAction = 2,
+};
+
+static LLVMVerifierFailureAction
+fromRust(LLVMRustVerifierFailureAction Action) {
+  switch (Action) {
+  case LLVMRustVerifierFailureAction::AbortProcessAction:
+    return LLVMAbortProcessAction;
+  case LLVMRustVerifierFailureAction::PrintMessageAction:
+    return LLVMPrintMessageAction;
+  case LLVMRustVerifierFailureAction::ReturnStatusAction:
+    return LLVMReturnStatusAction;
+  }
+  report_fatal_error("Invalid LLVMVerifierFailureAction value!");
+}
+
+extern "C" LLVMBool
+LLVMRustVerifyFunction(LLVMValueRef Fn, LLVMRustVerifierFailureAction Action) {
+  return LLVMVerifyFunction(Fn, fromRust(Action));
+}
+
 enum class LLVMRustTailCallKind {
   None,
   Tail,
@@ -388,6 +413,17 @@ extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr,
   AddAttributes(Call, Index, Attrs, AttrsLen);
 }
 
+extern "C" LLVMValueRef LLVMRustGetTerminator(LLVMBasicBlockRef BB) {
+  Instruction *ret = unwrap(BB)->getTerminator();
+  return wrap(ret);
+}
+
+extern "C" void LLVMRustEraseInstFromParent(LLVMValueRef Instr) {
+  if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
+    I->eraseFromParent();
+  }
+}
+
 extern "C" LLVMAttributeRef
 LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttributeKind RustAttr) {
   return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr)));
@@ -954,6 +990,47 @@ extern "C" void LLVMRustAddModuleFlagString(
       MDString::get(unwrap(M)->getContext(), StringRef(Value, ValueLen)));
 }
 
+extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) {
+  auto Point = unwrap(BB)->rbegin();
+  if (Point != unwrap(BB)->rend())
+    return wrap(&*Point);
+  return nullptr;
+}
+
+extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) {
+  auto &BB = *unwrap(bb);
+  auto &Inst = *unwrap<Instruction>(I);
+  auto It = BB.begin();
+  while (&*It != &Inst)
+    ++It;
+  // Make sure we found the Instruction.
+  assert(It != BB.end());
+  // We don't want to erase the instruction itself.
+  It--;
+  // Delete in rev order to ensure no dangling references.
+  while (It != BB.begin()) {
+    auto Prev = std::prev(It);
+    It->eraseFromParent();
+    It = Prev;
+  }
+  It->eraseFromParent();
+}
+
+extern "C" bool LLVMRustHasMetadata(LLVMValueRef inst, unsigned kindID) {
+  if (auto *I = dyn_cast<Instruction>(unwrap<Value>(inst))) {
+    return I->hasMetadata(kindID);
+  }
+  return false;
+}
+
+extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) {
+  if (auto *I = dyn_cast<Instruction>(unwrap<Value>(x))) {
+    auto *MD = I->getDebugLoc().getAsMDNode();
+    return wrap(MD);
+  }
+  return nullptr;
+}
+
 extern "C" void LLVMRustGlobalAddMetadata(LLVMValueRef Global, unsigned Kind,
                                           LLVMMetadataRef MD) {
   unwrap<GlobalObject>(Global)->addMetadata(Kind, *unwrap<MDNode>(MD));