about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--library/std/src/sys/windows/c.rs23
-rw-r--r--library/std/src/sys/windows/compat.rs246
2 files changed, 205 insertions, 64 deletions
diff --git a/library/std/src/sys/windows/c.rs b/library/std/src/sys/windows/c.rs
index c7a42ef9a93..e9c20850420 100644
--- a/library/std/src/sys/windows/c.rs
+++ b/library/std/src/sys/windows/c.rs
@@ -4,6 +4,7 @@
 #![cfg_attr(test, allow(dead_code))]
 #![unstable(issue = "none", feature = "windows_c")]
 
+use crate::ffi::CStr;
 use crate::mem;
 use crate::os::raw::{c_char, c_int, c_long, c_longlong, c_uint, c_ulong, c_ushort};
 use crate::os::windows::io::{BorrowedHandle, HandleOrInvalid, HandleOrNull};
@@ -1219,8 +1220,8 @@ extern "system" {
 
 // Functions that aren't available on every version of Windows that we support,
 // but we still use them and just provide some form of a fallback implementation.
-compat_fn! {
-    "kernel32":
+compat_fn_with_fallback! {
+    pub static KERNEL32: &CStr = ansi_str!("kernel32");
 
     // >= Win10 1607
     // https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-setthreaddescription
@@ -1243,8 +1244,8 @@ compat_fn! {
     }
 }
 
-compat_fn! {
-    "api-ms-win-core-synch-l1-2-0":
+compat_fn_optional! {
+    pub static SYNCH_API: &CStr = ansi_str!("api-ms-win-core-synch-l1-2-0");
 
     // >= Windows 8 / Server 2012
     // https://docs.microsoft.com/en-us/windows/win32/api/synchapi/nf-synchapi-waitonaddress
@@ -1253,17 +1254,13 @@ compat_fn! {
         CompareAddress: LPVOID,
         AddressSize: SIZE_T,
         dwMilliseconds: DWORD
-    ) -> BOOL {
-        panic!("WaitOnAddress not available")
-    }
-    pub fn WakeByAddressSingle(Address: LPVOID) -> () {
-        // If this api is unavailable, there cannot be anything waiting, because
-        // WaitOnAddress would've panicked. So it's fine to do nothing here.
-    }
+    ) -> BOOL;
+    pub fn WakeByAddressSingle(Address: LPVOID) -> ();
 }
 
-compat_fn! {
-    "ntdll":
+compat_fn_with_fallback! {
+    pub static NTDLL: &CStr = ansi_str!("ntdll");
+
     pub fn NtCreateFile(
         FileHandle: *mut HANDLE,
         DesiredAccess: ACCESS_MASK,
diff --git a/library/std/src/sys/windows/compat.rs b/library/std/src/sys/windows/compat.rs
index ded97bb7eaa..b4ffbdc9609 100644
--- a/library/std/src/sys/windows/compat.rs
+++ b/library/std/src/sys/windows/compat.rs
@@ -49,81 +49,225 @@
 //! * call any Rust function or CRT function that touches any static
 //!   (global) state.
 
-macro_rules! compat_fn {
-    ($module:literal: $(
+use crate::ffi::{c_void, CStr};
+use crate::ptr::NonNull;
+use crate::sys::c;
+
+/// Helper macro for creating CStrs from literals and symbol names.
+macro_rules! ansi_str {
+    (sym $ident:ident) => {{
+        #[allow(unused_unsafe)]
+        crate::sys::compat::const_cstr_from_bytes(concat!(stringify!($ident), "\0").as_bytes())
+    }};
+    ($lit:literal) => {{ crate::sys::compat::const_cstr_from_bytes(concat!($lit, "\0").as_bytes()) }};
+}
+
+/// Creates a C string wrapper from a byte slice, in a constant context.
+///
+/// This is a utility function used by the [`ansi_str`] macro.
+///
+/// # Panics
+///
+/// Panics if the slice is not null terminated or contains nulls, except as the last item
+pub(crate) const fn const_cstr_from_bytes(bytes: &'static [u8]) -> &'static CStr {
+    if !matches!(bytes.last(), Some(&0)) {
+        panic!("A CStr must be null terminated");
+    }
+    let mut i = 0;
+    // At this point `len()` is at least 1.
+    while i < bytes.len() - 1 {
+        if bytes[i] == 0 {
+            panic!("A CStr must not have interior nulls")
+        }
+        i += 1;
+    }
+    // SAFETY: The safety is ensured by the above checks.
+    unsafe { crate::ffi::CStr::from_bytes_with_nul_unchecked(bytes) }
+}
+
+#[used]
+#[link_section = ".CRT$XCU"]
+static INIT_TABLE_ENTRY: unsafe extern "C" fn() = init;
+
+/// This is where the magic preloading of symbols happens.
+///
+/// Note that any functions included here will be unconditionally included in
+/// the final binary, regardless of whether or not they're actually used.
+///
+/// Therefore, this is limited to `compat_fn_optional` functions which must be
+/// preloaded and any functions which may be more time sensitive, even for the first call.
+unsafe extern "C" fn init() {
+    // There is no locking here. This code is executed before main() is entered, and
+    // is guaranteed to be single-threaded.
+    //
+    // DO NOT do anything interesting or complicated in this function! DO NOT call
+    // any Rust functions or CRT functions if those functions touch any global state,
+    // because this function runs during global initialization. For example, DO NOT
+    // do any dynamic allocation, don't call LoadLibrary, etc.
+
+    if let Some(synch) = Module::new(c::SYNCH_API) {
+        // These are optional and so we must manually attempt to load them
+        // before they can be used.
+        c::WaitOnAddress::preload(synch);
+        c::WakeByAddressSingle::preload(synch);
+    }
+
+    if let Some(kernel32) = Module::new(c::KERNEL32) {
+        // Preloading this means getting a precise time will be as fast as possible.
+        c::GetSystemTimePreciseAsFileTime::preload(kernel32);
+    }
+}
+
+/// Represents a loaded module.
+///
+/// Note that the modules std depends on must not be unloaded.
+/// Therefore a `Module` is always valid for the lifetime of std.
+#[derive(Copy, Clone)]
+pub(in crate::sys) struct Module(NonNull<c_void>);
+impl Module {
+    /// Try to get a handle to a loaded module.
+    ///
+    /// # SAFETY
+    ///
+    /// This should only be use for modules that exist for the lifetime of std
+    /// (e.g. kernel32 and ntdll).
+    pub unsafe fn new(name: &CStr) -> Option<Self> {
+        // SAFETY: A CStr is always null terminated.
+        let module = c::GetModuleHandleA(name.as_ptr());
+        NonNull::new(module).map(Self)
+    }
+
+    // Try to get the address of a function.
+    pub fn proc_address(self, name: &CStr) -> Option<NonNull<c_void>> {
+        // SAFETY:
+        // `self.0` will always be a valid module.
+        // A CStr is always null terminated.
+        let proc = unsafe { c::GetProcAddress(self.0.as_ptr(), name.as_ptr()) };
+        NonNull::new(proc)
+    }
+}
+
+/// Load a function or use a fallback implementation if that fails.
+macro_rules! compat_fn_with_fallback {
+    (pub static $module:ident: &CStr = $name:expr; $(
         $(#[$meta:meta])*
         pub fn $symbol:ident($($argname:ident: $argtype:ty),*) -> $rettype:ty $fallback_body:block
-    )*) => ($(
+    )*) => (
+        pub static $module: &CStr = $name;
+    $(
         $(#[$meta])*
         pub mod $symbol {
             #[allow(unused_imports)]
             use super::*;
             use crate::mem;
+            use crate::ffi::CStr;
+            use crate::sync::atomic::{AtomicPtr, Ordering};
+            use crate::sys::compat::Module;
 
             type F = unsafe extern "system" fn($($argtype),*) -> $rettype;
 
-            /// Points to the DLL import, or the fallback function.
-            ///
-            /// This static can be an ordinary, unsynchronized, mutable static because
-            /// we guarantee that all of the writes finish during CRT initialization,
-            /// and all of the reads occur after CRT initialization.
-            static mut PTR: Option<F> = None;
-
-            /// This symbol is what allows the CRT to find the `init` function and call it.
-            /// It is marked `#[used]` because otherwise Rust would assume that it was not
-            /// used, and would remove it.
-            #[used]
-            #[link_section = ".CRT$XCU"]
-            static INIT_TABLE_ENTRY: unsafe extern "C" fn() = init;
-
-            unsafe extern "C" fn init() {
-                PTR = get_f();
+            /// `PTR` contains a function pointer to one of three functions.
+            /// It starts with the `load` function.
+            /// When that is called it attempts to load the requested symbol.
+            /// If it succeeds, `PTR` is set to the address of that symbol.
+            /// If it fails, then `PTR` is set to `fallback`.
+            static PTR: AtomicPtr<c_void> = AtomicPtr::new(load as *mut _);
+
+            unsafe extern "system" fn load($($argname: $argtype),*) -> $rettype {
+                let func = load_from_module(Module::new($module));
+                func($($argname),*)
             }
 
-            unsafe extern "C" fn get_f() -> Option<F> {
-                // There is no locking here. This code is executed before main() is entered, and
-                // is guaranteed to be single-threaded.
-                //
-                // DO NOT do anything interesting or complicated in this function! DO NOT call
-                // any Rust functions or CRT functions, if those functions touch any global state,
-                // because this function runs during global initialization. For example, DO NOT
-                // do any dynamic allocation, don't call LoadLibrary, etc.
-                let module_name: *const u8 = concat!($module, "\0").as_ptr();
-                let symbol_name: *const u8 = concat!(stringify!($symbol), "\0").as_ptr();
-                let module_handle = $crate::sys::c::GetModuleHandleA(module_name as *const i8);
-                if !module_handle.is_null() {
-                    let ptr = $crate::sys::c::GetProcAddress(module_handle, symbol_name as *const i8);
-                    if !ptr.is_null() {
-                        // Transmute to the right function pointer type.
-                        return Some(mem::transmute(ptr));
+            fn load_from_module(module: Option<Module>) -> F {
+                unsafe {
+                    static symbol_name: &CStr = ansi_str!(sym $symbol);
+                    if let Some(f) = module.and_then(|m| m.proc_address(symbol_name)) {
+                        PTR.store(f.as_ptr(), Ordering::Relaxed);
+                        mem::transmute(f)
+                    } else {
+                        PTR.store(fallback as *mut _, Ordering::Relaxed);
+                        fallback
                     }
                 }
-                return None;
             }
 
-            #[allow(dead_code)]
+            #[allow(unused_variables)]
+            unsafe extern "system" fn fallback($($argname: $argtype),*) -> $rettype {
+                $fallback_body
+            }
+
+            #[allow(unused)]
+            pub(in crate::sys) fn preload(module: Module) {
+                load_from_module(Some(module));
+            }
+
+            #[inline(always)]
+            pub unsafe fn call($($argname: $argtype),*) -> $rettype {
+                let func: F = mem::transmute(PTR.load(Ordering::Relaxed));
+                func($($argname),*)
+            }
+        }
+        $(#[$meta])*
+        pub use $symbol::call as $symbol;
+    )*)
+}
+
+/// A function that either exists or doesn't.
+///
+/// NOTE: Optional functions must be preloaded in the `init` function above, or they will always be None.
+macro_rules! compat_fn_optional {
+    (pub static $module:ident: &CStr = $name:expr; $(
+        $(#[$meta:meta])*
+        pub fn $symbol:ident($($argname:ident: $argtype:ty),*) -> $rettype:ty;
+    )*) => (
+        pub static $module: &CStr = $name;
+    $(
+        $(#[$meta])*
+        pub mod $symbol {
+            #[allow(unused_imports)]
+            use super::*;
+            use crate::mem;
+            use crate::sync::atomic::{AtomicPtr, Ordering};
+            use crate::sys::compat::Module;
+            use crate::ptr::{self, NonNull};
+
+            type F = unsafe extern "system" fn($($argtype),*) -> $rettype;
+
+            /// `PTR` will either be `null()` or set to the loaded function.
+            static PTR: AtomicPtr<c_void> = AtomicPtr::new(ptr::null_mut());
+
+            /// Only allow access to the function if it has loaded successfully.
             #[inline(always)]
+            #[cfg(not(miri))]
             pub fn option() -> Option<F> {
                 unsafe {
-                    if cfg!(miri) {
-                        // Miri does not run `init`, so we just call `get_f` each time.
-                        get_f()
-                    } else {
-                        PTR
-                    }
+                    NonNull::new(PTR.load(Ordering::Relaxed)).map(|f| mem::transmute(f))
                 }
             }
 
-            #[allow(dead_code)]
-            pub unsafe fn call($($argname: $argtype),*) -> $rettype {
-                if let Some(ptr) = option() {
-                    return ptr($($argname),*);
+            // Miri does not understand the way we do preloading
+            // therefore load the function here instead.
+            #[cfg(miri)]
+            pub fn option() -> Option<F> {
+                let mut func = NonNull::new(PTR.load(Ordering::Relaxed));
+                if func.is_none() {
+                    Module::new($module).map(preload);
+                    func = NonNull::new(PTR.load(Ordering::Relaxed));
+                }
+                unsafe {
+                    func.map(|f| mem::transmute(f))
                 }
-                $fallback_body
             }
-        }
 
-        $(#[$meta])*
-        pub use $symbol::call as $symbol;
+            #[allow(unused)]
+            pub(in crate::sys) fn preload(module: Module) {
+                unsafe {
+                    let symbol_name = ansi_str!(sym $symbol);
+                    if let Some(f) = module.proc_address(symbol_name) {
+                        PTR.store(f.as_ptr(), Ordering::Relaxed);
+                    }
+                }
+            }
+        }
     )*)
 }