about summary refs log tree commit diff
path: root/src/rustllvm/PassWrapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/rustllvm/PassWrapper.cpp')
-rw-r--r--src/rustllvm/PassWrapper.cpp128
1 files changed, 126 insertions, 2 deletions
diff --git a/src/rustllvm/PassWrapper.cpp b/src/rustllvm/PassWrapper.cpp
index fdbe4e5f7ad..7fb1eafb30d 100644
--- a/src/rustllvm/PassWrapper.cpp
+++ b/src/rustllvm/PassWrapper.cpp
@@ -10,11 +10,14 @@
 
 #include <stdio.h>
 
+#include <vector>
+
 #include "rustllvm.h"
 
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/AutoUpgrade.h"
+#include "llvm/IR/AssemblyAnnotationWriter.h"
 #include "llvm/Support/CBindingWrapping.h"
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Support/Host.h"
@@ -503,8 +506,129 @@ LLVMRustWriteOutputFile(LLVMTargetMachineRef Target, LLVMPassManagerRef PMR,
   return LLVMRustResult::Success;
 }
 
+
+// Callback to demangle function name
+// Parameters:
+// * name to be demangled
+// * name len
+// * output buffer
+// * output buffer len
+// Returns len of demangled string, or 0 if demangle failed.
+typedef size_t (*DemangleFn)(const char*, size_t, char*, size_t);
+
+
+namespace {
+
+class RustAssemblyAnnotationWriter : public AssemblyAnnotationWriter {
+  DemangleFn Demangle;
+  std::vector<char> Buf;
+
+public:
+  RustAssemblyAnnotationWriter(DemangleFn Demangle) : Demangle(Demangle) {}
+
+  // Return empty string if demangle failed
+  // or if name does not need to be demangled
+  StringRef CallDemangle(StringRef name) {
+    if (!Demangle) {
+      return StringRef();
+    }
+
+    if (Buf.size() < name.size() * 2) {
+      // Semangled name usually shorter than mangled,
+      // but allocate twice as much memory just in case
+      Buf.resize(name.size() * 2);
+    }
+
+    auto R = Demangle(name.data(), name.size(), Buf.data(), Buf.size());
+    if (!R) {
+      // Demangle failed.
+      return StringRef();
+    }
+
+    auto Demangled = StringRef(Buf.data(), R);
+    if (Demangled == name) {
+      // Do not print anything if demangled name is equal to mangled.
+      return StringRef();
+    }
+
+    return Demangled;
+  }
+
+  void emitFunctionAnnot(const Function *F,
+                         formatted_raw_ostream &OS) override {
+    StringRef Demangled = CallDemangle(F->getName());
+    if (Demangled.empty()) {
+        return;
+    }
+
+    OS << "; " << Demangled << "\n";
+  }
+
+  void emitInstructionAnnot(const Instruction *I,
+                            formatted_raw_ostream &OS) override {
+    const char *Name;
+    const Value *Value;
+    if (const CallInst *CI = dyn_cast<CallInst>(I)) {
+      Name = "call";
+      Value = CI->getCalledValue();
+    } else if (const InvokeInst* II = dyn_cast<InvokeInst>(I)) {
+      Name = "invoke";
+      Value = II->getCalledValue();
+    } else {
+      // Could demangle more operations, e. g.
+      // `store %place, @function`.
+      return;
+    }
+
+    if (!Value->hasName()) {
+      return;
+    }
+
+    StringRef Demangled = CallDemangle(Value->getName());
+    if (Demangled.empty()) {
+      return;
+    }
+
+    OS << "; " << Name << " " << Demangled << "\n";
+  }
+};
+
+class RustPrintModulePass : public ModulePass {
+  raw_ostream* OS;
+  DemangleFn Demangle;
+public:
+  static char ID;
+  RustPrintModulePass() : ModulePass(ID), OS(nullptr), Demangle(nullptr) {}
+  RustPrintModulePass(raw_ostream &OS, DemangleFn Demangle)
+      : ModulePass(ID), OS(&OS), Demangle(Demangle) {}
+
+  bool runOnModule(Module &M) override {
+    RustAssemblyAnnotationWriter AW(Demangle);
+
+    M.print(*OS, &AW, false);
+
+    return false;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesAll();
+  }
+
+  static StringRef name() { return "RustPrintModulePass"; }
+};
+
+} // namespace
+
+namespace llvm {
+  void initializeRustPrintModulePassPass(PassRegistry&);
+}
+
+char RustPrintModulePass::ID = 0;
+INITIALIZE_PASS(RustPrintModulePass, "print-rust-module",
+                "Print rust module to stderr", false, false)
+
 extern "C" void LLVMRustPrintModule(LLVMPassManagerRef PMR, LLVMModuleRef M,
-                                    const char *Path) {
+                                    const char *Path, DemangleFn Demangle) {
   llvm::legacy::PassManager *PM = unwrap<llvm::legacy::PassManager>(PMR);
   std::string ErrorInfo;
 
@@ -515,7 +639,7 @@ extern "C" void LLVMRustPrintModule(LLVMPassManagerRef PMR, LLVMModuleRef M,
 
   formatted_raw_ostream FOS(OS);
 
-  PM->add(createPrintModulePass(FOS));
+  PM->add(new RustPrintModulePass(FOS, Demangle));
 
   PM->run(*unwrap(M));
 }