about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/miri/src/helpers.rs1
-rw-r--r--src/tools/miri/src/shims/env.rs22
-rw-r--r--src/tools/miri/src/shims/extern_static.rs2
-rw-r--r--src/tools/miri/src/shims/unix/env.rs49
-rw-r--r--src/tools/miri/src/shims/unix/freebsd/foreign_items.rs5
-rw-r--r--src/tools/miri/src/shims/unix/linux/foreign_items.rs4
-rw-r--r--src/tools/miri/src/shims/unix/linux_like/syscall.rs6
-rw-r--r--src/tools/miri/src/shims/unix/macos/foreign_items.rs5
-rw-r--r--src/tools/miri/src/shims/windows/foreign_items.rs17
-rw-r--r--src/tools/miri/src/shims/windows/handle.rs4
-rw-r--r--src/tools/miri/tests/fail-dep/concurrency/windows_thread_invalid.rs9
-rw-r--r--src/tools/miri/tests/fail-dep/concurrency/windows_thread_invalid.stderr13
-rw-r--r--src/tools/miri/tests/pass-dep/shims/gettid.rs183
-rw-r--r--src/tools/miri/tests/pass/alloc-access-tracking.rs4
14 files changed, 312 insertions, 12 deletions
diff --git a/src/tools/miri/src/helpers.rs b/src/tools/miri/src/helpers.rs
index fb34600fa37..4821edb0942 100644
--- a/src/tools/miri/src/helpers.rs
+++ b/src/tools/miri/src/helpers.rs
@@ -1337,7 +1337,6 @@ where
 
 /// Check that the number of varargs is at least the minimum what we expect.
 /// Fixed args should not be included.
-/// Use `check_vararg_fixed_arg_count` to extract the varargs slice from full function arguments.
 pub fn check_min_vararg_count<'a, 'tcx, const N: usize>(
     name: &'a str,
     args: &'a [OpTy<'tcx>],
diff --git a/src/tools/miri/src/shims/env.rs b/src/tools/miri/src/shims/env.rs
index e99a8fd6e8c..9dfb1ebb90a 100644
--- a/src/tools/miri/src/shims/env.rs
+++ b/src/tools/miri/src/shims/env.rs
@@ -110,8 +110,30 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
         }
     }
 
+    /// Get the process identifier.
     fn get_pid(&self) -> u32 {
         let this = self.eval_context_ref();
         if this.machine.communicate() { std::process::id() } else { 1000 }
     }
+
+    /// Get an "OS" thread ID for the current thread.
+    fn get_current_tid(&self) -> u32 {
+        let this = self.eval_context_ref();
+        self.get_tid(this.machine.threads.active_thread())
+    }
+
+    /// Get an "OS" thread ID for any thread.
+    fn get_tid(&self, thread: ThreadId) -> u32 {
+        let this = self.eval_context_ref();
+        let index = thread.to_u32();
+        let target_os = &this.tcx.sess.target.os;
+        if target_os == "linux" || target_os == "netbsd" {
+            // On Linux, the main thread has PID == TID so we uphold this. NetBSD also appears
+            // to exhibit the same behavior, though I can't find a citation.
+            this.get_pid().strict_add(index)
+        } else {
+            // Other platforms do not display any relationship between PID and TID.
+            index
+        }
+    }
 }
diff --git a/src/tools/miri/src/shims/extern_static.rs b/src/tools/miri/src/shims/extern_static.rs
index a2ea3dbd88b..5cf9265c704 100644
--- a/src/tools/miri/src/shims/extern_static.rs
+++ b/src/tools/miri/src/shims/extern_static.rs
@@ -66,7 +66,7 @@ impl<'tcx> MiriMachine<'tcx> {
                     ecx,
                     &["__cxa_thread_atexit_impl", "__clock_gettime64"],
                 )?;
