diff options
Diffstat (limited to 'src/libstd/sync/task_pool.rs')
| -rw-r--r-- | src/libstd/sync/task_pool.rs | 230 |
1 files changed, 167 insertions, 63 deletions
diff --git a/src/libstd/sync/task_pool.rs b/src/libstd/sync/task_pool.rs index d4a60fb5844..2682582d708 100644 --- a/src/libstd/sync/task_pool.rs +++ b/src/libstd/sync/task_pool.rs @@ -1,4 +1,4 @@ -// Copyright 2012 The Rust Project Developers. See the COPYRIGHT +// Copyright 2014 The Rust Project Developers. See the COPYRIGHT // file at the top-level directory of this distribution and at // http://rust-lang.org/COPYRIGHT. // @@ -12,91 +12,195 @@ use core::prelude::*; -use task; use task::spawn; -use vec::Vec; -use comm::{channel, Sender}; +use comm::{channel, Sender, Receiver}; +use sync::{Arc, Mutex}; -enum Msg<T> { - Execute(proc(&T):Send), - Quit +struct Sentinel<'a> { + jobs: &'a Arc<Mutex<Receiver<proc(): Send>>>, + active: bool } -/// A task pool used to execute functions in parallel. -pub struct TaskPool<T> { - channels: Vec<Sender<Msg<T>>>, - next_index: uint, +impl<'a> Sentinel<'a> { + fn new(jobs: &Arc<Mutex<Receiver<proc(): Send>>>) -> Sentinel { + Sentinel { + jobs: jobs, + active: true + } + } + + // Cancel and destroy this sentinel. + fn cancel(mut self) { + self.active = false; + } } #[unsafe_destructor] -impl<T> Drop for TaskPool<T> { +impl<'a> Drop for Sentinel<'a> { fn drop(&mut self) { - for channel in self.channels.iter_mut() { - channel.send(Quit); + if self.active { + spawn_in_pool(self.jobs.clone()) } } } -impl<T> TaskPool<T> { - /// Spawns a new task pool with `n_tasks` tasks. The provided - /// `init_fn_factory` returns a function which, given the index of the - /// task, should return local data to be kept around in that task. +/// A task pool used to execute functions in parallel. +/// +/// Spawns `n` worker tasks and replenishes the pool if any worker tasks +/// panic. +/// +/// # Example +/// +/// ```rust +/// # use sync::TaskPool; +/// # use iter::AdditiveIterator; +/// +/// let pool = TaskPool::new(4u); +/// +/// let (tx, rx) = channel(); +/// for _ in range(0, 8u) { +/// let tx = tx.clone(); +/// pool.execute(proc() { +/// tx.send(1u); +/// }); +/// } +/// +/// assert_eq!(rx.iter().take(8u).sum(), 8u); +/// ``` +pub struct TaskPool { + // How the taskpool communicates with subtasks. + // + // This is the only such Sender, so when it is dropped all subtasks will + // quit. + jobs: Sender<proc(): Send> +} + +impl TaskPool { + /// Spawns a new task pool with `tasks` tasks. /// /// # Panics /// - /// This function will panic if `n_tasks` is less than 1. - pub fn new(n_tasks: uint, - init_fn_factory: || -> proc(uint):Send -> T) - -> TaskPool<T> { - assert!(n_tasks >= 1); - - let channels = Vec::from_fn(n_tasks, |i| { - let (tx, rx) = channel::<Msg<T>>(); - let init_fn = init_fn_factory(); - - let task_body = proc() { - let local_data = init_fn(i); - loop { - match rx.recv() { - Execute(f) => f(&local_data), - Quit => break - } - } - }; + /// This function will panic if `tasks` is 0. + pub fn new(tasks: uint) -> TaskPool { + assert!(tasks >= 1); - // Run on this scheduler. - task::spawn(task_body); + let (tx, rx) = channel::<proc(): Send>(); + let rx = Arc::new(Mutex::new(rx)); - tx - }); + // Taskpool tasks. + for _ in range(0, tasks) { + spawn_in_pool(rx.clone()); + } - return TaskPool { - channels: channels, - next_index: 0, - }; + TaskPool { jobs: tx } } - /// Executes the function `f` on a task in the pool. The function - /// receives a reference to the local data returned by the `init_fn`. - pub fn execute(&mut self, f: proc(&T):Send) { - self.channels[self.next_index].send(Execute(f)); - self.next_index += 1; - if self.next_index == self.channels.len() { self.next_index = 0; } + /// Executes the function `job` on a task in the pool. + pub fn execute(&self, job: proc():Send) { + self.jobs.send(job); } } -#[test] -fn test_task_pool() { - let f: || -> proc(uint):Send -> uint = || { proc(i) i }; - let mut pool = TaskPool::new(4, f); - for _ in range(0u, 8) { - pool.execute(proc(i) println!("Hello from thread {}!", *i)); - } +fn spawn_in_pool(jobs: Arc<Mutex<Receiver<proc(): Send>>>) { + spawn(proc() { + // Will spawn a new task on panic unless it is cancelled. + let sentinel = Sentinel::new(&jobs); + + loop { + let message = { + // Only lock jobs for the time it takes + // to get a job, not run it. + let lock = jobs.lock(); + lock.recv_opt() + }; + + match message { + Ok(job) => job(), + + // The Taskpool was dropped. + Err(..) => break + } + } + + sentinel.cancel(); + }) } -#[test] -#[should_fail] -fn test_zero_tasks_panic() { - let f: || -> proc(uint):Send -> uint = || { proc(i) i }; - TaskPool::new(0, f); +#[cfg(test)] +mod test { + use core::prelude::*; + use super::*; + use comm::channel; + use iter::range; + + const TEST_TASKS: uint = 4u; + + #[test] + fn test_works() { + use iter::AdditiveIterator; + + let pool = TaskPool::new(TEST_TASKS); + + let (tx, rx) = channel(); + for _ in range(0, TEST_TASKS) { + let tx = tx.clone(); + pool.execute(proc() { + tx.send(1u); + }); + } + + assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS); + } + + #[test] + #[should_fail] + fn test_zero_tasks_panic() { + TaskPool::new(0); + } + + #[test] + fn test_recovery_from_subtask_panic() { + use iter::AdditiveIterator; + + let pool = TaskPool::new(TEST_TASKS); + + // Panic all the existing tasks. + for _ in range(0, TEST_TASKS) { + pool.execute(proc() { panic!() }); + } + + // Ensure new tasks were spawned to compensate. + let (tx, rx) = channel(); + for _ in range(0, TEST_TASKS) { + let tx = tx.clone(); + pool.execute(proc() { + tx.send(1u); + }); + } + + assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS); + } + + #[test] + fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() { + use sync::{Arc, Barrier}; + + let pool = TaskPool::new(TEST_TASKS); + let waiter = Arc::new(Barrier::new(TEST_TASKS + 1)); + + // Panic all the existing tasks in a bit. + for _ in range(0, TEST_TASKS) { + let waiter = waiter.clone(); + pool.execute(proc() { + waiter.wait(); + panic!(); + }); + } + + drop(pool); + + // Kick off the failure. + waiter.wait(); + } } + |
