about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/miri/src/shims/windows/foreign_items.rs66
-rw-r--r--src/tools/miri/src/shims/windows/handle.rs55
-rw-r--r--src/tools/miri/src/shims/windows/thread.rs8
3 files changed, 78 insertions, 51 deletions
diff --git a/src/tools/miri/src/shims/windows/foreign_items.rs b/src/tools/miri/src/shims/windows/foreign_items.rs
index fae6170a9e7..dda30209275 100644
--- a/src/tools/miri/src/shims/windows/foreign_items.rs
+++ b/src/tools/miri/src/shims/windows/foreign_items.rs
@@ -9,14 +9,9 @@ use rustc_target::callconv::{Conv, FnAbi};
 
 use self::shims::windows::handle::{Handle, PseudoHandle};
 use crate::shims::os_str::bytes_to_os_str;
-use crate::shims::windows::handle::HandleError;
 use crate::shims::windows::*;
 use crate::*;
 
-// The NTSTATUS STATUS_INVALID_HANDLE (0xC0000008) encoded as a HRESULT by setting the N bit.
-// (https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-erref/0642cb2f-2075-4469-918c-4441e69c548a)
-const STATUS_INVALID_HANDLE: u32 = 0xD0000008;
-
 pub fn is_dyn_sym(name: &str) -> bool {
     // std does dynamic detection for these symbols
     matches!(
@@ -498,52 +493,37 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
             "SetThreadDescription" => {
                 let [handle, name] = this.check_shim(abi, sys_conv, link_name, args)?;
 
-                let handle = this.read_scalar(handle)?;
+                let handle = this.read_handle(handle)?;
                 let name = this.read_wide_str(this.read_pointer(name)?)?;
 
-                let thread = match Handle::try_from_scalar(handle, this)? {
-                    Ok(Handle::Thread(thread)) => Ok(thread),
-                    Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()),
-                    Ok(_) | Err(HandleError::InvalidHandle) =>
-                        this.invalid_handle("SetThreadDescription")?,
-                    Err(HandleError::ThreadNotFound(e)) => Err(e),
-                };
-                let res = match thread {
-                    Ok(thread) => {
-                        // FIXME: use non-lossy conversion
-                        this.set_thread_name(thread, String::from_utf16_lossy(&name).into_bytes());
-                        Scalar::from_u32(0)
-                    }
-                    Err(_) => Scalar::from_u32(STATUS_INVALID_HANDLE),
+                let thread = match handle {
+                    Handle::Thread(thread) => thread,
+                    Handle::Pseudo(PseudoHandle::CurrentThread) => this.active_thread(),
+                    _ => this.invalid_handle("SetThreadDescription")?,
                 };
-
-                this.write_scalar(res, dest)?;
+                // FIXME: use non-lossy conversion
+                this.set_thread_name(thread, String::from_utf16_lossy(&name).into_bytes());
+                this.write_scalar(Scalar::from_u32(0), dest)?;
             }
             "GetThreadDescription" => {
                 let [handle, name_ptr] = this.check_shim(abi, sys_conv, link_name, args)?;
 
-                let handle = this.read_scalar(handle)?;
+                let handle = this.read_handle(handle)?;
                 let name_ptr = this.deref_pointer_as(name_ptr, this.machine.layouts.mut_raw_ptr)?; // the pointer where we should store the ptr to the name
 
-                let thread = match Handle::try_from_scalar(handle, this)? {
-                    Ok(Handle::Thread(thread)) => Ok(thread),
-                    Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()),
-                    Ok(_) | Err(HandleError::InvalidHandle) =>
-                        this.invalid_handle("GetThreadDescription")?,
-                    Err(HandleError::ThreadNotFound(e)) => Err(e),
-                };
-                let (name, res) = match thread {
-                    Ok(thread) => {
-                        // Looks like the default thread name is empty.
-                        let name = this.get_thread_name(thread).unwrap_or(b"").to_owned();
-                        let name = this.alloc_os_str_as_wide_str(
-                            bytes_to_os_str(&name)?,
-                            MiriMemoryKind::WinLocal.into(),
-                        )?;
-                        (Scalar::from_maybe_pointer(name, this), Scalar::from_u32(0))
-                    }
-                    Err(_) => (Scalar::null_ptr(this), Scalar::from_u32(STATUS_INVALID_HANDLE)),
+                let thread = match handle {
+                    Handle::Thread(thread) => thread,
+                    Handle::Pseudo(PseudoHandle::CurrentThread) => this.active_thread(),
+                    _ => this.invalid_handle("GetThreadDescription")?,
                 };
+                // Looks like the default thread name is empty.
+                let name = this.get_thread_name(thread).unwrap_or(b"").to_owned();
+                let name = this.alloc_os_str_as_wide_str(
+                    bytes_to_os_str(&name)?,
+                    MiriMemoryKind::WinLocal.into(),
+                )?;
+                let name = Scalar::from_maybe_pointer(name, this);
+                let res = Scalar::from_u32(0);
 
                 this.write_scalar(name, &name_ptr)?;
                 this.write_scalar(res, dest)?;
@@ -638,11 +618,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
                 let [handle, filename, size] = this.check_shim(abi, sys_conv, link_name, args)?;
                 this.check_no_isolation("`GetModuleFileNameW`")?;
 
-                let handle = this.read_target_usize(handle)?;
+                let handle = this.read_handle(handle)?;
                 let filename = this.read_pointer(filename)?;
                 let size = this.read_scalar(size)?.to_u32()?;
 
-                if handle != 0 {
+                if handle != Handle::Null {
                     throw_unsup_format!("`GetModuleFileNameW` only supports the NULL handle");
                 }
 
diff --git a/src/tools/miri/src/shims/windows/handle.rs b/src/tools/miri/src/shims/windows/handle.rs
index c4eb11fbd3f..cac67c888f8 100644
--- a/src/tools/miri/src/shims/windows/handle.rs
+++ b/src/tools/miri/src/shims/windows/handle.rs
@@ -1,4 +1,5 @@
 use std::mem::variant_count;
+use std::panic::Location;
 
 use rustc_abi::HasDataLayout;
 
@@ -16,6 +17,8 @@ pub enum Handle {
     Null,
     Pseudo(PseudoHandle),
     Thread(ThreadId),
+    File(i32),
+    Invalid,
 }
 
 impl PseudoHandle {
@@ -47,12 +50,16 @@ impl Handle {
     const NULL_DISCRIMINANT: u32 = 0;
     const PSEUDO_DISCRIMINANT: u32 = 1;
     const THREAD_DISCRIMINANT: u32 = 2;
+    const FILE_DISCRIMINANT: u32 = 3;
+    const INVALID_DISCRIMINANT: u32 = 7;
 
     fn discriminant(self) -> u32 {
         match self {
             Self::Null => Self::NULL_DISCRIMINANT,
             Self::Pseudo(_) => Self::PSEUDO_DISCRIMINANT,
             Self::Thread(_) => Self::THREAD_DISCRIMINANT,
+            Self::File(_) => Self::FILE_DISCRIMINANT,
+            Self::Invalid => Self::INVALID_DISCRIMINANT,
         }
     }
 
@@ -61,11 +68,16 @@ impl Handle {
             Self::Null => 0,
             Self::Pseudo(pseudo_handle) => pseudo_handle.value(),
             Self::Thread(thread) => thread.to_u32(),
+            #[expect(clippy::cast_sign_loss)]
+            Self::File(fd) => fd as u32,
+            Self::Invalid => 0x1FFFFFFF,
         }
     }
 
     fn packed_disc_size() -> u32 {
         // ceil(log2(x)) is how many bits it takes to store x numbers
+        // We ensure that INVALID_HANDLE_VALUE (0xFFFFFFFF) decodes to Handle::Invalid
+        // see https://devblogs.microsoft.com/oldnewthing/20230914-00/?p=108766
         let variant_count = variant_count::<Self>();
 
         // however, std's ilog2 is floor(log2(x))
@@ -93,7 +105,7 @@ impl Handle {
         assert!(discriminant < 2u32.pow(disc_size));
 
         // make sure the data fits into `data_size` bits
-        assert!(data < 2u32.pow(data_size));
+        assert!(data <= 2u32.pow(data_size));
 
         // packs the data into the lower `data_size` bits
         // and packs the discriminant right above the data
@@ -105,6 +117,9 @@ impl Handle {
             Self::NULL_DISCRIMINANT if data == 0 => Some(Self::Null),
             Self::PSEUDO_DISCRIMINANT => Some(Self::Pseudo(PseudoHandle::from_value(data)?)),
             Self::THREAD_DISCRIMINANT => Some(Self::Thread(ThreadId::new_unchecked(data))),
+            #[expect(clippy::cast_possible_wrap)]
+            Self::FILE_DISCRIMINANT => Some(Self::File(data as i32)),
+            Self::INVALID_DISCRIMINANT => Some(Self::Invalid),
             _ => None,
         }
     }
@@ -171,6 +186,26 @@ impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
 
 #[allow(non_snake_case)]
 pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
+    #[track_caller]
+    fn read_handle(&self, handle: &OpTy<'tcx>) -> InterpResult<'tcx, Handle> {
+        let this = self.eval_context_ref();
+        let handle = this.read_scalar(handle)?;
+        match Handle::try_from_scalar(handle, this)? {
+            Ok(handle) => interp_ok(handle),
+            Err(HandleError::InvalidHandle) =>
+                throw_machine_stop!(TerminationInfo::Abort(format!(
+                    "invalid handle {} at {}",
+                    handle.to_target_isize(this)?,
+                    Location::caller(),
+                ))),
+            Err(HandleError::ThreadNotFound(_)) =>
+                throw_machine_stop!(TerminationInfo::Abort(format!(
+                    "invalid thread ID: {}",
+                    Location::caller()
+                ))),
+        }
+    }
+
     fn invalid_handle(&mut self, function_name: &str) -> InterpResult<'tcx, !> {
         throw_machine_stop!(TerminationInfo::Abort(format!(
             "invalid handle passed to `{function_name}`"
@@ -180,12 +215,24 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
     fn CloseHandle(&mut self, handle_op: &OpTy<'tcx>) -> InterpResult<'tcx, Scalar> {
         let this = self.eval_context_mut();
 
-        let handle = this.read_scalar(handle_op)?;
-        let ret = match Handle::try_from_scalar(handle, this)? {
-            Ok(Handle::Thread(thread)) => {
+        let handle = this.read_handle(handle_op)?;
+        let ret = match handle {
+            Handle::Thread(thread) => {
                 this.detach_thread(thread, /*allow_terminated_joined*/ true)?;
                 this.eval_windows("c", "TRUE")
             }
+            Handle::File(fd) =>
+                if let Some(file) = this.machine.fds.get(fd) {
+                    let err = file.close(this.machine.communicate(), this)?;
+                    if let Err(e) = err {
+                        this.set_last_error(e)?;
+                        this.eval_windows("c", "FALSE")
+                    } else {
+                        this.eval_windows("c", "TRUE")
+                    }
+                } else {
+                    this.invalid_handle("CloseHandle")?
+                },
             _ => this.invalid_handle("CloseHandle")?,
         };
 
diff --git a/src/tools/miri/src/shims/windows/thread.rs b/src/tools/miri/src/shims/windows/thread.rs
index 5db55404422..8289eea3412 100644
--- a/src/tools/miri/src/shims/windows/thread.rs
+++ b/src/tools/miri/src/shims/windows/thread.rs
@@ -62,14 +62,14 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
     ) -> InterpResult<'tcx, Scalar> {
         let this = self.eval_context_mut();
 
-        let handle = this.read_scalar(handle_op)?;
+        let handle = this.read_handle(handle_op)?;
         let timeout = this.read_scalar(timeout_op)?.to_u32()?;
 
-        let thread = match Handle::try_from_scalar(handle, this)? {
-            Ok(Handle::Thread(thread)) => thread,
+        let thread = match handle {
+            Handle::Thread(thread) => thread,
             // Unlike on posix, the outcome of joining the current thread is not documented.
             // On current Windows, it just deadlocks.
-            Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => this.active_thread(),
+            Handle::Pseudo(PseudoHandle::CurrentThread) => this.active_thread(),
             _ => this.invalid_handle("WaitForSingleObject")?,
         };