about summary refs log tree commit diff
path: root/src/libstd
diff options
context:
space:
mode:
authorHuon Wilson <dbau.pp+github@gmail.com>2013-10-12 02:15:22 +1100
committerHuon Wilson <dbau.pp+github@gmail.com>2013-10-23 10:40:06 +1100
commit0bba73c0d156df8f22c64ef4f4c50910fe31cf31 (patch)
tree3a0b98b2985fdf4b4717aecab31464718d7dc162 /src/libstd
parent83aa1abb19ee290655ed03053e26cdb2662242c4 (diff)
downloadrust-0bba73c0d156df8f22c64ef4f4c50910fe31cf31.tar.gz
rust-0bba73c0d156df8f22c64ef4f4c50910fe31cf31.zip
std::rand: move Weighted to distributions.
A user constructs the WeightedChoice distribution and then samples from
it, which allows it to use binary search internally.
Diffstat (limited to 'src/libstd')
-rw-r--r--src/libstd/rand/distributions.rs208
-rw-r--r--src/libstd/rand/mod.rs131
2 files changed, 207 insertions, 132 deletions
diff --git a/src/libstd/rand/distributions.rs b/src/libstd/rand/distributions.rs
index 4a025ae05d7..e7bcf8ce5d3 100644
--- a/src/libstd/rand/distributions.rs
+++ b/src/libstd/rand/distributions.rs
@@ -20,8 +20,11 @@ that do not need to record state.
 
 */
 
+use iter::range;
+use option::{Some, None};
 use num;
 use rand::{Rng,Rand};
+use clone::Clone;
 
 pub use self::range::Range;
 
@@ -61,8 +64,128 @@ impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
     }
 }
 
-mod ziggurat_tables;
+/// A value with a particular weight for use with `WeightedChoice`.
+pub struct Weighted<T> {
+    /// The numerical weight of this item
+    weight: uint,
+    /// The actual item which is being weighted
+    item: T,
+}
+
+/// A distribution that selects from a finite collection of weighted items.
+///
+/// Each item has an associated weight that influences how likely it
+/// is to be chosen: higher weight is more likely.
+///
+/// The `Clone` restriction is a limitation of the `Sample` and
+/// `IndepedentSample` traits. Note that `&T` is (cheaply) `Clone` for
+/// all `T`, as is `uint`, so one can store references or indices into
+/// another vector.
+///
+/// # Example
+///
+/// ```rust
+/// use std::rand;
+/// use std::rand::distributions::{Weighted, WeightedChoice, IndepedentSample};
+///
+/// fn main() {
+///     let wc = WeightedChoice::new(~[Weighted { weight: 2, item: 'a' },
+///                                    Weighted { weight: 4, item: 'b' },
+///                                    Weighted { weight: 1, item: 'c' }]);
+///     let rng = rand::task_rng();
+///     for _ in range(0, 16) {
+///          // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
+///          println!("{}", wc.ind_sample(rng));
+///     }
+/// }
+/// ```
+pub struct WeightedChoice<T> {
+    priv items: ~[Weighted<T>],
+    priv weight_range: Range<uint>
+}
+
+impl<T: Clone> WeightedChoice<T> {
+    /// Create a new `WeightedChoice`.
+    ///
+    /// Fails if:
+    /// - `v` is empty
+    /// - the total weight is 0
+    /// - the total weight is larger than a `uint` can contain.
+    pub fn new(mut items: ~[Weighted<T>]) -> WeightedChoice<T> {
+        // strictly speaking, this is subsumed by the total weight == 0 case
+        assert!(!items.is_empty(), "WeightedChoice::new called with no items");
+
+        let mut running_total = 0u;
+
+        // we convert the list from individual weights to cumulative
+        // weights so we can binary search. This *could* drop elements
+        // with weight == 0 as an optimisation.
+        for item in items.mut_iter() {
+            running_total = running_total.checked_add(&item.weight)
+                .expect("WeightedChoice::new called with a total weight larger \
+                        than a uint can contain");
+
+            item.weight = running_total;
+        }
+        assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0");
+
+        WeightedChoice {
+            items: items,
+            // we're likely to be generating numbers in this range
+            // relatively often, so might as well cache it
+            weight_range: Range::new(0, running_total)
+        }
+    }
+}
+
+impl<T: Clone> Sample<T> for WeightedChoice<T> {
+    fn sample<R: Rng>(&mut self, rng: &mut R) -> T { self.ind_sample(rng) }
+}
 
