about summary refs log tree commit diff
path: root/src/libstd/sync/task_pool.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/libstd/sync/task_pool.rs')
-rw-r--r--src/libstd/sync/task_pool.rs230
1 files changed, 167 insertions, 63 deletions
diff --git a/src/libstd/sync/task_pool.rs b/src/libstd/sync/task_pool.rs
index d4a60fb5844..2682582d708 100644
--- a/src/libstd/sync/task_pool.rs
+++ b/src/libstd/sync/task_pool.rs
@@ -1,4 +1,4 @@
-// Copyright 2012 The Rust Project Developers. See the COPYRIGHT
+// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
 // file at the top-level directory of this distribution and at
 // http://rust-lang.org/COPYRIGHT.
 //
@@ -12,91 +12,195 @@
 
 use core::prelude::*;
 
-use task;
 use task::spawn;
-use vec::Vec;
-use comm::{channel, Sender};
+use comm::{channel, Sender, Receiver};
+use sync::{Arc, Mutex};
 
-enum Msg<T> {
-    Execute(proc(&T):Send),
-    Quit
+struct Sentinel<'a> {
+    jobs: &'a Arc<Mutex<Receiver<proc(): Send>>>,
+    active: bool
 }
 
-/// A task pool used to execute functions in parallel.
-pub struct TaskPool<T> {
-    channels: Vec<Sender<Msg<T>>>,
-    next_index: uint,
+impl<'a> Sentinel<'a> {
+    fn new(jobs: &Arc<Mutex<Receiver<proc(): Send>>>) -> Sentinel {
+        Sentinel {
+            jobs: jobs,
+            active: true
+        }
+    }
+
+    // Cancel and destroy this sentinel.
+    fn cancel(mut self) {
+        self.active = false;
+    }
 }
 
 #[unsafe_destructor]
-impl<T> Drop for TaskPool<T> {
+impl<'a> Drop for Sentinel<'a> {
     fn drop(&mut self) {
-        for channel in self.channels.iter_mut() {
-            channel.send(Quit);
+        if self.active {
+            spawn_in_pool(self.jobs.clone())
         }
     }
 }
 
-impl<T> TaskPool<T> {
-    /// Spawns a new task pool with `n_tasks` tasks. The provided
-    /// `init_fn_factory` returns a function which, given the index of the
-    /// task, should return local data to be kept around in that task.
+/// A task pool used to execute functions in parallel.
+///
+/// Spawns `n` worker tasks and replenishes the pool if any worker tasks
+/// panic.
+///
+/// # Example
+///
+/// ```rust
+/// # use sync::TaskPool;
+/// # use iter::AdditiveIterator;
+///
+/// let pool = TaskPool::new(4u);
+///
+/// let (tx, rx) = channel();
+/// for _ in range(0, 8u) {
+///     let tx = tx.clone();
+///     pool.execute(proc() {
+///         tx.send(1u);
+///     });
+/// }
+///
+/// assert_eq!(rx.iter().take(8u).sum(), 8u);
+/// ```
+pub struct TaskPool {
+    // How the taskpool communicates with subtasks.
+    //
+    // This is the only such Sender, so when it is dropped all subtasks will
+    // quit.
+    jobs: Sender<proc(): Send>
+}
+
+impl TaskPool {
+    /// Spawns a new task pool with `tasks` tasks.
     ///
     /// # Panics
     ///
-    /// This function will panic if `n_tasks` is less than 1.
-    pub fn new(n_tasks: uint,
-               init_fn_factory: || -> proc(uint):Send -> T)
-               -> TaskPool<T> {
-        assert!(n_tasks >= 1);
-
-        let channels = Vec::from_fn(n_tasks, |i| {
-            let (tx, rx) = channel::<Msg<T>>();
-            let init_fn = init_fn_factory();
-
-            let task_body = proc() {
-                let local_data = init_fn(i);
-                loop {
-                    match rx.recv() {
-                        Execute(f) => f(&local_data),
-                        Quit => break
-                    }
-                }
-            };
+    /// This function will panic if `tasks` is 0.
+    pub fn new(tasks: uint) -> TaskPool {
+        assert!(tasks >= 1);
 
-            // Run on this scheduler.
-            task::spawn(task_body);
+        let (tx, rx) = channel::<proc(): Send>();
+        let rx = Arc::new(Mutex::new(rx));
 
-            tx
-        });
+        // Taskpool tasks.
+        for _ in range(0, tasks) {
+            spawn_in_pool(rx.clone());
+        }
 
-        return TaskPool {
-            channels: channels,
-            next_index: 0,
-        };
+        TaskPool { jobs: tx }
     }
 
-    /// Executes the function `f` on a task in the pool. The function
-    /// receives a reference to the local data returned by the `init_fn`.
-    pub fn execute(&mut self, f: proc(&T):Send) {
-        self.channels[self.next_index].send(Execute(f));
-        self.next_index += 1;
-        if self.next_index == self.channels.len() { self.next_index = 0; }
+    /// Executes the function `job` on a task in the pool.
+    pub fn execute(&self, job: proc():Send) {
+        self.jobs.send(job);
     }
 }
 
-#[test]
-fn test_task_pool() {
-    let f: || -> proc(uint):Send -> uint = || { proc(i) i };
-    let mut pool = TaskPool::new(4, f);
-    for _ in range(0u, 8) {
-        pool.execute(proc(i) println!("Hello from thread {}!", *i));
-    }
+fn spawn_in_pool(jobs: Arc<Mutex<Receiver<proc(): Send>>>) {
+    spawn(proc() {
+        // Will spawn a new task on panic unless it is cancelled.
+        let sentinel = Sentinel::new(&jobs);
+
+        loop {
+            let message = {
+                // Only lock jobs for the time it takes
+                // to get a job, not run it.
+                let lock = jobs.lock();
+                lock.recv_opt()
+            };
+
+            match message {
+                Ok(job) => job(),
+
+                // The Taskpool was dropped.
+                Err(..) => break
+            }
+        }
+
+        sentinel.cancel();
+    })
 }
 
-#[test]
-#[should_fail]
-fn test_zero_tasks_panic() {
-    let f: || -> proc(uint):Send -> uint = || { proc(i) i };
-    TaskPool::new(0, f);
+#[cfg(test)]
+mod test {
+    use core::prelude::*;
+    use super::*;
+    use comm::channel;
+    use iter::range;
+
+    const TEST_TASKS: uint = 4u;
+
+    #[test]
+    fn test_works() {
+        use iter::AdditiveIterator;
+
+        let pool = TaskPool::new(TEST_TASKS);
+
+        let (tx, rx) = channel();
+        for _ in range(0, TEST_TASKS) {
+            let tx = tx.clone();
+            pool.execute(proc() {
+                tx.send(1u);
+            });
+        }
+
+        assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
+    }
+
+    #[test]
+    #[should_fail]
+    fn test_zero_tasks_panic() {
+        TaskPool::new(0);
+    }
+
+    #[test]
+    fn test_recovery_from_subtask_panic() {
+        use iter::AdditiveIterator;
+
+        let pool = TaskPool::new(TEST_TASKS);
+
+        // Panic all the existing tasks.
+        for _ in range(0, TEST_TASKS) {
+            pool.execute(proc() { panic!() });
+        }
+
+        // Ensure new tasks were spawned to compensate.
+        let (tx, rx) = channel();
+        for _ in range(0, TEST_TASKS) {
+            let tx = tx.clone();
+            pool.execute(proc() {
+                tx.send(1u);
+            });
+        }
+
+        assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
+    }
+
+    #[test]
+    fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
+        use sync::{Arc, Barrier};
+
+        let pool = TaskPool::new(TEST_TASKS);
+        let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
+
+        // Panic all the existing tasks in a bit.
+        for _ in range(0, TEST_TASKS) {
+            let waiter = waiter.clone();
+            pool.execute(proc() {
+                waiter.wait();
+                panic!();
+            });
+        }
+
+        drop(pool);
+
+        // Kick off the failure.
+        waiter.wait();
+    }
 }
+