-                Self::weak_symbol_extern_statics(ecx, &["getrandom", "statx"])?;
+                Self::weak_symbol_extern_statics(ecx, &["getrandom", "gettid", "statx"])?;
             }
             "freebsd" => {
                 Self::null_ptr_extern_statics(ecx, &["__cxa_thread_atexit_impl"])?;
diff --git a/src/tools/miri/src/shims/unix/env.rs b/src/tools/miri/src/shims/unix/env.rs
index 604fb0974d2..a0e5d3f0127 100644
--- a/src/tools/miri/src/shims/unix/env.rs
+++ b/src/tools/miri/src/shims/unix/env.rs
@@ -274,15 +274,52 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
         interp_ok(Scalar::from_u32(this.get_pid()))
     }
 
-    fn linux_gettid(&mut self) -> InterpResult<'tcx, Scalar> {
+    /// The `gettid`-like function for Unix platforms that take no parameters and return a 32-bit
+    /// integer. It is not always named "gettid".
+    fn unix_gettid(&mut self, link_name: &str) -> InterpResult<'tcx, Scalar> {
         let this = self.eval_context_ref();
-        this.assert_target_os("linux", "gettid");
+        this.assert_target_os_is_unix(link_name);
 
-        let index = this.machine.threads.active_thread().to_u32();
+        // For most platforms the return type is an `i32`, but some are unsigned. The TID
+        // will always be positive so we don't need to differentiate.
+        interp_ok(Scalar::from_u32(this.get_current_tid()))
+    }
+
+    /// The Apple-specific `int pthread_threadid_np(pthread_t thread, uint64_t *thread_id)`, which
+    /// allows querying the ID for arbitrary threads, identified by their pthread_t.
+    ///
+    /// API documentation: <https://www.manpagez.com/man/3/pthread_threadid_np/>.
+    fn apple_pthread_threadip_np(
+        &mut self,
+        thread_op: &OpTy<'tcx>,
+        tid_op: &OpTy<'tcx>,
+    ) -> InterpResult<'tcx, Scalar> {
+        let this = self.eval_context_mut();
+        this.assert_target_os("macos", "pthread_threadip_np");
+
+        let tid_dest = this.read_pointer(tid_op)?;
+        if this.ptr_is_null(tid_dest)? {
+            // If NULL is passed, an error is immediately returned
+            return interp_ok(this.eval_libc("EINVAL"));
+        }
+
+        let thread = this.read_scalar(thread_op)?.to_int(this.libc_ty_layout("pthread_t").size)?;
+        let thread = if thread == 0 {
+            // Null thread ID indicates that we are querying the active thread.
+            this.machine.threads.active_thread()
+        } else {
+            // Our pthread_t is just the raw ThreadId.
+            let Ok(thread) = this.thread_id_try_from(thread) else {
+                return interp_ok(this.eval_libc("ESRCH"));
+            };
+            thread
+        };
 
-        // Compute a TID for this thread, ensuring that the main thread has PID == TID.
-        let tid = this.get_pid().strict_add(index);
+        let tid = this.get_tid(thread);
+        let tid_dest = this.deref_pointer_as(tid_op, this.machine.layouts.u64)?;
+        this.write_int(tid, &tid_dest)?;
 
-        interp_ok(Scalar::from_u32(tid))
+        // Possible errors have been handled, return success.
+        interp_ok(Scalar::from_u32(0))
     }
 }
diff --git a/src/tools/miri/src/shims/unix/freebsd/foreign_items.rs b/src/tools/miri/src/shims/unix/freebsd/foreign_items.rs
index 42502d5bf09..33564a2f84c 100644
--- a/src/tools/miri/src/shims/unix/freebsd/foreign_items.rs
+++ b/src/tools/miri/src/shims/unix/freebsd/foreign_items.rs
@@ -56,6 +56,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                 };
                 this.write_scalar(res, dest)?;
             }