+impl<T: Clone> IndependentSample<T> for WeightedChoice<T> {
+    fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
+        // we want to find the first element that has cumulative
+        // weight > sample_weight, which we do by binary since the
+        // cumulative weights of self.items are sorted.
+
+        // choose a weight in [0, total_weight)
+        let sample_weight = self.weight_range.ind_sample(rng);
+
+        // short circuit when it's the first item
+        if sample_weight < self.items[0].weight {
+            return self.items[0].item.clone();
+        }
+
+        let mut idx = 0;
+        let mut modifier = self.items.len();
+
+        // now we know that every possibility has an element to the
+        // left, so we can just search for the last element that has
+        // cumulative weight <= sample_weight, then the next one will
+        // be "it". (Note that this greatest element will never be the
+        // last element of the vector, since sample_weight is chosen
+        // in [0, total_weight) and the cumulative weight of the last
+        // one is exactly the total weight.)
+        while modifier > 1 {
+            let i = idx + modifier / 2;
+            if self.items[i].weight <= sample_weight {
+                // we're small, so look to the right, but allow this
+                // exact element still.
+                idx = i;
+                // we need the `/ 2` to round up otherwise we'll drop
+                // the trailing elements when `modifier` is odd.
+                modifier += 1;
+            } else {
+                // otherwise we're too big, so go left. (i.e. do
+                // nothing)
+            }
+            modifier /= 2;
+        }
+        return self.items[idx + 1].item.clone();
+    }
+}
+
+mod ziggurat_tables;
 
 /// Sample a random number using the Ziggurat method (specifically the
 /// ZIGNOR variant from Doornik 2005). Most of the arguments are
@@ -302,6 +425,18 @@ mod tests {
         }
     }
 
