about summary refs log tree commit diff
path: root/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp')
-rw-r--r--compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp63
1 files changed, 63 insertions, 0 deletions
diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
index 90aa9188c83..588d867bbbf 100644
--- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
+++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
@@ -1591,12 +1591,49 @@ extern "C" LLVMValueRef LLVMRustBuildMemSet(LLVMBuilderRef B, LLVMValueRef Dst,
                                       MaybeAlign(DstAlign), IsVolatile));
 }
 
+extern "C" void LLVMRustPositionBuilderPastAllocas(LLVMBuilderRef B,
+                                                   LLVMValueRef Fn) {
+  Function *F = unwrap<Function>(Fn);
+  unwrap(B)->SetInsertPointPastAllocas(F);
+}
 extern "C" void LLVMRustPositionBuilderAtStart(LLVMBuilderRef B,
                                                LLVMBasicBlockRef BB) {
   auto Point = unwrap(BB)->getFirstInsertionPt();
   unwrap(B)->SetInsertPoint(unwrap(BB), Point);
 }
 
+extern "C" void LLVMRustPositionBefore(LLVMBuilderRef B, LLVMValueRef Instr) {
+  if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
+    unwrap(B)->SetInsertPoint(I);
+  }
+}
+
+extern "C" void LLVMRustPositionAfter(LLVMBuilderRef B, LLVMValueRef Instr) {
+  if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
+    auto J = I->getNextNode();
+    unwrap(B)->SetInsertPoint(J);
+  }
+}
+
+extern "C" LLVMValueRef
+LLVMRustGetFunctionCall(LLVMValueRef Fn, const char *Name, size_t NameLen) {
+  auto targetName = StringRef(Name, NameLen);
+  Function *F = unwrap<Function>(Fn);
+  for (auto &BB : *F) {
+    for (auto &I : BB) {
+      if (auto *callInst = llvm::dyn_cast<llvm::CallBase>(&I)) {
+        const llvm::Function *calledFunc = callInst->getCalledFunction();
+        if (calledFunc && calledFunc->getName() == targetName) {
+          // Found a call to the target function
+          return wrap(callInst);
+        }
+      }
+    }
+  }
+
+  return nullptr;
+}
+
 extern "C" bool LLVMRustConstIntGetZExtValue(LLVMValueRef CV, uint64_t *value) {
   auto C = unwrap<llvm::ConstantInt>(CV);
   if (C->getBitWidth() > 64)
@@ -1949,3 +1986,29 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
   MD.NoHWAddress = true;
   GV.setSanitizerMetadata(MD);
 }
+
+enum class LLVMRustTailCallKind {
+  None = 0,
+  Tail = 1,
+  MustTail = 2,
+  NoTail = 3
+};
+
+extern "C" void LLVMRustSetTailCallKind(LLVMValueRef Call,
+                                        LLVMRustTailCallKind Kind) {
+  CallInst *CI = unwrap<CallInst>(Call);
+  switch (Kind) {
+  case LLVMRustTailCallKind::None:
+    CI->setTailCallKind(CallInst::TCK_None);
+    break;
+  case LLVMRustTailCallKind::Tail:
+    CI->setTailCallKind(CallInst::TCK_Tail);
+    break;
+  case LLVMRustTailCallKind::MustTail:
+    CI->setTailCallKind(CallInst::TCK_MustTail);
+    break;
+  case LLVMRustTailCallKind::NoTail:
+    CI->setTailCallKind(CallInst::TCK_NoTail);
+    break;
+  }
+}