about summary refs log tree commit diff
path: root/src/libstd
diff options
context:
space:
mode:
authorbors <bors@rust-lang.org>2013-08-15 06:56:06 -0700
committerbors <bors@rust-lang.org>2013-08-15 06:56:06 -0700
commit435020ecc42c472302dcbcce64abc1a9943b80b9 (patch)
treec0b61f72985bd9b917fa4a7609220b6d934f6344 /src/libstd
parent77739a70849293f67aba8ce3379e050cd319d1ce (diff)
parent11b3d76fb6ae627aea1b2a3a3de2807c21f9e720 (diff)
downloadrust-435020ecc42c472302dcbcce64abc1a9943b80b9.tar.gz
rust-435020ecc42c472302dcbcce64abc1a9943b80b9.zip
auto merge of #8491 : robertknight/rust/7722-reservoir_sampling, r=graydon
Fixes #7722

I had a couple of queries:
- Should this return an array or an iterator?
- Should this be a method on iterators or on the rng? I implemented it in RngUtils as it seemed to belong with shuffle().
Diffstat (limited to 'src/libstd')
-rw-r--r--src/libstd/rand.rs56
1 files changed, 56 insertions, 0 deletions
diff --git a/src/libstd/rand.rs b/src/libstd/rand.rs
index 500278fddb0..bd2ea1d6ac6 100644
--- a/src/libstd/rand.rs
+++ b/src/libstd/rand.rs
@@ -461,6 +461,26 @@ pub trait RngUtil {
      * ~~~
      */
     fn shuffle_mut<T>(&mut self, values: &mut [T]);
+
+    /**
+     * Sample up to `n` values from an iterator.
+     *
+     * # Example
+     *
+     * ~~~ {.rust}
+     *
+     * use std::rand;
+     * use std::rand::RngUtil;
+     *
+     * fn main() {
+     *     let mut rng = rand::rng();
+     *     let vals = range(1, 100).to_owned_vec();
+     *     let sample = rng.sample(vals.iter(), 5);
+     *     printfln!(sample);
+     * }
+     * ~~~
+     */
+    fn sample<A, T: Iterator<A>>(&mut self, iter: T, n: uint) -> ~[A];
 }
 
 /// Extension methods for random number generators
@@ -607,6 +627,23 @@ impl<R: Rng> RngUtil for R {
             values.swap(i, self.gen_uint_range(0u, i + 1u));
         }
     }
+
+    /// Randomly sample up to `n` elements from an iterator
+    fn sample<A, T: Iterator<A>>(&mut self, iter: T, n: uint) -> ~[A] {
+        let mut reservoir : ~[A] = vec::with_capacity(n);
+        for (i, elem) in iter.enumerate() {
+            if i < n {
+                reservoir.push(elem);
+                loop
+            }
+
+            let k = self.gen_uint_range(0, i + 1);
+            if k < reservoir.len() {
+                reservoir[k] = elem
+            }
+        }
+        reservoir
+    }
 }
 
 /// Create a random number generator with a default algorithm and seed.
@@ -914,6 +951,7 @@ pub fn random<T: Rand>() -> T {
 
 #[cfg(test)]
 mod test {
+    use iterator::{Iterator, range};
     use option::{Option, Some};
     use super::*;
 
@@ -1130,6 +1168,24 @@ mod test {
             }
         }
     }
+
+    #[test]
+    fn test_sample() {
+        let MIN_VAL = 1;
+        let MAX_VAL = 100;
+
+        let mut r = rng();
+        let vals = range(MIN_VAL, MAX_VAL).to_owned_vec();
+        let small_sample = r.sample(vals.iter(), 5);
+        let large_sample = r.sample(vals.iter(), vals.len() + 5);
+
+        assert_eq!(small_sample.len(), 5);
+        assert_eq!(large_sample.len(), vals.len());
+
+        assert!(small_sample.iter().all(|e| {
+            **e >= MIN_VAL && **e <= MAX_VAL
+        }));
+    }
 }
 
 #[cfg(test)]