about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/tools/miri/src/concurrency/thread.rs182
-rw-r--r--src/tools/miri/src/shims/unix/foreign_items.rs3
-rw-r--r--src/tools/miri/src/shims/unix/thread.rs14
-rw-r--r--src/tools/miri/src/shims/windows/foreign_items.rs3
-rw-r--r--src/tools/miri/src/shims/windows/thread.rs13
5 files changed, 111 insertions, 104 deletions
diff --git a/src/tools/miri/src/concurrency/thread.rs b/src/tools/miri/src/concurrency/thread.rs
index 38b5d4c0f06..c8a408fd8ac 100644
--- a/src/tools/miri/src/concurrency/thread.rs
+++ b/src/tools/miri/src/concurrency/thread.rs
@@ -582,88 +582,6 @@ impl<'tcx> ThreadManager<'tcx> {
         interp_ok(())
     }
 
-    /// Mark that the active thread tries to join the thread with `joined_thread_id`.
-    fn join_thread(
-        &mut self,
-        joined_thread_id: ThreadId,
-        data_race_handler: &mut GlobalDataRaceHandler,
-    ) -> InterpResult<'tcx> {
-        if self.threads[joined_thread_id].join_status == ThreadJoinStatus::Detached {
-            // On Windows this corresponds to joining on a closed handle.
-            throw_ub_format!("trying to join a detached thread");
-        }
-
-        fn after_join<'tcx>(
-            threads: &mut ThreadManager<'_>,
-            joined_thread_id: ThreadId,
-            data_race_handler: &mut GlobalDataRaceHandler,
-        ) -> InterpResult<'tcx> {
-            match data_race_handler {
-                GlobalDataRaceHandler::None => {}
-                GlobalDataRaceHandler::Vclocks(data_race) =>
-                    data_race.thread_joined(threads, joined_thread_id),
-                GlobalDataRaceHandler::Genmc(genmc_ctx) =>
-                    genmc_ctx.handle_thread_join(threads.active_thread, joined_thread_id)?,
-            }
-            interp_ok(())
-        }
-
-        // Mark the joined thread as being joined so that we detect if other
-        // threads try to join it.
-        self.threads[joined_thread_id].join_status = ThreadJoinStatus::Joined;
-        if !self.threads[joined_thread_id].state.is_terminated() {
-            trace!(
-                "{:?} blocked on {:?} when trying to join",
-                self.active_thread, joined_thread_id
-            );
-            // The joined thread is still running, we need to wait for it.
-            // Unce we get unblocked, perform the appropriate synchronization.
-            self.block_thread(
-                BlockReason::Join(joined_thread_id),
-                None,
-                callback!(
-                    @capture<'tcx> {
-                        joined_thread_id: ThreadId,
-                    }
-                    |this, unblock: UnblockKind| {
-                        assert_eq!(unblock, UnblockKind::Ready);
-                        after_join(&mut this.machine.threads, joined_thread_id, &mut this.machine.data_race)
-                    }
-                ),
-            );
-        } else {
-            // The thread has already terminated - establish happens-before
-            after_join(self, joined_thread_id, data_race_handler)?;
-        }
-        interp_ok(())
-    }
-
-    /// Mark that the active thread tries to exclusively join the thread with `joined_thread_id`.
-    /// If the thread is already joined by another thread, it will throw UB
-    fn join_thread_exclusive(
-        &mut self,
-        joined_thread_id: ThreadId,
-        data_race_handler: &mut GlobalDataRaceHandler,
-    ) -> InterpResult<'tcx> {
-        if self.threads[joined_thread_id].join_status == ThreadJoinStatus::Joined {
-            throw_ub_format!("trying to join an already joined thread");
-        }
-
-        if joined_thread_id == self.active_thread {
-            throw_ub_format!("trying to join itself");
-        }
-
-        // Sanity check `join_status`.
-        assert!(
-            self.threads
-                .iter()
-                .all(|thread| { !thread.state.is_blocked_on(BlockReason::Join(joined_thread_id)) }),
-            "this thread already has threads waiting for its termination"
-        );
-
-        self.join_thread(joined_thread_id, data_race_handler)
-    }
-
     /// Set the name of the given thread.
     pub fn set_thread_name(&mut self, thread: ThreadId, new_thread_name: Vec<u8>) {
         self.threads[thread].thread_name = Some(new_thread_name);
@@ -1114,20 +1032,102 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
         this.machine.threads.detach_thread(thread_id, allow_terminated_joined)
     }
 
