about summary refs log tree commit diff
path: root/compiler/rustc_thread_pool/src/scope/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_thread_pool/src/scope/mod.rs')
-rw-r--r--compiler/rustc_thread_pool/src/scope/mod.rs76
1 files changed, 57 insertions, 19 deletions
diff --git a/compiler/rustc_thread_pool/src/scope/mod.rs b/compiler/rustc_thread_pool/src/scope/mod.rs
index 55e58b3509d..b6601d0cbcc 100644
--- a/compiler/rustc_thread_pool/src/scope/mod.rs
+++ b/compiler/rustc_thread_pool/src/scope/mod.rs
@@ -8,12 +8,14 @@
 use std::any::Any;
 use std::marker::PhantomData;
 use std::mem::ManuallyDrop;
-use std::sync::Arc;
 use std::sync::atomic::{AtomicPtr, Ordering};
+use std::sync::{Arc, Mutex};
 use std::{fmt, ptr};
 
+use indexmap::IndexSet;
+
 use crate::broadcast::BroadcastContext;
-use crate::job::{ArcJob, HeapJob, JobFifo, JobRef};
+use crate::job::{ArcJob, HeapJob, JobFifo, JobRef, JobRefId};
 use crate::latch::{CountLatch, Latch};
 use crate::registry::{Registry, WorkerThread, global_registry, in_worker};
 use crate::tlv::{self, Tlv};
@@ -52,6 +54,12 @@ struct ScopeBase<'scope> {
     /// latch to track job counts
     job_completed_latch: CountLatch,
 
+    /// Jobs that have been spawned, but not yet started.
+    pending_jobs: Mutex<IndexSet<JobRefId>>,
+
+    /// The worker which will wait on scope completion, if any.
+    worker: Option<usize>,
+
     /// You can think of a scope as containing a list of closures to execute,
     /// all of which outlive `'scope`. They're not actually required to be
     /// `Sync`, but it's still safe to let the `Scope` implement `Sync` because
