diff options
| author | Huon Wilson <dbau.pp+github@gmail.com> | 2013-10-12 02:15:22 +1100 |
|---|---|---|
| committer | Huon Wilson <dbau.pp+github@gmail.com> | 2013-10-23 10:40:06 +1100 |
| commit | 0bba73c0d156df8f22c64ef4f4c50910fe31cf31 (patch) | |
| tree | 3a0b98b2985fdf4b4717aecab31464718d7dc162 /src/libstd | |
| parent | 83aa1abb19ee290655ed03053e26cdb2662242c4 (diff) | |
| download | rust-0bba73c0d156df8f22c64ef4f4c50910fe31cf31.tar.gz rust-0bba73c0d156df8f22c64ef4f4c50910fe31cf31.zip | |
std::rand: move Weighted to distributions.
A user constructs the WeightedChoice distribution and then samples from it, which allows it to use binary search internally.
Diffstat (limited to 'src/libstd')
| -rw-r--r-- | src/libstd/rand/distributions.rs | 208 | ||||
| -rw-r--r-- | src/libstd/rand/mod.rs | 131 |
2 files changed, 207 insertions, 132 deletions
diff --git a/src/libstd/rand/distributions.rs b/src/libstd/rand/distributions.rs index 4a025ae05d7..e7bcf8ce5d3 100644 --- a/src/libstd/rand/distributions.rs +++ b/src/libstd/rand/distributions.rs @@ -20,8 +20,11 @@ that do not need to record state. */ +use iter::range; +use option::{Some, None}; use num; use rand::{Rng,Rand}; +use clone::Clone; pub use self::range::Range; @@ -61,8 +64,128 @@ impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> { } } -mod ziggurat_tables; +/// A value with a particular weight for use with `WeightedChoice`. +pub struct Weighted<T> { + /// The numerical weight of this item + weight: uint, + /// The actual item which is being weighted + item: T, +} + +/// A distribution that selects from a finite collection of weighted items. +/// +/// Each item has an associated weight that influences how likely it +/// is to be chosen: higher weight is more likely. +/// +/// The `Clone` restriction is a limitation of the `Sample` and +/// `IndepedentSample` traits. Note that `&T` is (cheaply) `Clone` for +/// all `T`, as is `uint`, so one can store references or indices into +/// another vector. +/// +/// # Example +/// +/// ```rust +/// use std::rand; +/// use std::rand::distributions::{Weighted, WeightedChoice, IndepedentSample}; +/// +/// fn main() { +/// let wc = WeightedChoice::new(~[Weighted { weight: 2, item: 'a' }, +/// Weighted { weight: 4, item: 'b' }, +/// Weighted { weight: 1, item: 'c' }]); +/// let rng = rand::task_rng(); +/// for _ in range(0, 16) { +/// // on average prints 'a' 4 times, 'b' 8 and 'c' twice. +/// println!("{}", wc.ind_sample(rng)); +/// } +/// } +/// ``` +pub struct WeightedChoice<T> { + priv items: ~[Weighted<T>], + priv weight_range: Range<uint> +} + +impl<T: Clone> WeightedChoice<T> { + /// Create a new `WeightedChoice`. + /// + /// Fails if: + /// - `v` is empty + /// - the total weight is 0 + /// - the total weight is larger than a `uint` can contain. + pub fn new(mut items: ~[Weighted<T>]) -> WeightedChoice<T> { + // strictly speaking, this is subsumed by the total weight == 0 case + assert!(!items.is_empty(), "WeightedChoice::new called with no items"); + + let mut running_total = 0u; + + // we convert the list from individual weights to cumulative + // weights so we can binary search. This *could* drop elements + // with weight == 0 as an optimisation. + for item in items.mut_iter() { + running_total = running_total.checked_add(&item.weight) + .expect("WeightedChoice::new called with a total weight larger \ + than a uint can contain"); + + item.weight = running_total; + } + assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0"); + + WeightedChoice { + items: items, + // we're likely to be generating numbers in this range + // relatively often, so might as well cache it + weight_range: Range::new(0, running_total) + } + } +} + +impl<T: Clone> Sample<T> for WeightedChoice<T> { + fn sample<R: Rng>(&mut self, rng: &mut R) -> T { self.ind_sample(rng) } +} +impl<T: Clone> IndependentSample<T> for WeightedChoice<T> { + fn ind_sample<R: Rng>(&self, rng: &mut R) -> T { + // we want to find the first element that has cumulative + // weight > sample_weight, which we do by binary since the + // cumulative weights of self.items are sorted. + + // choose a weight in [0, total_weight) + let sample_weight = self.weight_range.ind_sample(rng); + + // short circuit when it's the first item + if sample_weight < self.items[0].weight { + return self.items[0].item.clone(); + } + + let mut idx = 0; + let mut modifier = self.items.len(); + + // now we know that every possibility has an element to the + // left, so we can just search for the last element that has + // cumulative weight <= sample_weight, then the next one will + // be "it". (Note that this greatest element will never be the + // last element of the vector, since sample_weight is chosen + // in [0, total_weight) and the cumulative weight of the last + // one is exactly the total weight.) + while modifier > 1 { + let i = idx + modifier / 2; + if self.items[i].weight <= sample_weight { + // we're small, so look to the right, but allow this + // exact element still. + idx = i; + // we need the `/ 2` to round up otherwise we'll drop + // the trailing elements when `modifier` is odd. + modifier += 1; + } else { + // otherwise we're too big, so go left. (i.e. do + // nothing) + } + modifier /= 2; + } + return self.items[idx + 1].item.clone(); + } +} + +mod ziggurat_tables; /// Sample a random number using the Ziggurat method (specifically the /// ZIGNOR variant from Doornik 2005). Most of the arguments are @@ -302,6 +425,18 @@ mod tests { } } + // 0, 1, 2, 3, ... + struct CountingRng { i: u32 } + impl Rng for CountingRng { + fn next_u32(&mut self) -> u32 { + self.i += 1; + self.i - 1 + } + fn next_u64(&mut self) -> u64 { + self.next_u32() as u64 + } + } + #[test] fn test_rand_sample() { let mut rand_sample = RandSample::<ConstRand>; @@ -344,6 +479,77 @@ mod tests { fn test_exp_invalid_lambda_neg() { Exp::new(-10.0); } + + #[test] + fn test_weighted_choice() { + // this makes assumptions about the internal implementation of + // WeightedChoice, specifically: it doesn't reorder the items, + // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to + // 1, internally; modulo a modulo operation). + + macro_rules! t ( + ($items:expr, $expected:expr) => {{ + let wc = WeightedChoice::new($items); + let expected = $expected; + + let mut rng = CountingRng { i: 0 }; + + for &val in expected.iter() { + assert_eq!(wc.ind_sample(&mut rng), val) + } + }} + ); + + t!(~[Weighted { weight: 1, item: 10}], ~[10]); + + // skip some + t!(~[Weighted { weight: 0, item: 20}, + Weighted { weight: 2, item: 21}, + Weighted { weight: 0, item: 22}, + Weighted { weight: 1, item: 23}], + ~[21,21, 23]); + + // different weights + t!(~[Weighted { weight: 4, item: 30}, + Weighted { weight: 3, item: 31}], + ~[30,30,30,30, 31,31,31]); + + // check that we're binary searching + // correctly with some vectors of odd + // length. + t!(~[Weighted { weight: 1, item: 40}, + Weighted { weight: 1, item: 41}, + Weighted { weight: 1, item: 42}, + Weighted { weight: 1, item: 43}, + Weighted { weight: 1, item: 44}], + ~[40, 41, 42, 43, 44]); + t!(~[Weighted { weight: 1, item: 50}, + Weighted { weight: 1, item: 51}, + Weighted { weight: 1, item: 52}, + Weighted { weight: 1, item: 53}, + Weighted { weight: 1, item: 54}, + Weighted { weight: 1, item: 55}, + Weighted { weight: 1, item: 56}], + ~[50, 51, 52, 53, 54, 55, 56]); + } + + #[test] #[should_fail] + fn test_weighted_choice_no_items() { + WeightedChoice::<int>::new(~[]); + } + #[test] #[should_fail] + fn test_weighted_choice_zero_weight() { + WeightedChoice::new(~[Weighted { weight: 0, item: 0}, + Weighted { weight: 0, item: 1}]); + } + #[test] #[should_fail] + fn test_weighted_choice_weight_overflows() { + let x = (-1) as uint / 2; // x + x + 2 is the overflow + WeightedChoice::new(~[Weighted { weight: x, item: 0 }, + Weighted { weight: 1, item: 1 }, + Weighted { weight: x, item: 2 }, + Weighted { weight: 1, item: 3 }]); + } } #[cfg(test)] diff --git a/src/libstd/rand/mod.rs b/src/libstd/rand/mod.rs index 2d2c6e794f8..a372eb1f11a 100644 --- a/src/libstd/rand/mod.rs +++ b/src/libstd/rand/mod.rs @@ -100,14 +100,6 @@ pub trait Rand { fn rand<R: Rng>(rng: &mut R) -> Self; } -/// A value with a particular weight compared to other values -pub struct Weighted<T> { - /// The numerical weight of this item - weight: uint, - /// The actual item which is being weighted - item: T, -} - /// A random number generator pub trait Rng { /// Return the next random u32. This rarely needs to be called @@ -334,91 +326,6 @@ pub trait Rng { } } - /// Choose an item respecting the relative weights, failing if the sum of - /// the weights is 0 - /// - /// # Example - /// - /// ```rust - /// use std::rand; - /// use std::rand::Rng; - /// - /// fn main() { - /// let mut rng = rand::rng(); - /// let x = [rand::Weighted {weight: 4, item: 'a'}, - /// rand::Weighted {weight: 2, item: 'b'}, - /// rand::Weighted {weight: 2, item: 'c'}]; - /// println!("{}", rng.choose_weighted(x)); - /// } - /// ``` - fn choose_weighted<T:Clone>(&mut self, v: &[Weighted<T>]) -> T { - self.choose_weighted_option(v).expect("Rng.choose_weighted: total weight is 0") - } - - /// Choose Some(item) respecting the relative weights, returning none if - /// the sum of the weights is 0 - /// - /// # Example - /// - /// ```rust - /// use std::rand; - /// use std::rand::Rng; - /// - /// fn main() { - /// let mut rng = rand::rng(); - /// let x = [rand::Weighted {weight: 4, item: 'a'}, - /// rand::Weighted {weight: 2, item: 'b'}, - /// rand::Weighted {weight: 2, item: 'c'}]; - /// println!("{:?}", rng.choose_weighted_option(x)); - /// } - /// ``` - fn choose_weighted_option<T:Clone>(&mut self, v: &[Weighted<T>]) - -> Option<T> { - let mut total = 0u; - for item in v.iter() { - total += item.weight; - } - if total == 0u { - return None; - } - let chosen = self.gen_range(0u, total); - let mut so_far = 0u; - for item in v.iter() { - so_far += item.weight; - if so_far > chosen { - return Some(item.item.clone()); - } - } - unreachable!(); - } - - /// Return a vec containing copies of the items, in order, where - /// the weight of the item determines how many copies there are - /// - /// # Example - /// - /// ```rust - /// use std::rand; - /// use std::rand::Rng; - /// - /// fn main() { - /// let mut rng = rand::rng(); - /// let x = [rand::Weighted {weight: 4, item: 'a'}, - /// rand::Weighted {weight: 2, item: 'b'}, - /// rand::Weighted {weight: 2, item: 'c'}]; - /// println!("{}", rng.weighted_vec(x)); - /// } - /// ``` - fn weighted_vec<T:Clone>(&mut self, v: &[Weighted<T>]) -> ~[T] { - let mut r = ~[]; - for item in v.iter() { - for _ in range(0u, item.weight) { - r.push(item.item.clone()); - } - } - r - } - /// Shuffle a vec /// /// # Example @@ -861,44 +768,6 @@ mod test { } #[test] - fn test_choose_weighted() { - let mut r = rng(); - assert!(r.choose_weighted([ - Weighted { weight: 1u, item: 42 }, - ]) == 42); - assert!(r.choose_weighted([ - Weighted { weight: 0u, item: 42 }, - Weighted { weight: 1u, item: 43 }, - ]) == 43); - } - - #[test] - fn test_choose_weighted_option() { - let mut r = rng(); - assert!(r.choose_weighted_option([ - Weighted { weight: 1u, item: 42 }, - ]) == Some(42)); - assert!(r.choose_weighted_option([ - Weighted { weight: 0u, item: 42 }, - Weighted { weight: 1u, item: 43 }, - ]) == Some(43)); - let v: Option<int> = r.choose_weighted_option([]); - assert!(v.is_none()); - } - - #[test] - fn test_weighted_vec() { - let mut r = rng(); - let empty: ~[int] = ~[]; - assert_eq!(r.weighted_vec([]), empty); - assert!(r.weighted_vec([ - Weighted { weight: 0u, item: 3u }, - Weighted { weight: 1u, item: 2u }, - Weighted { weight: 2u, item: 1u }, - ]) == ~[2u, 1u, 1u]); - } - - #[test] fn test_shuffle() { let mut r = rng(); let empty: ~[int] = ~[]; |
