about summary refs log tree commit diff
diff options
context:
space:
mode:
authorDrMeepster <19316085+DrMeepster@users.noreply.github.com>2022-10-30 14:56:49 -0700
committerDrMeepster <19316085+DrMeepster@users.noreply.github.com>2022-11-05 14:47:51 -0700
commita2f7e8497e12e6d79d7963729ef6746b912a32eb (patch)
tree412d4e7979790ec7519e81bc6c566ceecb3f02aa
parenta1d94d43ac876975e15b5069da1f9b7a81e29733 (diff)
downloadrust-a2f7e8497e12e6d79d7963729ef6746b912a32eb.tar.gz
rust-a2f7e8497e12e6d79d7963729ef6746b912a32eb.zip
use enum for condvar locks
-rw-r--r--src/tools/miri/src/concurrency/sync.rs24
-rw-r--r--src/tools/miri/src/shims/unix/sync.rs23
-rw-r--r--src/tools/miri/src/shims/windows/sync.rs73
3 files changed, 76 insertions, 44 deletions
diff --git a/src/tools/miri/src/concurrency/sync.rs b/src/tools/miri/src/concurrency/sync.rs
index 48f9e605276..ba5ae852c5a 100644
--- a/src/tools/miri/src/concurrency/sync.rs
+++ b/src/tools/miri/src/concurrency/sync.rs
@@ -116,15 +116,25 @@ struct RwLock {
 
 declare_id!(CondvarId);
 
+#[derive(Debug, Copy, Clone)]
+pub enum RwLockMode {
+    Read,
+    Write,
+}
+
+#[derive(Debug)]
+pub enum CondvarLock {
+    Mutex(MutexId),
+    RwLock { id: RwLockId, mode: RwLockMode },
+}
+
 /// A thread waiting on a conditional variable.
 #[derive(Debug)]
 struct CondvarWaiter {
     /// The thread that is waiting on this variable.
     thread: ThreadId,
     /// The mutex or rwlock on which the thread is waiting.
-    lock: u32,
-    /// If the lock is shared or exclusive
-    shared: bool,
+    lock: CondvarLock,
 }
 
 /// The conditional variable state.
@@ -571,16 +581,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
     }
 
     /// Mark that the thread is waiting on the conditional variable.
-    fn condvar_wait(&mut self, id: CondvarId, thread: ThreadId, lock: u32, shared: bool) {
+    fn condvar_wait(&mut self, id: CondvarId, thread: ThreadId, lock: CondvarLock) {
         let this = self.eval_context_mut();
         let waiters = &mut this.machine.threads.sync.condvars[id].waiters;
         assert!(waiters.iter().all(|waiter| waiter.thread != thread), "thread is already waiting");
-        waiters.push_back(CondvarWaiter { thread, lock, shared });
+        waiters.push_back(CondvarWaiter { thread, lock });
     }
 
     /// Wake up some thread (if there is any) sleeping on the conditional
     /// variable.
-    fn condvar_signal(&mut self, id: CondvarId) -> Option<(ThreadId, u32, bool)> {
+    fn condvar_signal(&mut self, id: CondvarId) -> Option<(ThreadId, CondvarLock)> {
         let this = self.eval_context_mut();
         let current_thread = this.get_active_thread();
         let condvar = &mut this.machine.threads.sync.condvars[id];
@@ -594,7 +604,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
             if let Some(data_race) = data_race {
                 data_race.validate_lock_acquire(&condvar.data_race, waiter.thread);
             }
-            (waiter.thread, waiter.lock, waiter.shared)
+            (waiter.thread, waiter.lock)
         })
     }
 
diff --git a/src/tools/miri/src/shims/unix/sync.rs b/src/tools/miri/src/shims/unix/sync.rs
index d24e1a56bd5..a7275646847 100644
--- a/src/tools/miri/src/shims/unix/sync.rs
+++ b/src/tools/miri/src/shims/unix/sync.rs
@@ -3,6 +3,7 @@ use std::time::SystemTime;
 use rustc_hir::LangItem;
 use rustc_middle::ty::{layout::TyAndLayout, query::TyCtxtAt, Ty};
 
+use crate::concurrency::sync::CondvarLock;
 use crate::concurrency::thread::{MachineCallback, Time};
 use crate::*;
 
