about summary refs log tree commit diff
diff options
context:
space:
mode:
authorLukas Wirth <lukastw97@gmail.com>2024-12-12 13:20:08 +0100
committerLukas Wirth <lukastw97@gmail.com>2024-12-12 13:23:25 +0100
commit1428cf6032cb70bb8c93d6aafdd05245c48c9be6 (patch)
treed812173b0f1e5ba71725b84f7e08e2b8fe767623
parentaef05d468e67f7a103738694e643b2e57c499b10 (diff)
downloadrust-1428cf6032cb70bb8c93d6aafdd05245c48c9be6.tar.gz
rust-1428cf6032cb70bb8c93d6aafdd05245c48c9be6.zip
Only parse the object file once
-rw-r--r--src/tools/rust-analyzer/crates/proc-macro-srv/src/dylib.rs16
-rw-r--r--src/tools/rust-analyzer/crates/proc-macro-srv/src/dylib/version.rs24
2 files changed, 19 insertions, 21 deletions
diff --git a/src/tools/rust-analyzer/crates/proc-macro-srv/src/dylib.rs b/src/tools/rust-analyzer/crates/proc-macro-srv/src/dylib.rs
index 828d49e6a21..977ca8bafd8 100644
--- a/src/tools/rust-analyzer/crates/proc-macro-srv/src/dylib.rs
+++ b/src/tools/rust-analyzer/crates/proc-macro-srv/src/dylib.rs
@@ -6,7 +6,6 @@ use proc_macro::bridge;
 use std::{fmt, fs, io, time::SystemTime};
 
 use libloading::Library;
-use memmap2::Mmap;
 use object::Object;
 use paths::{Utf8Path, Utf8PathBuf};
 use proc_macro_api::ProcMacroKind;
@@ -23,8 +22,8 @@ fn is_derive_registrar_symbol(symbol: &str) -> bool {
     symbol.contains(NEW_REGISTRAR_SYMBOL)
 }
 
-fn find_registrar_symbol(buffer: &[u8]) -> object::Result<Option<String>> {
-    Ok(object::File::parse(buffer)?
+fn find_registrar_symbol(obj: &object::File<'_>) -> object::Result<Option<String>> {
+    Ok(obj
         .exports()?
         .into_iter()
         .map(|export| export.name())
@@ -109,15 +108,16 @@ struct ProcMacroLibraryLibloading {
 
 impl ProcMacroLibraryLibloading {
     fn open(path: &Utf8Path) -> Result<Self, LoadProcMacroDylibError> {
-        let buffer = unsafe { Mmap::map(&fs::File::open(path)?)? };
+        let file = fs::File::open(path)?;
+        let file = unsafe { memmap2::Mmap::map(&file) }?;
+        let obj = object::File::parse(&*file)
+            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
+        let version_info = version::read_dylib_info(&obj)?;
         let symbol_name =
-            find_registrar_symbol(&buffer).map_err(invalid_data_err)?.ok_or_else(|| {
+            find_registrar_symbol(&obj).map_err(invalid_data_err)?.ok_or_else(|| {
                 invalid_data_err(format!("Cannot find registrar symbol in file {path}"))
             })?;
 
-        let version_info = version::read_dylib_info(&buffer)?;
-        drop(buffer);
-
         let lib = load_library(path).map_err(invalid_data_err)?;
         let proc_macros = crate::proc_macros::ProcMacros::from_lib(
             &lib,
diff --git a/src/tools/rust-analyzer/crates/proc-macro-srv/src/dylib/version.rs b/src/tools/rust-analyzer/crates/proc-macro-srv/src/dylib/version.rs
index c1804e4fef7..4e28aaced9b 100644
--- a/src/tools/rust-analyzer/crates/proc-macro-srv/src/dylib/version.rs
+++ b/src/tools/rust-analyzer/crates/proc-macro-srv/src/dylib/version.rs
@@ -2,7 +2,7 @@
 
 use std::io::{self, Read};
 
-use object::read::{File as BinaryFile, Object, ObjectSection};
+use object::read::{Object, ObjectSection};
 
 #[derive(Debug)]
 #[allow(dead_code)]
@@ -16,14 +16,14 @@ pub struct RustCInfo {
 }
 
 /// Read rustc dylib information
-pub fn read_dylib_info(buffer: &[u8]) -> io::Result<RustCInfo> {
+pub fn read_dylib_info(obj: &object::File<'_>) -> io::Result<RustCInfo> {
     macro_rules! err {
         ($e:literal) => {
             io::Error::new(io::ErrorKind::InvalidData, $e)
         };
     }
 
-    let ver_str = read_version(buffer)?;
+    let ver_str = read_version(obj)?;
     let mut items = ver_str.split_whitespace();
     let tag = items.next().ok_or_else(|| err!("version format error"))?;
     if tag != "rustc" {
@@ -70,10 +70,8 @@ pub fn read_dylib_info(buffer: &[u8]) -> io::Result<RustCInfo> {
 
 /// This is used inside read_version() to locate the ".rustc" section
 /// from a proc macro crate's binary file.
-fn read_section<'a>(dylib_binary: &'a [u8], section_name: &str) -> io::Result<&'a [u8]> {
-    BinaryFile::parse(dylib_binary)
-        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
-        .section_by_name(section_name)
+fn read_section<'a>(obj: &object::File<'a>, section_name: &str) -> io::Result<&'a [u8]> {
+    obj.section_by_name(section_name)
         .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "section read error"))?
         .data()
         .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
@@ -101,8 +99,8 @@ fn read_section<'a>(dylib_binary: &'a [u8], section_name: &str) -> io::Result<&'
 ///
 /// Check this issue for more about the bytes layout:
 /// <https://github.com/rust-lang/rust-analyzer/issues/6174>
-pub fn read_version(buffer: &[u8]) -> io::Result<String> {
-    let dot_rustc = read_section(buffer, ".rustc")?;
+pub fn read_version(obj: &object::File<'_>) -> io::Result<String> {
+    let dot_rustc = read_section(obj, ".rustc")?;
 
     // check if magic is valid
     if &dot_rustc[0..4] != b"rust" {
@@ -151,10 +149,10 @@ pub fn read_version(buffer: &[u8]) -> io::Result<String> {
 
 #[test]
 fn test_version_check() {
-    let info = read_dylib_info(&unsafe {
-        memmap2::Mmap::map(&std::fs::File::open(crate::proc_macro_test_dylib_path()).unwrap())
-            .unwrap()
-    })
+    let info = read_dylib_info(
+        &object::File::parse(&*std::fs::read(crate::proc_macro_test_dylib_path()).unwrap())
+            .unwrap(),
+    )
     .unwrap();
 
     assert_eq!(