about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorRalf Jung <post@ralfj.de>2024-11-19 09:36:07 +0000
committerGitHub <noreply@github.com>2024-11-19 09:36:07 +0000
commit4520ff84e151a8eb904146cdd19353a92b671ba8 (patch)
tree6530726e1c55797da55f08ed46e887b936e92f50 /src
parentf953ed513aea07356bf1aaf63ac0fa887e987977 (diff)
parentcecf2b3eae882f17af2f309587c29fec270ad1fd (diff)
downloadrust-4520ff84e151a8eb904146cdd19353a92b671ba8.tar.gz
rust-4520ff84e151a8eb904146cdd19353a92b671ba8.zip
Merge pull request #4035 from discord9/master
refactor: refine thread variant for windows
Diffstat (limited to 'src')
-rw-r--r--src/tools/miri/src/concurrency/thread.rs5
-rw-r--r--src/tools/miri/src/shims/windows/foreign_items.rs23
-rw-r--r--src/tools/miri/src/shims/windows/handle.rs51
-rw-r--r--src/tools/miri/src/shims/windows/thread.rs10
4 files changed, 57 insertions, 32 deletions
diff --git a/src/tools/miri/src/concurrency/thread.rs b/src/tools/miri/src/concurrency/thread.rs
index 9cf301b78d3..59e2fdd4285 100644
--- a/src/tools/miri/src/concurrency/thread.rs
+++ b/src/tools/miri/src/concurrency/thread.rs
@@ -113,6 +113,11 @@ impl ThreadId {
         self.0
     }
 
+    /// Create a new thread id from a `u32` without checking if this thread exists.
+    pub fn new_unchecked(id: u32) -> Self {
+        Self(id)
+    }
+
     pub const MAIN_THREAD: ThreadId = ThreadId(0);
 }
 
diff --git a/src/tools/miri/src/shims/windows/foreign_items.rs b/src/tools/miri/src/shims/windows/foreign_items.rs
index 504efed3cfd..c145cf3ceb8 100644
--- a/src/tools/miri/src/shims/windows/foreign_items.rs
+++ b/src/tools/miri/src/shims/windows/foreign_items.rs
@@ -7,6 +7,7 @@ use rustc_span::Symbol;
 
 use self::shims::windows::handle::{Handle, PseudoHandle};
 use crate::shims::os_str::bytes_to_os_str;
+use crate::shims::windows::handle::HandleError;
 use crate::shims::windows::*;
 use crate::*;
 
@@ -488,7 +489,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                 let thread_id =
                     this.CreateThread(security, stacksize, start, arg, flags, thread)?;
 