+            "pthread_getthreadid_np" => {
+                let [] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
+                let result = this.unix_gettid(link_name.as_str())?;
+                this.write_scalar(result, dest)?;
+            }
 
             "cpuset_getaffinity" => {
                 // The "same" kind of api as `sched_getaffinity` but more fine grained control for FreeBSD specifically.
diff --git a/src/tools/miri/src/shims/unix/linux/foreign_items.rs b/src/tools/miri/src/shims/unix/linux/foreign_items.rs
index aeaff1cb13a..b3e99e6cc68 100644
--- a/src/tools/miri/src/shims/unix/linux/foreign_items.rs
+++ b/src/tools/miri/src/shims/unix/linux/foreign_items.rs
@@ -18,7 +18,7 @@ use crate::*;
 const TASK_COMM_LEN: u64 = 16;
 
 pub fn is_dyn_sym(name: &str) -> bool {
-    matches!(name, "statx")
+    matches!(name, "gettid" | "statx")
 }
 
 impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -117,7 +117,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
             }
             "gettid" => {
                 let [] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
-                let result = this.linux_gettid()?;
+                let result = this.unix_gettid(link_name.as_str())?;
                 this.write_scalar(result, dest)?;
             }
 
diff --git a/src/tools/miri/src/shims/unix/linux_like/syscall.rs b/src/tools/miri/src/shims/unix/linux_like/syscall.rs
index d42d6b9073e..d3534e6e1bc 100644
--- a/src/tools/miri/src/shims/unix/linux_like/syscall.rs
+++ b/src/tools/miri/src/shims/unix/linux_like/syscall.rs
@@ -4,6 +4,7 @@ use rustc_span::Symbol;
 use rustc_target::callconv::FnAbi;
 
 use crate::helpers::check_min_vararg_count;
+use crate::shims::unix::env::EvalContextExt;
 use crate::shims::unix::linux_like::eventfd::EvalContextExt as _;
 use crate::shims::unix::linux_like::sync::futex;
 use crate::*;
@@ -24,6 +25,7 @@ pub fn syscall<'tcx>(
     let sys_getrandom = ecx.eval_libc("SYS_getrandom").to_target_usize(ecx)?;
     let sys_futex = ecx.eval_libc("SYS_futex").to_target_usize(ecx)?;
     let sys_eventfd2 = ecx.eval_libc("SYS_eventfd2").to_target_usize(ecx)?;
+    let sys_gettid = ecx.eval_libc("SYS_gettid").to_target_usize(ecx)?;
 
     match ecx.read_target_usize(op)? {
         // `libc::syscall(NR_GETRANDOM, buf.as_mut_ptr(), buf.len(), GRND_NONBLOCK)`
@@ -53,6 +55,10 @@ pub fn syscall<'tcx>(
             let result = ecx.eventfd(initval, flags)?;
             ecx.write_int(result.to_i32()?, dest)?;
         }
+        num if num == sys_gettid => {
+            let result = ecx.unix_gettid("SYS_gettid")?;
+            ecx.write_int(result.to_u32()?, dest)?;
+        }
         num => {
             throw_unsup_format!("syscall: unsupported syscall number {num}");
         }
diff --git a/src/tools/miri/src/shims/unix/macos/foreign_items.rs b/src/tools/miri/src/shims/unix/macos/foreign_items.rs
index ae921a013a4..23303718091 100644
--- a/src/tools/miri/src/shims/unix/macos/foreign_items.rs
+++ b/src/tools/miri/src/shims/unix/macos/foreign_items.rs
@@ -222,6 +222,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                 };
                 this.write_scalar(res, dest)?;
             }
+            "pthread_threadid_np" => {
+                let [thread, tid_ptr] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
+                let res = this.apple_pthread_threadip_np(thread, tid_ptr)?;
+                this.write_scalar(res, dest)?;
+            }
 
             // Synchronization primitives
             "os_sync_wait_on_address" => {
diff --git a/src/tools/miri/src/shims/windows/foreign_items.rs b/src/tools/miri/src/shims/windows/foreign_items.rs
index de10357f5fa..959abc0baca 100644
--- a/src/tools/miri/src/shims/windows/foreign_items.rs
+++ b/src/tools/miri/src/shims/windows/foreign_items.rs
@@ -629,6 +629,23 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                 this.write_scalar(name, &name_ptr)?;
                 this.write_scalar(res, dest)?;
             }