+    // 0, 1, 2, 3, ...
+    struct CountingRng { i: u32 }
+    impl Rng for CountingRng {
+        fn next_u32(&mut self) -> u32 {
+            self.i += 1;
+            self.i - 1
+        }
+        fn next_u64(&mut self) -> u64 {
+            self.next_u32() as u64
+        }
+    }
+
     #[test]
     fn test_rand_sample() {
         let mut rand_sample = RandSample::<ConstRand>;
@@ -344,6 +479,77 @@ mod tests {
     fn test_exp_invalid_lambda_neg() {
         Exp::new(-10.0);
     }
+
+    #[test]
+    fn test_weighted_choice() {
+        // this makes assumptions about the internal implementation of
+        // WeightedChoice, specifically: it doesn't reorder the items,
+        // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to
+        // 1, internally; modulo a modulo operation).
+
+        macro_rules! t (
+            ($items:expr, $expected:expr) => {{
+                let wc = WeightedChoice::new($items);
+                let expected = $expected;
+
+                let mut rng = CountingRng { i: 0 };
+
+                for &val in expected.iter() {
+                    assert_eq!(wc.ind_sample(&mut rng), val)
+                }
+            }}
+        );
+
+        t!(~[Weighted { weight: 1, item: 10}], ~[10]);
+
+        // skip some
+        t!(~[Weighted { weight: 0, item: 20},
+             Weighted { weight: 2, item: 21},
+             Weighted { weight: 0, item: 22},
+             Weighted { weight: 1, item: 23}],
+           ~[21,21, 23]);
+
+        // different weights
+        t!(~[Weighted { weight: 4, item: 30},
+             Weighted { weight: 3, item: 31}],
+           ~[30,30,30,30, 31,31,31]);
+
+        // check that we're binary searching
+        // correctly with some vectors of odd
+        // length.
+        t!(~[Weighted { weight: 1, item: 40},
+             Weighted { weight: 1, item: 41},
+             Weighted { weight: 1, item: 42},
+             Weighted { weight: 1, item: 43},
+             Weighted { weight: 1, item: 44}],
+           ~[40, 41, 42, 43, 44]);
+        t!(~[Weighted { weight: 1, item: 50},
+             Weighted { weight: 1, item: 51},
+             Weighted { weight: 1, item: 52},
+             Weighted { weight: 1, item: 53},
+             Weighted { weight: 1, item: 54},
+             Weighted { weight: 1, item: 55},
+             Weighted { weight: 1, item: 56}],
+           ~[50, 51, 52, 53, 54, 55, 56]);
+    }
+
+    #[test] #[should_fail]
+    fn test_weighted_choice_no_items() {
+        WeightedChoice::<int>::new(~[]);
+    }
+    #[test] #[should_fail]
+    fn test_weighted_choice_zero_weight() {
+        WeightedChoice::new(~[Weighted { weight: 0, item: 0},
+                              Weighted { weight: 0, item: 1}]);
+    }
+    #[test] #[should_fail]
+    fn test_weighted_choice_weight_overflows() {
+        let x = (-1) as uint / 2; // x + x + 2 is the overflow
+        WeightedChoice::new(~[Weighted { weight: x, item: 0 },
+                              Weighted { weight: 1, item: 1 },
+                              Weighted { weight: x, item: 2 },
+                              Weighted { weight: 1, item: 3 }]);
+    }
 }
 
 #[cfg(test)]
diff --git a/src/libstd/rand/mod.rs b/src/libstd/rand/mod.rs
index 2d2c6e794f8..a372eb1f11a 100644
--- a/src/libstd/rand/mod.rs
+++ b/src/libstd/rand/mod.rs
@@ -100,14 +100,6 @@ pub trait Rand {
     fn rand<R: Rng>(rng: &mut R) -> Self;
 }
 
-/// A value with a particular weight compared to other values
-pub struct Weighted<T> {
-    /// The numerical weight of this item
-    weight: uint,
-    /// The actual item which is being weighted
-    item: T,
-}
-
 /// A random number generator
 pub trait Rng {
     /// Return the next random u32. This rarely needs to be called
@@ -334,91 +326,6 @@ pub trait Rng {
         }
     }
 
-    /// Choose an item respecting the relative weights, failing if the sum of
-    /// the weights is 0
-    ///
-    /// # Example
-    ///
-    /// ```rust
-    /// use std::rand;
-    /// use std::rand::Rng;
-    ///
-    /// fn main() {
-    ///     let mut rng = rand::rng();
-    ///     let x = [rand::Weighted {weight: 4, item: 'a'},
-    ///              rand::Weighted {weight: 2, item: 'b'},
-    ///              rand::Weighted {weight: 2, item: 'c'}];
-    ///     println!("{}", rng.choose_weighted(x));
-    /// }
-    /// ```
-    fn choose_weighted<T:Clone>(&mut self, v: &[Weighted<T>]) -> T {
-        self.choose_weighted_option(v).expect("Rng.choose_weighted: total weight is 0")
-    }
-
-    /// Choose Some(item) respecting the relative weights, returning none if
-    /// the sum of the weights is 0
-    ///
-    /// # Example
-    ///
-    /// ```rust
-    /// use std::rand;
-    /// use std::rand::Rng;
-    ///
-    /// fn main() {
-    ///     let mut rng = rand::rng();
-    ///     let x = [rand::Weighted {weight: 4, item: 'a'},
-    ///              rand::Weighted {weight: 2, item: 'b'},
-    ///              rand::Weighted {weight: 2, item: 'c'}];
-    ///     println!("{:?}", rng.choose_weighted_option(x));
-    /// }
-    /// ```
-    fn choose_weighted_option<T:Clone>(&mut self, v: &[Weighted<T>])
-                                       -> Option<T> {
-        let mut total = 0u;
-        for item in v.iter() {
-            total += item.weight;
-        }
-        if total == 0u {
-            return None;
-        }
-        let chosen = self.gen_range(0u, total);
-        let mut so_far = 0u;
-        for item in v.iter() {
-            so_far += item.weight;
-            if so_far > chosen {
-                return Some(item.item.clone());
-            }
-        }
-        unreachable!();
-    }
-
-    /// Return a vec containing copies of the items, in order, where
-    /// the weight of the item determines how many copies there are
-    ///
-    /// # Example
-    ///
-    /// ```rust
-    /// use std::rand;
-    /// use std::rand::Rng;
-    ///
-    /// fn main() {
-    ///     let mut rng = rand::rng();
-    ///     let x = [rand::Weighted {weight: 4, item: 'a'},
-    ///              rand::Weighted {weight: 2, item: 'b'},
-    ///              rand::Weighted {weight: 2, item: 'c'}];
-    ///     println!("{}", rng.weighted_vec(x));
-    /// }
-    /// ```
-    fn weighted_vec<T:Clone>(&mut self, v: &[Weighted<T>]) -> ~[T] {
-        let mut r = ~[];
-        for item in v.iter() {
-            for _ in range(0u, item.weight) {
-                r.push(item.item.clone());
-            }
-        }
-        r
-    }
-
     /// Shuffle a vec
     ///
     /// # Example
@@ -861,44 +768,6 @@ mod test {
     }
 
     #[test]
-    fn test_choose_weighted() {
-        let mut r = rng();
-        assert!(r.choose_weighted([
-            Weighted { weight: 1u, item: 42 },
-        ]) == 42);
-        assert!(r.choose_weighted([
-            Weighted { weight: 0u, item: 42 },
-            Weighted { weight: 1u, item: 43 },
-        ]) == 43);
-    }
-
-    #[test]
-    fn test_choose_weighted_option() {
-        let mut r = rng();
-        assert!(r.choose_weighted_option([
-            Weighted { weight: 1u, item: 42 },
-        ]) == Some(42));
-        assert!(r.choose_weighted_option([
-            Weighted { weight: 0u, item: 42 },
-            Weighted { weight: 1u, item: 43 },
-        ]) == Some(43));
-        let v: Option<int> = r.choose_weighted_option([]);
-        assert!(v.is_none());
-    }
-
-    #[test]
-    fn test_weighted_vec() {
-        let mut r = rng();
-        let empty: ~[int] = ~[];
-        assert_eq!(r.weighted_vec([]), empty);
-        assert!(r.weighted_vec([
-            Weighted { weight: 0u, item: 3u },
-            Weighted { weight: 1u, item: 2u },
-            Weighted { weight: 2u, item: 1u },
-        ]) == ~[2u, 1u, 1u]);
-    }
-
-    #[test]
     fn test_shuffle() {
         let mut r = rng();
         let empty: ~[int] = ~[];