-                this.write_scalar(Handle::Thread(thread_id.to_u32()).to_scalar(this), dest)?;
+                this.write_scalar(Handle::Thread(thread_id).to_scalar(this), dest)?;
             }
             "WaitForSingleObject" => {
                 let [handle, timeout] =
@@ -513,10 +514,12 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                 let handle = this.read_scalar(handle)?;
                 let name = this.read_wide_str(this.read_pointer(name)?)?;
 
-                let thread = match Handle::from_scalar(handle, this)? {
-                    Some(Handle::Thread(thread)) => this.thread_id_try_from(thread),
-                    Some(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()),
-                    _ => this.invalid_handle("SetThreadDescription")?,
+                let thread = match Handle::try_from_scalar(handle, this)? {
+                    Ok(Handle::Thread(thread)) => Ok(thread),
+                    Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()),
+                    Ok(_) | Err(HandleError::InvalidHandle) =>
+                        this.invalid_handle("SetThreadDescription")?,
+                    Err(HandleError::ThreadNotFound(e)) => Err(e),
                 };
                 let res = match thread {
                     Ok(thread) => {
@@ -536,10 +539,12 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                 let handle = this.read_scalar(handle)?;
                 let name_ptr = this.deref_pointer(name_ptr)?; // the pointer where we should store the ptr to the name
 
-                let thread = match Handle::from_scalar(handle, this)? {
-                    Some(Handle::Thread(thread)) => this.thread_id_try_from(thread),
-                    Some(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()),
-                    _ => this.invalid_handle("GetThreadDescription")?,
+                let thread = match Handle::try_from_scalar(handle, this)? {
+                    Ok(Handle::Thread(thread)) => Ok(thread),
+                    Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()),
+                    Ok(_) | Err(HandleError::InvalidHandle) =>
+                        this.invalid_handle("GetThreadDescription")?,
+                    Err(HandleError::ThreadNotFound(e)) => Err(e),
                 };
                 let (name, res) = match thread {
                     Ok(thread) => {
diff --git a/src/tools/miri/src/shims/windows/handle.rs b/src/tools/miri/src/shims/windows/handle.rs
index b40c00efedd..3d872b65a63 100644
--- a/src/tools/miri/src/shims/windows/handle.rs
+++ b/src/tools/miri/src/shims/windows/handle.rs
@@ -2,6 +2,7 @@ use std::mem::variant_count;
 
 use rustc_abi::HasDataLayout;
 
+use crate::concurrency::thread::ThreadNotFound;
 use crate::*;
 
 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
@@ -14,7 +15,7 @@ pub enum PseudoHandle {
 pub enum Handle {
     Null,
     Pseudo(PseudoHandle),
-    Thread(u32),
+    Thread(ThreadId),
 }
 
 impl PseudoHandle {
@@ -34,6 +35,14 @@ impl PseudoHandle {
     }
 }
 
+/// Errors that can occur when constructing a [`Handle`] from a Scalar.
+pub enum HandleError {
+    /// There is no thread with the given ID.
+    ThreadNotFound(ThreadNotFound),
+    /// Can't convert scalar to handle because it is structurally invalid.
+    InvalidHandle,
+}
+
 impl Handle {
     const NULL_DISCRIMINANT: u32 = 0;
     const PSEUDO_DISCRIMINANT: u32 = 1;
@@ -51,7 +60,7 @@ impl Handle {
         match self {
             Self::Null => 0,
             Self::Pseudo(pseudo_handle) => pseudo_handle.value(),
-            Self::Thread(thread) => thread,
+            Self::Thread(thread) => thread.to_u32(),
         }
     }
 
@@ -95,7 +104,7 @@ impl Handle {
         match discriminant {
             Self::NULL_DISCRIMINANT if data == 0 => Some(Self::Null),
             Self::PSEUDO_DISCRIMINANT => Some(Self::Pseudo(PseudoHandle::from_value(data)?)),
-            Self::THREAD_DISCRIMINANT => Some(Self::Thread(data)),
+            Self::THREAD_DISCRIMINANT => Some(Self::Thread(ThreadId::new_unchecked(data))),
             _ => None,
         }
     }
@@ -126,10 +135,14 @@ impl Handle {
         Scalar::from_target_isize(signed_handle.into(), cx)
     }
 
-    pub fn from_scalar<'tcx>(
+    /// Convert a scalar into a structured `Handle`.
+    /// Structurally invalid handles return [`HandleError::InvalidHandle`].
+    /// If the handle is structurally valid but semantically invalid, e.g. a for non-existent thread
+    /// ID, returns [`HandleError::ThreadNotFound`].
+    pub fn try_from_scalar<'tcx>(
         handle: Scalar,
-        cx: &impl HasDataLayout,
-    ) -> InterpResult<'tcx, Option<Self>> {
+        cx: &MiriInterpCx<'tcx>,
+    ) -> InterpResult<'tcx, Result<Self, HandleError>> {
         let sign_extended_handle = handle.to_target_isize(cx)?;
 
         #[expect(clippy::cast_sign_loss)] // we want to lose the sign
@@ -137,10 +150,20 @@ impl Handle {
             signed_handle as u32
         } else {
             // if a handle doesn't fit in an i32, it isn't valid.
-            return interp_ok(None);
+            return interp_ok(Err(HandleError::InvalidHandle));
         };
 
-        interp_ok(Self::from_packed(handle))
+        match Self::from_packed(handle) {
+            Some(Self::Thread(thread)) => {
+                // validate the thread id
+                match cx.machine.threads.thread_id_try_from(thread.to_u32()) {
+                    Ok(id) => interp_ok(Ok(Self::Thread(id))),
+                    Err(e) => interp_ok(Err(HandleError::ThreadNotFound(e))),
+                }
+            }
+            Some(handle) => interp_ok(Ok(handle)),
+            None => interp_ok(Err(HandleError::InvalidHandle)),
+        }
     }
 }
 
@@ -158,14 +181,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
         let this = self.eval_context_mut();
 
         let handle = this.read_scalar(handle_op)?;
-        let ret = match Handle::from_scalar(handle, this)? {
-            Some(Handle::Thread(thread)) => {
-                if let Ok(thread) = this.thread_id_try_from(thread) {
-                    this.detach_thread(thread, /*allow_terminated_joined*/ true)?;
-                    this.eval_windows("c", "TRUE")
-                } else {
-                    this.invalid_handle("CloseHandle")?
-                }
+        let ret = match Handle::try_from_scalar(handle, this)? {
+            Ok(Handle::Thread(thread)) => {
+                this.detach_thread(thread, /*allow_terminated_joined*/ true)?;
+                this.eval_windows("c", "TRUE")
             }
             _ => this.invalid_handle("CloseHandle")?,
         };
diff --git a/src/tools/miri/src/shims/windows/thread.rs b/src/tools/miri/src/shims/windows/thread.rs
index 7af15fc647c..efc1c2286bc 100644
--- a/src/tools/miri/src/shims/windows/thread.rs
+++ b/src/tools/miri/src/shims/windows/thread.rs
@@ -65,15 +65,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
         let handle = this.read_scalar(handle_op)?;
         let timeout = this.read_scalar(timeout_op)?.to_u32()?;
 
-        let thread = match Handle::from_scalar(handle, this)? {
-            Some(Handle::Thread(thread)) =>
-                match this.thread_id_try_from(thread) {
-                    Ok(thread) => thread,
-                    Err(_) => this.invalid_handle("WaitForSingleObject")?,
-                },
+        let thread = match Handle::try_from_scalar(handle, this)? {
+            Ok(Handle::Thread(thread)) => thread,
             // Unlike on posix, the outcome of joining the current thread is not documented.
             // On current Windows, it just deadlocks.
-            Some(Handle::Pseudo(PseudoHandle::CurrentThread)) => this.active_thread(),
+            Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => this.active_thread(),
             _ => this.invalid_handle("WaitForSingleObject")?,
         };