+            "GetThreadId" => {
+                let [handle] = this.check_shim(abi, sys_conv, link_name, args)?;
+                let handle = this.read_handle(handle, "GetThreadId")?;
+                let thread = match handle {
+                    Handle::Thread(thread) => thread,
+                    Handle::Pseudo(PseudoHandle::CurrentThread) => this.active_thread(),
+                    _ => this.invalid_handle("GetThreadDescription")?,
+                };
+                let tid = this.get_tid(thread);
+                this.write_scalar(Scalar::from_u32(tid), dest)?;
+            }
+            "GetCurrentThreadId" => {
+                let [] = this.check_shim(abi, sys_conv, link_name, args)?;
+                let thread = this.active_thread();
+                let tid = this.get_tid(thread);
+                this.write_scalar(Scalar::from_u32(tid), dest)?;
+            }
 
             // Miscellaneous
             "ExitProcess" => {
diff --git a/src/tools/miri/src/shims/windows/handle.rs b/src/tools/miri/src/shims/windows/handle.rs
index 1e30bf25ed9..8a965ea316d 100644
--- a/src/tools/miri/src/shims/windows/handle.rs
+++ b/src/tools/miri/src/shims/windows/handle.rs
@@ -166,6 +166,10 @@ impl Handle {
     /// Structurally invalid handles return [`HandleError::InvalidHandle`].
     /// If the handle is structurally valid but semantically invalid, e.g. a for non-existent thread
     /// ID, returns [`HandleError::ThreadNotFound`].
+    ///
+    /// This function is deliberately private; shims should always use `read_handle`.
+    /// That enforces handle validity even when Windows does not: for now, we argue invalid
+    /// handles are always a bug and programmers likely want to know about them.
     fn try_from_scalar<'tcx>(
         handle: Scalar,
         cx: &MiriInterpCx<'tcx>,
diff --git a/src/tools/miri/tests/fail-dep/concurrency/windows_thread_invalid.rs b/src/tools/miri/tests/fail-dep/concurrency/windows_thread_invalid.rs
new file mode 100644
index 00000000000..2e0729c9b49
--- /dev/null
+++ b/src/tools/miri/tests/fail-dep/concurrency/windows_thread_invalid.rs
@@ -0,0 +1,9 @@
+//! Ensure we error if thread functions are called with invalid handles
+//@only-target: windows # testing Windows API
+
+use windows_sys::Win32::System::Threading::GetThreadId;
+
+fn main() {
+    let _tid = unsafe { GetThreadId(std::ptr::dangling_mut()) };
+    //~^ ERROR: invalid handle
+}
diff --git a/src/tools/miri/tests/fail-dep/concurrency/windows_thread_invalid.stderr b/src/tools/miri/tests/fail-dep/concurrency/windows_thread_invalid.stderr
new file mode 100644
index 00000000000..8d4b049b740
--- /dev/null
+++ b/src/tools/miri/tests/fail-dep/concurrency/windows_thread_invalid.stderr
@@ -0,0 +1,13 @@
+error: abnormal termination: invalid handle 1 passed to GetThreadId
+  --> tests/fail-dep/concurrency/windows_thread_invalid.rs:LL:CC
+   |
+LL |     let _tid = unsafe { GetThreadId(std::ptr::dangling_mut()) };
+   |                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ abnormal termination occurred here
+   |
+   = note: BACKTRACE:
+   = note: inside `main` at tests/fail-dep/concurrency/windows_thread_invalid.rs:LL:CC
+
+note: some details are omitted, run with `MIRIFLAGS=-Zmiri-backtrace=full` for a verbose backtrace
+
+error: aborting due to 1 previous error
+
diff --git a/src/tools/miri/tests/pass-dep/shims/gettid.rs b/src/tools/miri/tests/pass-dep/shims/gettid.rs
new file mode 100644
index 00000000000..b7a2fa49ef8
--- /dev/null
+++ b/src/tools/miri/tests/pass-dep/shims/gettid.rs
@@ -0,0 +1,183 @@
+//! Test for `gettid` and similar functions for retrieving an OS thread ID.
+//@ revisions: with_isolation without_isolation
+//@ [without_isolation] compile-flags: -Zmiri-disable-isolation
+
+#![feature(linkage)]
+
+fn gettid() -> u64 {
+    cfg_if::cfg_if! {
+        if #[cfg(any(target_os = "android", target_os = "linux"))] {
+            gettid_linux_like()
+        } else if #[cfg(target_os = "nto")] {
+            unsafe { libc::gettid() as u64 }
+        } else if #[cfg(target_os = "openbsd")] {
+            unsafe { libc::getthrid() as u64 }
+        } else if #[cfg(target_os = "freebsd")] {
+            unsafe { libc::pthread_getthreadid_np() as u64 }
+        } else if #[cfg(target_os = "netbsd")] {
+            unsafe { libc::_lwp_self() as u64 }
+        } else if #[cfg(any(target_os = "solaris", target_os = "illumos"))] {
+            // On Solaris and Illumos, the `pthread_t` is the OS TID.
+            unsafe { libc::pthread_self() as u64 }
+        } else if #[cfg(target_vendor = "apple")] {
+            let mut id = 0u64;
+            let status: libc::c_int = unsafe { libc::pthread_threadid_np(0, &mut id) };
+            assert_eq!(status, 0);
+            id
+        } else if #[cfg(windows)] {
+            use windows_sys::Win32::System::Threading::GetCurrentThreadId;
+            unsafe { GetCurrentThreadId() as u64 }
+        } else {
+            compile_error!("platform has no gettid")
+        }
+    }
+}
+
+/// Test the libc function, the syscall, and the extern symbol.
+#[cfg(any(target_os = "android", target_os = "linux"))]
+fn gettid_linux_like() -> u64 {
+    unsafe extern "C" {
+        #[linkage = "extern_weak"]
+        static gettid: Option<unsafe extern "C" fn() -> libc::pid_t>;
+    }
+
+    let from_libc = unsafe { libc::gettid() as u64 };
+    let from_sys = unsafe { libc::syscall(libc::SYS_gettid) as u64 };
+    let from_static = unsafe { gettid.unwrap()() as u64 };
+
+    assert_eq!(from_libc, from_sys);
+    assert_eq!(from_libc, from_static);
+
+    from_libc
+}
+
+/// Specific platforms can query the tid of arbitrary threads from a `pthread_t` / `HANDLE`
+#[cfg(any(target_vendor = "apple", windows))]
+mod queried {
+    use std::ffi::c_void;
+    use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
+    use std::{ptr, thread};
+
+    use super::*;
+
+    static SPAWNED_TID: AtomicU64 = AtomicU64::new(0);
+    static CAN_JOIN: AtomicBool = AtomicBool::new(false);
+
+    /// Save this thread's TID, give the spawning thread a chance to query it separately before
+    /// being joined.
+    fn thread_body() {
+        SPAWNED_TID.store(gettid(), Ordering::Relaxed);
+
+        // Spin until the main thread has a chance to read this thread's ID
+        while !CAN_JOIN.load(Ordering::Relaxed) {
+            thread::yield_now();
+        }
+    }
+
+    /// Spawn a thread, query then return its TID.
+    #[cfg(target_vendor = "apple")]
+    fn spawn_update_join() -> u64 {
+        extern "C" fn thread_start(_data: *mut c_void) -> *mut c_void {
+            thread_body();
+            ptr::null_mut()
+        }
+
+        let mut t: libc::pthread_t = 0;
+        let mut spawned_tid_from_handle = 0u64;
+
+        unsafe {
+            let res = libc::pthread_create(&mut t, ptr::null(), thread_start, ptr::null_mut());
+            assert_eq!(res, 0, "failed to spawn thread");
+
+            let res = libc::pthread_threadid_np(t, &mut spawned_tid_from_handle);
+            assert_eq!(res, 0, "failed to query thread ID");
+            CAN_JOIN.store(true, Ordering::Relaxed);
+
+            let res = libc::pthread_join(t, ptr::null_mut());
+            assert_eq!(res, 0, "failed to join thread");
+
+            // Apple also has two documented return values for invalid threads and null pointers
+            let res = libc::pthread_threadid_np(libc::pthread_t::MAX, &mut 0);
+            assert_eq!(res, libc::ESRCH, "expected ESRCH for invalid TID");
+            let res = libc::pthread_threadid_np(0, ptr::null_mut());
+            assert_eq!(res, libc::EINVAL, "invalid EINVAL for a null pointer");
+        }
+
+        spawned_tid_from_handle
+    }
+
+    /// Spawn a thread, query then return its TID.
+    #[cfg(windows)]
+    fn spawn_update_join() -> u64 {
+        use windows_sys::Win32::Foundation::WAIT_FAILED;
+        use windows_sys::Win32::System::Threading::{
+            CreateThread, GetThreadId, INFINITE, WaitForSingleObject,
+        };
+
+        extern "system" fn thread_start(_data: *mut c_void) -> u32 {
+            thread_body();
+            0
+        }
+
+        let spawned_tid_from_handle;
+        let mut tid_at_spawn = 0u32;
+
+        unsafe {
+            let handle =
+                CreateThread(ptr::null(), 0, Some(thread_start), ptr::null(), 0, &mut tid_at_spawn);
+            assert!(!handle.is_null(), "failed to spawn thread");
+
+            spawned_tid_from_handle = GetThreadId(handle);
+            assert_ne!(spawned_tid_from_handle, 0, "failed to query thread ID");
+            CAN_JOIN.store(true, Ordering::Relaxed);
+
+            let res = WaitForSingleObject(handle, INFINITE);
+            assert_ne!(res, WAIT_FAILED, "failed to join thread");
+        }
+
+        // Windows also indirectly returns the TID from `CreateThread`, ensure that matches up.
+        assert_eq!(spawned_tid_from_handle, tid_at_spawn);
+
+        spawned_tid_from_handle.into()
+    }
+
+    pub fn check() {
+        let spawned_tid_from_handle = spawn_update_join();
+        let spawned_tid_from_self = SPAWNED_TID.load(Ordering::Relaxed);
+        let current_tid = gettid();
+
+        // Ensure that we got a different thread ID.
+        assert_ne!(spawned_tid_from_handle, current_tid);
+
+        // Ensure that we get the same result from `gettid` and from querying a thread's handle
+        assert_eq!(spawned_tid_from_handle, spawned_tid_from_self);
+    }
+}
+
+fn main() {
+    let tid = gettid();
+
+    std::thread::spawn(move || {
+        assert_ne!(gettid(), tid);
+    });
+
+    // Test that in isolation mode a deterministic value will be returned.
+    // The value is not important, we only care that whatever the value is,
+    // won't change from execution to execution.
+    if cfg!(with_isolation) {
+        if cfg!(target_os = "linux") {
+            // Linux starts the TID at the PID, which is 1000.
+            assert_eq!(tid, 1000);
+        } else {
+            // Other platforms start counting from 0.
+            assert_eq!(tid, 0);
+        }
+    }
+
+    // On Linux and NetBSD, the first TID is the PID.
+    #[cfg(any(target_os = "linux", target_os = "netbsd"))]
+    assert_eq!(tid, unsafe { libc::getpid() } as u64);
+
+    #[cfg(any(target_vendor = "apple", windows))]
+    queried::check();
+}
diff --git a/src/tools/miri/tests/pass/alloc-access-tracking.rs b/src/tools/miri/tests/pass/alloc-access-tracking.rs
index 9eba0ca171b..c47063bef03 100644
--- a/src/tools/miri/tests/pass/alloc-access-tracking.rs
+++ b/src/tools/miri/tests/pass/alloc-access-tracking.rs
@@ -1,7 +1,7 @@
 #![no_std]
 #![no_main]
-//@compile-flags: -Zmiri-track-alloc-id=19 -Zmiri-track-alloc-accesses -Cpanic=abort
-//@normalize-stderr-test: "id 19" -> "id $$ALLOC"
+//@compile-flags: -Zmiri-track-alloc-id=21 -Zmiri-track-alloc-accesses -Cpanic=abort
+//@normalize-stderr-test: "id 21" -> "id $$ALLOC"
 //@only-target: linux # alloc IDs differ between OSes (due to extern static allocations)
 
 extern "Rust" {