@@ -525,13 +533,19 @@ impl<'scope> Scope<'scope> {
         BODY: FnOnce(&Scope<'scope>) + Send + 'scope,
     {
         let scope_ptr = ScopePtr(self);
-        let job = HeapJob::new(self.base.tlv, move || unsafe {
+        let job = HeapJob::new(self.base.tlv, move |id| unsafe {
             // SAFETY: this job will execute before the scope ends.
             let scope = scope_ptr.as_ref();
+
+            // Mark this job is started.
+            scope.base.pending_jobs.lock().unwrap().swap_remove_full(&id);
+
             ScopeBase::execute_job(&scope.base, move || body(scope))
         });
         let job_ref = self.base.heap_job_ref(job);
 
+        // Mark this job as pending.
+        self.base.pending_jobs.lock().unwrap().insert(job_ref.id());
         // Since `Scope` implements `Sync`, we can't be sure that we're still in a
         // thread of this pool, so we can't just push to the local worker thread.
         // Also, this might be an in-place scope.
@@ -547,10 +561,17 @@ impl<'scope> Scope<'scope> {
         BODY: Fn(&Scope<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope,
     {
         let scope_ptr = ScopePtr(self);
-        let job = ArcJob::new(move || unsafe {
+        let job = ArcJob::new(move |id| unsafe {
             // SAFETY: this job will execute before the scope ends.
             let scope = scope_ptr.as_ref();
             let body = &body;
+
+            let current_index = WorkerThread::current().as_ref().map(|worker| worker.index());
+            if current_index == scope.base.worker {
+                // Mark this job as started on the scope's worker thread.
+                scope.base.pending_jobs.lock().unwrap().swap_remove(&id);
+            }
+
             let func = move || BroadcastContext::with(move |ctx| body(scope, ctx));
             ScopeBase::execute_job(&scope.base, func)
         });
@@ -585,23 +606,24 @@ impl<'scope> ScopeFifo<'scope> {
         BODY: FnOnce(&ScopeFifo<'scope>) + Send + 'scope,
     {
         let scope_ptr = ScopePtr(self);
-        let job = HeapJob::new(self.base.tlv, move || unsafe {
+        let job = HeapJob::new(self.base.tlv, move |id| unsafe {
             // SAFETY: this job will execute before the scope ends.
             let scope = scope_ptr.as_ref();
+
+            // Mark this job is started.
+            scope.base.pending_jobs.lock().unwrap().swap_remove(&id);
+
             ScopeBase::execute_job(&scope.base, move || body(scope))
         });
         let job_ref = self.base.heap_job_ref(job);
 
-        // If we're in the pool, use our scope's private fifo for this thread to execute
-        // in a locally-FIFO order. Otherwise, just use the pool's global injector.
-        match self.base.registry.current_thread() {
-            Some(worker) => {
-                let fifo = &self.fifos[worker.index()];
-                // SAFETY: this job will execute before the scope ends.
-                unsafe { worker.push(fifo.push(job_ref)) };
-            }
-            None => self.base.registry.inject(job_ref),
-        }
+        // Mark this job as pending.
+        self.base.pending_jobs.lock().unwrap().insert(job_ref.id());
+
+        // Since `ScopeFifo` implements `Sync`, we can't be sure that we're still in a
+        // thread of this pool, so we can't just push to the local worker thread.
+        // Also, this might be an in-place scope.
+        self.base.registry.inject_or_push(job_ref);
     }
 
     /// Spawns a job into every thread of the fork-join scope `self`. This job will
@@ -613,9 +635,15 @@ impl<'scope> ScopeFifo<'scope> {
         BODY: Fn(&ScopeFifo<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope,
     {
         let scope_ptr = ScopePtr(self);
-        let job = ArcJob::new(move || unsafe {
+        let job = ArcJob::new(move |id| unsafe {
             // SAFETY: this job will execute before the scope ends.
             let scope = scope_ptr.as_ref();
+
+            let current_index = WorkerThread::current().as_ref().map(|worker| worker.index());
+            if current_index == scope.base.worker {
+                // Mark this job as started on the scope's worker thread.
+                scope.base.pending_jobs.lock().unwrap().swap_remove(&id);
+            }
             let body = &body;
             let func = move || BroadcastContext::with(move |ctx| body(scope, ctx));
             ScopeBase::execute_job(&scope.base, func)
@@ -636,6 +664,8 @@ impl<'scope> ScopeBase<'scope> {
             registry: Arc::clone(registry),
             panic: AtomicPtr::new(ptr::null_mut()),
             job_completed_latch: CountLatch::new(owner),
+            pending_jobs: Mutex::new(IndexSet::new()),
+            worker: owner.map(|w| w.index()),
             marker: PhantomData,
             tlv: tlv::get(),
         }
@@ -643,7 +673,7 @@ impl<'scope> ScopeBase<'scope> {
 
     fn heap_job_ref<FUNC>(&self, job: Box<HeapJob<FUNC>>) -> JobRef
     where
-        FUNC: FnOnce() + Send + 'scope,
+        FUNC: FnOnce(JobRefId) + Send + 'scope,
     {
         unsafe {
             self.job_completed_latch.increment();
@@ -653,8 +683,12 @@ impl<'scope> ScopeBase<'scope> {
 
     fn inject_broadcast<FUNC>(&self, job: Arc<ArcJob<FUNC>>)
     where
-        FUNC: Fn() + Send + Sync + 'scope,
+        FUNC: Fn(JobRefId) + Send + Sync + 'scope,
     {
+        if self.worker.is_some() {
+            let id = unsafe { ArcJob::as_job_ref(&job).id() };
+            self.pending_jobs.lock().unwrap().insert(id);
+        }
         let n_threads = self.registry.num_threads();
         let job_refs = (0..n_threads).map(|_| unsafe {
             self.job_completed_latch.increment();
@@ -671,7 +705,11 @@ impl<'scope> ScopeBase<'scope> {
         FUNC: FnOnce() -> R,
     {
         let result = unsafe { Self::execute_job_closure(self, func) };
-        self.job_completed_latch.wait(owner);
+        self.job_completed_latch.wait(
+            owner,
+            || self.pending_jobs.lock().unwrap().is_empty(),
+            |job| self.pending_jobs.lock().unwrap().contains(&job.id()),
+        );
 
         // Restore the TLV if we ran some jobs while waiting
         tlv::set(self.tlv);