-    #[inline]
-    fn join_thread(&mut self, joined_thread_id: ThreadId) -> InterpResult<'tcx> {
+    /// Mark that the active thread tries to join the thread with `joined_thread_id`.
+    ///
+    /// When the join is successful (immediately, or as soon as the joined thread finishes), `success_retval` will be written to `return_dest`.
+    fn join_thread(
+        &mut self,
+        joined_thread_id: ThreadId,
+        success_retval: Scalar,
+        return_dest: &MPlaceTy<'tcx>,
+    ) -> InterpResult<'tcx> {
         let this = self.eval_context_mut();
-        this.machine.threads.join_thread(joined_thread_id, &mut this.machine.data_race)?;
+        let thread_mgr = &mut this.machine.threads;
+        if thread_mgr.threads[joined_thread_id].join_status == ThreadJoinStatus::Detached {
+            // On Windows this corresponds to joining on a closed handle.
+            throw_ub_format!("trying to join a detached thread");
+        }
+
+        fn after_join<'tcx>(
+            this: &mut InterpCx<'tcx, MiriMachine<'tcx>>,
+            joined_thread_id: ThreadId,
+            success_retval: Scalar,
+            return_dest: &MPlaceTy<'tcx>,
+        ) -> InterpResult<'tcx> {
+            let threads = &this.machine.threads;
+            match &mut this.machine.data_race {
+                GlobalDataRaceHandler::None => {}
+                GlobalDataRaceHandler::Vclocks(data_race) =>
+                    data_race.thread_joined(threads, joined_thread_id),
+                GlobalDataRaceHandler::Genmc(genmc_ctx) =>
+                    genmc_ctx.handle_thread_join(threads.active_thread, joined_thread_id)?,
+            }
+            this.write_scalar(success_retval, return_dest)?;
+            interp_ok(())
+        }
+
+        // Mark the joined thread as being joined so that we detect if other
+        // threads try to join it.
+        thread_mgr.threads[joined_thread_id].join_status = ThreadJoinStatus::Joined;
+        if !thread_mgr.threads[joined_thread_id].state.is_terminated() {
+            trace!(
+                "{:?} blocked on {:?} when trying to join",
+                thread_mgr.active_thread, joined_thread_id
+            );
+            // The joined thread is still running, we need to wait for it.
+            // Once we get unblocked, perform the appropriate synchronization and write the return value.
+            let dest = return_dest.clone();
+            thread_mgr.block_thread(
+                BlockReason::Join(joined_thread_id),
+                None,
+                callback!(
+                    @capture<'tcx> {
+                        joined_thread_id: ThreadId,
+                        dest: MPlaceTy<'tcx>,
+                        success_retval: Scalar,
+                    }
+                    |this, unblock: UnblockKind| {
+                        assert_eq!(unblock, UnblockKind::Ready);
+                        after_join(this, joined_thread_id, success_retval, &dest)
+                    }
+                ),
+            );
+        } else {
+            // The thread has already terminated - establish happens-before and write the return value.
+            after_join(this, joined_thread_id, success_retval, return_dest)?;
+        }
         interp_ok(())
     }
 
-    #[inline]
-    fn join_thread_exclusive(&mut self, joined_thread_id: ThreadId) -> InterpResult<'tcx> {
+    /// Mark that the active thread tries to exclusively join the thread with `joined_thread_id`.
+    /// If the thread is already joined by another thread, it will throw UB.
+    ///
+    /// When the join is successful (immediately, or as soon as the joined thread finishes), `success_retval` will be written to `return_dest`.
+    fn join_thread_exclusive(
+        &mut self,
+        joined_thread_id: ThreadId,
+        success_retval: Scalar,
+        return_dest: &MPlaceTy<'tcx>,
+    ) -> InterpResult<'tcx> {
         let this = self.eval_context_mut();
-        this.machine
-            .threads
-            .join_thread_exclusive(joined_thread_id, &mut this.machine.data_race)?;
-        interp_ok(())
+        let threads = &this.machine.threads.threads;
+        if threads[joined_thread_id].join_status == ThreadJoinStatus::Joined {
+            throw_ub_format!("trying to join an already joined thread");
+        }
+
+        if joined_thread_id == this.machine.threads.active_thread {
+            throw_ub_format!("trying to join itself");
+        }
+
+        // Sanity check `join_status`.
+        assert!(
+            threads
+                .iter()
+                .all(|thread| { !thread.state.is_blocked_on(BlockReason::Join(joined_thread_id)) }),
+            "this thread already has threads waiting for its termination"
+        );
+
+        this.join_thread(joined_thread_id, success_retval, return_dest)
     }
 
     #[inline]