@@ -696,9 +697,12 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
     fn pthread_cond_signal(&mut self, cond_op: &OpTy<'tcx, Provenance>) -> InterpResult<'tcx, i32> {
         let this = self.eval_context_mut();
         let id = this.condvar_get_or_create_id(cond_op, CONDVAR_ID_OFFSET)?;
-        if let Some((thread, mutex, shared)) = this.condvar_signal(id) {
-            assert!(!shared);
-            post_cond_signal(this, thread, MutexId::from_u32(mutex))?;
+        if let Some((thread, lock)) = this.condvar_signal(id) {
+            if let CondvarLock::Mutex(mutex) = lock {
+                post_cond_signal(this, thread, mutex)?;
+            } else {
+                panic!("condvar should not have an rwlock on unix");
+            }
         }
 
         Ok(0)
@@ -711,9 +715,12 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         let this = self.eval_context_mut();
         let id = this.condvar_get_or_create_id(cond_op, CONDVAR_ID_OFFSET)?;
 
-        while let Some((thread, mutex, shared)) = this.condvar_signal(id) {
-            assert!(!shared);
-            post_cond_signal(this, thread, MutexId::from_u32(mutex))?;
+        while let Some((thread, lock)) = this.condvar_signal(id) {
+            if let CondvarLock::Mutex(mutex) = lock {
+                post_cond_signal(this, thread, mutex)?;
+            } else {
+                panic!("condvar should not have an rwlock on unix");
+            }
         }
 
         Ok(0)
@@ -731,7 +738,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         let active_thread = this.get_active_thread();
 
         release_cond_mutex_and_block(this, active_thread, mutex_id)?;
-        this.condvar_wait(id, active_thread, mutex_id.to_u32(), false);
+        this.condvar_wait(id, active_thread, CondvarLock::Mutex(mutex_id));
 
         Ok(0)
     }
@@ -770,7 +777,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         };
 
         release_cond_mutex_and_block(this, active_thread, mutex_id)?;
-        this.condvar_wait(id, active_thread, mutex_id.to_u32(), false);
+        this.condvar_wait(id, active_thread, CondvarLock::Mutex(mutex_id));
 
         // We return success for now and override it in the timeout callback.
         this.write_scalar(Scalar::from_i32(0), dest)?;
diff --git a/src/tools/miri/src/shims/windows/sync.rs b/src/tools/miri/src/shims/windows/sync.rs
index 2eab1794c4f..a34d142f4a5 100644
--- a/src/tools/miri/src/shims/windows/sync.rs
+++ b/src/tools/miri/src/shims/windows/sync.rs
@@ -3,6 +3,7 @@ use std::time::Duration;
 use rustc_target::abi::Size;
 
 use crate::concurrency::init_once::InitOnceStatus;
+use crate::concurrency::sync::{CondvarLock, RwLockMode};
 use crate::concurrency::thread::MachineCallback;
 use crate::*;
 
@@ -18,23 +19,24 @@ pub trait EvalContextExtPriv<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tc
         &mut self,
         thread: ThreadId,
         lock: RwLockId,
-        shared: bool,
+        mode: RwLockMode,
     ) -> InterpResult<'tcx> {
         let this = self.eval_context_mut();
         this.unblock_thread(thread);
 
-        if shared {
-            if this.rwlock_is_locked(lock) {
-                this.rwlock_enqueue_and_block_reader(lock, thread);
-            } else {
-                this.rwlock_reader_lock(lock, thread);
-            }
-        } else {
-            if this.rwlock_is_write_locked(lock) {
-                this.rwlock_enqueue_and_block_writer(lock, thread);
-            } else {
-                this.rwlock_writer_lock(lock, thread);
-            }
+        match mode {
+            RwLockMode::Read =>
+                if this.rwlock_is_locked(lock) {
+                    this.rwlock_enqueue_and_block_reader(lock, thread);
+                } else {
+                    this.rwlock_reader_lock(lock, thread);
+                },
+            RwLockMode::Write =>
+                if this.rwlock_is_write_locked(lock) {
+                    this.rwlock_enqueue_and_block_writer(lock, thread);
+                } else {
+                    this.rwlock_writer_lock(lock, thread);
+                },
         }
 
         Ok(())
@@ -383,14 +385,19 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         };
 
         let shared_mode = 0x1; // CONDITION_VARIABLE_LOCKMODE_SHARED is not in std
-        let shared = flags == shared_mode;
+        let mode = if flags == 0 {
+            RwLockMode::Write
+        } else if flags == shared_mode {
+            RwLockMode::Read
+        } else {
+            throw_unsup_format!("unsupported `Flags` {flags} in `SleepConditionVariableSRW`");
+        };
 
         let active_thread = this.get_active_thread();
 
-        let was_locked = if shared {
-            this.rwlock_reader_unlock(lock_id, active_thread)
-        } else {
-            this.rwlock_writer_unlock(lock_id, active_thread)
+        let was_locked = match mode {
+            RwLockMode::Read => this.rwlock_reader_unlock(lock_id, active_thread),
+            RwLockMode::Write => this.rwlock_writer_unlock(lock_id, active_thread),
         };
 
         if !was_locked {
@@ -400,27 +407,27 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         }
 
         this.block_thread(active_thread);
-        this.condvar_wait(condvar_id, active_thread, lock_id.to_u32(), shared);
+        this.condvar_wait(condvar_id, active_thread, CondvarLock::RwLock { id: lock_id, mode });
 
         if let Some(timeout_time) = timeout_time {
             struct Callback<'tcx> {
                 thread: ThreadId,
                 condvar_id: CondvarId,
                 lock_id: RwLockId,
-                shared: bool,
+                mode: RwLockMode,
                 dest: PlaceTy<'tcx, Provenance>,
             }
 
             impl<'tcx> VisitTags for Callback<'tcx> {
                 fn visit_tags(&self, visit: &mut dyn FnMut(SbTag)) {
-                    let Callback { thread: _, condvar_id: _, lock_id: _, shared: _, dest } = self;
+                    let Callback { thread: _, condvar_id: _, lock_id: _, mode: _, dest } = self;
                     dest.visit_tags(visit);
                 }
             }
 
             impl<'mir, 'tcx: 'mir> MachineCallback<'mir, 'tcx> for Callback<'tcx> {
                 fn call(&self, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
-                    this.reacquire_cond_lock(self.thread, self.lock_id, self.shared)?;
+                    this.reacquire_cond_lock(self.thread, self.lock_id, self.mode)?;
 
                     this.condvar_remove_waiter(self.condvar_id, self.thread);
 
@@ -438,7 +445,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
                     thread: active_thread,
                     condvar_id,
                     lock_id,
-                    shared,
+                    mode,
                     dest: dest.clone(),
                 }),
             );
@@ -451,9 +458,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         let this = self.eval_context_mut();
         let condvar_id = this.condvar_get_or_create_id(condvar_op, CONDVAR_ID_OFFSET)?;
 
-        if let Some((thread, lock, shared)) = this.condvar_signal(condvar_id) {
-            this.reacquire_cond_lock(thread, RwLockId::from_u32(lock), shared)?;
-            this.unregister_timeout_callback_if_exists(thread);
+        if let Some((thread, lock)) = this.condvar_signal(condvar_id) {
+            if let CondvarLock::RwLock { id, mode } = lock {
+                this.reacquire_cond_lock(thread, id, mode)?;
+                this.unregister_timeout_callback_if_exists(thread);
+            } else {
+                panic!("mutexes should not exist on windows");
+            }
         }
 
         Ok(())
@@ -466,9 +477,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
         let this = self.eval_context_mut();
         let condvar_id = this.condvar_get_or_create_id(condvar_op, CONDVAR_ID_OFFSET)?;
 
-        while let Some((thread, lock, shared)) = this.condvar_signal(condvar_id) {
-            this.reacquire_cond_lock(thread, RwLockId::from_u32(lock), shared)?;
-            this.unregister_timeout_callback_if_exists(thread);
+        while let Some((thread, lock)) = this.condvar_signal(condvar_id) {
+            if let CondvarLock::RwLock { id, mode } = lock {
+                this.reacquire_cond_lock(thread, id, mode)?;
+                this.unregister_timeout_callback_if_exists(thread);
+            } else {
+                panic!("mutexes should not exist on windows");
+            }
         }
 
         Ok(())