diff --git a/src/tools/miri/src/shims/unix/foreign_items.rs b/src/tools/miri/src/shims/unix/foreign_items.rs
index 9106ef94c43..f34b95e730b 100644
--- a/src/tools/miri/src/shims/unix/foreign_items.rs
+++ b/src/tools/miri/src/shims/unix/foreign_items.rs
@@ -946,8 +946,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
             }
             "pthread_join" => {
                 let [thread, retval] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
-                let res = this.pthread_join(thread, retval)?;
-                this.write_scalar(res, dest)?;
+                this.pthread_join(thread, retval, dest)?;
             }
             "pthread_detach" => {
                 let [thread] = this.check_shim(abi, CanonAbi::C, link_name, args)?;
diff --git a/src/tools/miri/src/shims/unix/thread.rs b/src/tools/miri/src/shims/unix/thread.rs
index 4b6615b3ea8..a438e71a41d 100644
--- a/src/tools/miri/src/shims/unix/thread.rs
+++ b/src/tools/miri/src/shims/unix/thread.rs
@@ -41,7 +41,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
         &mut self,
         thread: &OpTy<'tcx>,
         retval: &OpTy<'tcx>,
-    ) -> InterpResult<'tcx, Scalar> {
+        return_dest: &MPlaceTy<'tcx>,
+    ) -> InterpResult<'tcx> {
         let this = self.eval_context_mut();
 
         if !this.ptr_is_null(this.read_pointer(retval)?)? {
@@ -51,12 +52,15 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
 
         let thread = this.read_scalar(thread)?.to_int(this.libc_ty_layout("pthread_t").size)?;
         let Ok(thread) = this.thread_id_try_from(thread) else {
-            return interp_ok(this.eval_libc("ESRCH"));
+            this.write_scalar(this.eval_libc("ESRCH"), return_dest)?;
+            return interp_ok(());
         };
 
-        this.join_thread_exclusive(thread)?;
-
-        interp_ok(Scalar::from_u32(0))
+        this.join_thread_exclusive(
+            thread,
+            /* success_retval */ Scalar::from_u32(0),
+            return_dest,
+        )
     }
 
     fn pthread_detach(&mut self, thread: &OpTy<'tcx>) -> InterpResult<'tcx, Scalar> {
diff --git a/src/tools/miri/src/shims/windows/foreign_items.rs b/src/tools/miri/src/shims/windows/foreign_items.rs
index 10f6df67ad4..de10357f5fa 100644
--- a/src/tools/miri/src/shims/windows/foreign_items.rs
+++ b/src/tools/miri/src/shims/windows/foreign_items.rs
@@ -573,8 +573,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
             "WaitForSingleObject" => {
                 let [handle, timeout] = this.check_shim(abi, sys_conv, link_name, args)?;
 
-                let ret = this.WaitForSingleObject(handle, timeout)?;
-                this.write_scalar(ret, dest)?;
+                this.WaitForSingleObject(handle, timeout, dest)?;
             }
             "GetCurrentProcess" => {
                 let [] = this.check_shim(abi, sys_conv, link_name, args)?;
diff --git a/src/tools/miri/src/shims/windows/thread.rs b/src/tools/miri/src/shims/windows/thread.rs
index d5f9ed4e968..981742391b9 100644
--- a/src/tools/miri/src/shims/windows/thread.rs
+++ b/src/tools/miri/src/shims/windows/thread.rs
@@ -59,13 +59,14 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
         &mut self,
         handle_op: &OpTy<'tcx>,
         timeout_op: &OpTy<'tcx>,
-    ) -> InterpResult<'tcx, Scalar> {
+        return_dest: &MPlaceTy<'tcx>,
+    ) -> InterpResult<'tcx> {
         let this = self.eval_context_mut();
 
         let handle = this.read_handle(handle_op, "WaitForSingleObject")?;
         let timeout = this.read_scalar(timeout_op)?.to_u32()?;
 
-        let thread = match handle {
+        let joined_thread_id = match handle {
             Handle::Thread(thread) => thread,
             // Unlike on posix, the outcome of joining the current thread is not documented.
             // On current Windows, it just deadlocks.
@@ -77,8 +78,12 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
             throw_unsup_format!("`WaitForSingleObject` with non-infinite timeout");
         }
 
-        this.join_thread(thread)?;
+        this.join_thread(
+            joined_thread_id,
+            /* success_retval */ this.eval_windows("c", "WAIT_OBJECT_0"),
+            return_dest,
+        )?;
 
-        interp_ok(this.eval_windows("c", "WAIT_OBJECT_0"))
+        interp_ok(())
     }
 }