about summary refs log tree commit diff
diff options
context:
space:
mode:
authorHuon Wilson <dbau.pp+github@gmail.com>2013-10-10 20:18:07 +1100
committerHuon Wilson <dbau.pp+github@gmail.com>2013-10-23 10:40:06 +1100
commit148f737c199a5c9dd6d349751072add3cc458533 (patch)
tree6d11d443541b4076a8336b791fa374d00d259432
parent1420272ddc174996c532e14623d9f897ba5e7a9d (diff)
downloadrust-148f737c199a5c9dd6d349751072add3cc458533.tar.gz
rust-148f737c199a5c9dd6d349751072add3cc458533.zip
std::rand: add distributions::Range for generating [lo, hi).
This reifies the computations required for uniformity done by
(the old) `Rng.gen_integer_range` (now Rng.gen_range), so that they can
be amortised over many invocations, if it is called in a loop.

Also, it makes it correct, but using a trait + impls for each type,
rather than trying to coerce `Int` + `u64` to do the right thing. This
also makes it more extensible, e.g. big integers could & should
implement SampleRange.
-rw-r--r--src/libextra/base64.rs2
-rw-r--r--src/libextra/crypto/cryptoutil.rs2
-rw-r--r--src/libextra/flate.rs2
-rw-r--r--src/libextra/sort.rs12
-rw-r--r--src/libextra/treemap.rs2
-rw-r--r--src/libstd/rand/distributions.rs4
-rw-r--r--src/libstd/rand/mod.rs70
-rw-r--r--src/libstd/rand/range.rs235
-rw-r--r--src/libstd/rt/comm.rs2
-rw-r--r--src/libstd/rt/sched.rs2
-rw-r--r--src/test/bench/core-std.rs6
11 files changed, 287 insertions, 52 deletions
diff --git a/src/libextra/base64.rs b/src/libextra/base64.rs
index 3960be46686..429f7adf5f5 100644
--- a/src/libextra/base64.rs
+++ b/src/libextra/base64.rs
@@ -318,7 +318,7 @@ mod test {
         use std::vec;
 
         do 1000.times {
-            let times = task_rng().gen_integer_range(1u, 100);
+            let times = task_rng().gen_range(1u, 100);
             let v = vec::from_fn(times, |_| random::<u8>());
             assert_eq!(v.to_base64(STANDARD).from_base64().unwrap(), v);
         }
diff --git a/src/libextra/crypto/cryptoutil.rs b/src/libextra/crypto/cryptoutil.rs
index 97b82383d84..bb3524a7d49 100644
--- a/src/libextra/crypto/cryptoutil.rs
+++ b/src/libextra/crypto/cryptoutil.rs
@@ -365,7 +365,7 @@ pub mod test {
         digest.reset();
 
         while count < total_size {
-            let next: uint = rng.gen_integer_range(0, 2 * blocksize + 1);
+            let next: uint = rng.gen_range(0, 2 * blocksize + 1);
             let remaining = total_size - count;
             let size = if next > remaining { remaining } else { next };
             digest.input(buffer.slice_to(size));
diff --git a/src/libextra/flate.rs b/src/libextra/flate.rs
index 9d6c2e8aa82..3f3b588e8bc 100644
--- a/src/libextra/flate.rs
+++ b/src/libextra/flate.rs
@@ -113,7 +113,7 @@ mod tests {
         let mut r = rand::rng();
         let mut words = ~[];
         do 20.times {
-            let range = r.gen_integer_range(1u, 10);
+            let range = r.gen_range(1u, 10);
             words.push(r.gen_vec::<u8>(range));
         }
         do 20.times {
diff --git a/src/libextra/sort.rs b/src/libextra/sort.rs
index d884f4f05c1..2a456f8de3e 100644
--- a/src/libextra/sort.rs
+++ b/src/libextra/sort.rs
@@ -1069,8 +1069,8 @@ mod big_tests {
             isSorted(arr);
 
             do 3.times {
-                let i1 = rng.gen_integer_range(0u, n);
-                let i2 = rng.gen_integer_range(0u, n);
+                let i1 = rng.gen_range(0u, n);
+                let i2 = rng.gen_range(0u, n);
                 arr.swap(i1, i2);
             }
             tim_sort(arr); // 3sort
@@ -1088,7 +1088,7 @@ mod big_tests {
             isSorted(arr);
 
             do (n/100).times {
-                let idx = rng.gen_integer_range(0u, n);
+                let idx = rng.gen_range(0u, n);
                 arr[idx] = rng.gen();
             }
             tim_sort(arr);
@@ -1141,8 +1141,8 @@ mod big_tests {
             isSorted(arr);
 
             do 3.times {
-                let i1 = rng.gen_integer_range(0u, n);
-                let i2 = rng.gen_integer_range(0u, n);
+                let i1 = rng.gen_range(0u, n);
+                let i2 = rng.gen_range(0u, n);
                 arr.swap(i1, i2);
             }
             tim_sort(arr); // 3sort
@@ -1160,7 +1160,7 @@ mod big_tests {
             isSorted(arr);
 
             do (n/100).times {
-                let idx = rng.gen_integer_range(0u, n);
+                let idx = rng.gen_range(0u, n);
                 arr[idx] = @rng.gen();
             }
             tim_sort(arr);
diff --git a/src/libextra/treemap.rs b/src/libextra/treemap.rs
index ad196b32fb2..7ef9ba76b99 100644
--- a/src/libextra/treemap.rs
+++ b/src/libextra/treemap.rs
@@ -1028,7 +1028,7 @@ mod test_treemap {
             }
 
             do 30.times {
-                let r = rng.gen_integer_range(0, ctrl.len());
+                let r = rng.gen_range(0, ctrl.len());
                 let (key, _) = ctrl.remove(r);
                 assert!(map.remove(&key));
                 check_structure(&map);
diff --git a/src/libstd/rand/distributions.rs b/src/libstd/rand/distributions.rs
index 6b23bff4c45..b31e72bc697 100644
--- a/src/libstd/rand/distributions.rs
+++ b/src/libstd/rand/distributions.rs
@@ -23,6 +23,10 @@
 use num;
 use rand::{Rng,Rand};
 
+pub use self::range::Range;
+
+pub mod range;
+
 /// Things that can be used to create a random instance of `Support`.
 pub trait Sample<Support> {
     /// Generate a random value of `Support`, using `rng` as the
diff --git a/src/libstd/rand/mod.rs b/src/libstd/rand/mod.rs
index f5c60417bac..178f5106d28 100644
--- a/src/libstd/rand/mod.rs
+++ b/src/libstd/rand/mod.rs
@@ -55,17 +55,20 @@ fn main () {
 use mem::size_of;
 use unstable::raw::Slice;
 use cast;
+use cmp::Ord;
 use container::Container;
 use iter::{Iterator, range};
 use local_data;
 use prelude::*;
 use str;
-use u64;
 use vec;
 
 pub use self::isaac::{IsaacRng, Isaac64Rng};
 pub use self::os::OSRng;
 
+use self::distributions::{Range, IndependentSample};
+use self::distributions::range::SampleRange;
+
 pub mod distributions;
 pub mod isaac;
 pub mod os;
@@ -218,14 +221,14 @@ pub trait Rng {
         vec::from_fn(len, |_| self.gen())
     }
 
-    /// Generate a random primitive integer in the range [`low`,
-    /// `high`). Fails if `low >= high`.
+    /// Generate a random value in the range [`low`, `high`). Fails if
+    /// `low >= high`.
     ///
-    /// This gives a uniform distribution (assuming this RNG is itself
-    /// uniform), even for edge cases like `gen_integer_range(0u8,
-    /// 170)`, which a naive modulo operation would return numbers
-    /// less than 85 with double the probability to those greater than
-    /// 85.
+    /// This is a convenience wrapper around
+    /// `distributions::Range`. If this function will be called
+    /// repeatedly with the same arguments, one should use `Range`, as
+    /// that will amortize the computations that allow for perfect
+    /// uniformity, as they only happen on initialization.
     ///
     /// # Example
     ///
@@ -235,22 +238,15 @@ pub trait Rng {
     ///
     /// fn main() {
     ///    let mut rng = rand::task_rng();
-    ///    let n: uint = rng.gen_integer_range(0u, 10);
+    ///    let n: uint = rng.gen_range(0u, 10);
     ///    println!("{}", n);
-    ///    let m: int = rng.gen_integer_range(-40, 400);
+    ///    let m: float = rng.gen_range(-40.0, 1.3e5);
     ///    println!("{}", m);
     /// }
     /// ```
-    fn gen_integer_range<T: Rand + Int>(&mut self, low: T, high: T) -> T {
-        assert!(low < high, "RNG.gen_integer_range called with low >= high");
-        let range = (high - low).to_u64().unwrap();
-        let accept_zone = u64::max_value - u64::max_value % range;
-        loop {
-            let rand = self.gen::<u64>();
-            if rand < accept_zone {
-                return low + NumCast::from(rand % range).unwrap();
-            }
-        }
+    fn gen_range<T: Ord + SampleRange>(&mut self, low: T, high: T) -> T {
+        assert!(low < high, "Rng.gen_range called with low >= high");
+        Range::new(low, high).ind_sample(self)
     }
 
     /// Return a bool with a 1 in n chance of true
@@ -267,7 +263,7 @@ pub trait Rng {
     /// }
     /// ```
     fn gen_weighted_bool(&mut self, n: uint) -> bool {
-        n == 0 || self.gen_integer_range(0, n) == 0
+        n == 0 || self.gen_range(0, n) == 0
     }
 
     /// Return a random string of the specified length composed of
@@ -317,7 +313,7 @@ pub trait Rng {
         if values.is_empty() {
             None
         } else {
-            Some(&values[self.gen_integer_range(0u, values.len())])
+            Some(&values[self.gen_range(0u, values.len())])
         }
     }
 
@@ -368,7 +364,7 @@ pub trait Rng {
         if total == 0u {
             return None;
         }
-        let chosen = self.gen_integer_range(0u, total);
+        let chosen = self.gen_range(0u, total);
         let mut so_far = 0u;
         for item in v.iter() {
             so_far += item.weight;
@@ -447,7 +443,7 @@ pub trait Rng {
             // invariant: elements with index >= i have been locked in place.
             i -= 1u;
             // lock element i in place.
-            values.swap(i, self.gen_integer_range(0u, i + 1u));
+            values.swap(i, self.gen_range(0u, i + 1u));
         }
     }
 
@@ -473,7 +469,7 @@ pub trait Rng {
                 continue
             }
 
-            let k = self.gen_integer_range(0, i + 1);
+            let k = self.gen_range(0, i + 1);
             if k < reservoir.len() {
                 reservoir[k] = elem
             }
@@ -760,36 +756,36 @@ mod test {
     }
 
     #[test]
-    fn test_gen_integer_range() {
+    fn test_gen_range() {
         let mut r = rng();
         for _ in range(0, 1000) {
-            let a = r.gen_integer_range(-3i, 42);
+            let a = r.gen_range(-3i, 42);
             assert!(a >= -3 && a < 42);
-            assert_eq!(r.gen_integer_range(0, 1), 0);
-            assert_eq!(r.gen_integer_range(-12, -11), -12);
+            assert_eq!(r.gen_range(0, 1), 0);
+            assert_eq!(r.gen_range(-12, -11), -12);
         }
 
         for _ in range(0, 1000) {
-            let a = r.gen_integer_range(10, 42);
+            let a = r.gen_range(10, 42);
             assert!(a >= 10 && a < 42);
-            assert_eq!(r.gen_integer_range(0, 1), 0);
-            assert_eq!(r.gen_integer_range(3_000_000u, 3_000_001), 3_000_000);
+            assert_eq!(r.gen_range(0, 1), 0);
+            assert_eq!(r.gen_range(3_000_000u, 3_000_001), 3_000_000);
         }
 
     }
 
     #[test]
     #[should_fail]
-    fn test_gen_integer_range_fail_int() {
+    fn test_gen_range_fail_int() {
         let mut r = rng();
-        r.gen_integer_range(5i, -2);
+        r.gen_range(5i, -2);
     }
 
     #[test]
     #[should_fail]
-    fn test_gen_integer_range_fail_uint() {
+    fn test_gen_range_fail_uint() {
         let mut r = rng();
-        r.gen_integer_range(5u, 2u);
+        r.gen_range(5u, 2u);
     }
 
     #[test]
@@ -894,7 +890,7 @@ mod test {
         let mut r = task_rng();
         r.gen::<int>();
         assert_eq!(r.shuffle(~[1, 1, 1]), ~[1, 1, 1]);
-        assert_eq!(r.gen_integer_range(0u, 1u), 0u);
+        assert_eq!(r.gen_range(0u, 1u), 0u);
     }
 
     #[test]
diff --git a/src/libstd/rand/range.rs b/src/libstd/rand/range.rs
new file mode 100644
index 00000000000..1b805a0b8f7
--- /dev/null
+++ b/src/libstd/rand/range.rs
@@ -0,0 +1,235 @@
+// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+//! Generating numbers between two others.
+
+// this is surprisingly complicated to be both generic & correct
+
+use cmp::Ord;
+use num::Bounded;
+use rand::Rng;
+use rand::distributions::{Sample, IndependentSample};
+
+/// Sample values uniformly between two bounds.
+///
+/// This gives a uniform distribution (assuming the RNG used to sample
+/// it is itself uniform & the `SampleRange` implementation for the
+/// given type is correct), even for edge cases like `low = 0u8`,
+/// `high = 170u8`, for which a naive modulo operation would return
+/// numbers less than 85 with double the probability to those greater
+/// than 85.
+///
+/// Types should attempt to sample in `[low, high)`, i.e., not
+/// including `high`, but this may be very difficult. All the
+/// primitive integer types satisfy this property, and the float types
+/// normally satisfy it, but rounding may mean `high` can occur.
+///
+/// # Example
+///
+/// ```rust
+/// use std::rand;
+/// use std::rand::distributions::{IndependentSample, Range};
+///
+/// fn main() {
+///     let between = Range::new(10u, 10000u);
+///     let rng = rand::task_rng();
+///     let mut sum = 0;
+///     for _ in range(0, 1000) {
+///         sum += between.ind_sample(rng);
+///     }
+///     println!("{}", sum);
+/// }
+/// ```
+pub struct Range<X> {
+    priv low: X,
+    priv range: X,
+    priv accept_zone: X
+}
+
+impl<X: SampleRange + Ord> Range<X> {
+    /// Create a new `Range` instance that samples uniformly from
+    /// `[low, high)`. Fails if `low >= high`.
+    pub fn new(low: X, high: X) -> Range<X> {
+        assert!(low < high, "Range::new called with `low >= high`");
+        SampleRange::construct_range(low, high)
+    }
+}
+
+impl<Sup: SampleRange> Sample<Sup> for Range<Sup> {
+    #[inline]
+    fn sample<R: Rng>(&mut self, rng: &mut R) -> Sup { self.ind_sample(rng) }
+}
+impl<Sup: SampleRange> IndependentSample<Sup> for Range<Sup> {
+    fn ind_sample<R: Rng>(&self, rng: &mut R) -> Sup {
+        SampleRange::sample_range(self, rng)
+    }
+}
+
+/// The helper trait for types that have a sensible way to sample
+/// uniformly between two values. This should not be used directly,
+/// and is only to facilitate `Range`.
+pub trait SampleRange {
+    /// Construct the `Range` object that `sample_range`
+    /// requires. This should not ever be called directly, only via
+    /// `Range::new`, which will check that `low < high`, so this
+    /// function doesn't have to repeat the check.
+    fn construct_range(low: Self, high: Self) -> Range<Self>;
+
+    /// Sample a value from the given `Range` with the given `Rng` as
+    /// a source of randomness.
+    fn sample_range<R: Rng>(r: &Range<Self>, rng: &mut R) -> Self;
+}
+
+macro_rules! integer_impl {
+    ($ty:ty, $unsigned:ty) => {
+        impl SampleRange for $ty {
+            // we play free and fast with unsigned vs signed here
+            // (when $ty is signed), but that's fine, since the
+            // contract of this macro is for $ty and $unsigned to be
+            // "bit-equal", so casting between them is a no-op & a
+            // bijection.
+
+            fn construct_range(low: $ty, high: $ty) -> Range<$ty> {
+                let range = high as $unsigned - low as $unsigned;
+                let unsigned_max: $unsigned = Bounded::max_value();
+
+                // this is the largest number that fits into $unsigned
+                // that `range` divides evenly, so, if we've sampled
+                // `n` uniformly from this region, then `n % range` is
+                // uniform in [0, range)
+                let zone = unsigned_max - unsigned_max % range;
+
+                Range {
+                    low: low,
+                    range: range as $ty,
+                    accept_zone: zone as $ty
+                }
+            }
+            #[inline]
+            fn sample_range<R: Rng>(r: &Range<$ty>, rng: &mut R) -> $ty {
+                loop {
+                    // rejection sample
+                    let v = rng.gen::<$unsigned>();
+                    // until we find something that fits into the
+                    // region which r.range evenly divides (this will
+                    // be uniformly distributed)
+                    if v < r.accept_zone as $unsigned {
+                        // and return it, with some adjustments
+                        return r.low + (v % r.range as $unsigned) as $ty;
+                    }
+                }
+            }
+        }
+    }
+}
+
+integer_impl! { i8, u8 }
+integer_impl! { i16, u16 }
+integer_impl! { i32, u32 }
+integer_impl! { i64, u64 }
+integer_impl! { int, uint }
+integer_impl! { u8, u8 }
+integer_impl! { u16, u16 }
+integer_impl! { u32, u32 }
+integer_impl! { u64, u64 }
+integer_impl! { uint, uint }
+
+macro_rules! float_impl {
+    ($ty:ty) => {
+        impl SampleRange for $ty {
+            fn construct_range(low: $ty, high: $ty) -> Range<$ty> {
+                Range {
+                    low: low,
+                    range: high - low,
+                    accept_zone: 0.0 // unused
+                }
+            }
+            fn sample_range<R: Rng>(r: &Range<$ty>, rng: &mut R) -> $ty {
+                r.low + r.range * rng.gen()
+            }
+        }
+    }
+}
+
+float_impl! { f32 }
+float_impl! { f64 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use rand::*;
+    use num::Bounded;
+    use iter::range;
+    use option::{Some, None};
+    use vec::ImmutableVector;
+
+    #[should_fail]
+    #[test]
+    fn test_range_bad_limits_equal() {
+        Range::new(10, 10);
+    }
+    #[should_fail]
+    #[test]
+    fn test_range_bad_limits_flipped() {
+        Range::new(10, 5);
+    }
+
+    #[test]
+    fn test_integers() {
+        let rng = task_rng();
+        macro_rules! t (
+            ($($ty:ty),*) => {{
+                $(
+                   let v: &[($ty, $ty)] = [(0, 10),
+                                           (10, 127),
+                                           (Bounded::min_value(), Bounded::max_value())];
+                   for &(low, high) in v.iter() {
+                        let mut sampler: Range<$ty> = Range::new(low, high);
+                        for _ in range(0, 1000) {
+                            let v = sampler.sample(rng);
+                            assert!(low <= v && v < high);
+                            let v = sampler.ind_sample(rng);
+                            assert!(low <= v && v < high);
+                        }
+                    }
+                 )*
+            }}
+        );
+        t!(i8, i16, i32, i64, int,
+           u8, u16, u32, u64, uint)
+    }
+
+    #[test]
+    fn test_floats() {
+        let rng = task_rng();
+        macro_rules! t (
+            ($($ty:ty),*) => {{
+                $(
+                   let v: &[($ty, $ty)] = [(0.0, 100.0),
+                                           (-1e35, -1e25),
+                                           (1e-35, 1e-25),
+                                           (-1e35, 1e35)];
+                   for &(low, high) in v.iter() {
+                        let mut sampler: Range<$ty> = Range::new(low, high);
+                        for _ in range(0, 1000) {
+                            let v = sampler.sample(rng);
+                            assert!(low <= v && v < high);
+                            let v = sampler.ind_sample(rng);
+                            assert!(low <= v && v < high);
+                        }
+                    }
+                 )*
+            }}
+        );
+
+        t!(f32, f64)
+    }
+
+}
diff --git a/src/libstd/rt/comm.rs b/src/libstd/rt/comm.rs
index 4eae8bdc9a8..967dedd94a6 100644
--- a/src/libstd/rt/comm.rs
+++ b/src/libstd/rt/comm.rs
@@ -1117,7 +1117,7 @@ mod test {
             let total = stress_factor() + 10;
             let mut rng = rand::rng();
             do total.times {
-                let msgs = rng.gen_integer_range(0u, 10);
+                let msgs = rng.gen_range(0u, 10);
                 let pipe_clone = pipe.clone();
                 let end_chan_clone = end_chan.clone();
                 do spawntask_random {
diff --git a/src/libstd/rt/sched.rs b/src/libstd/rt/sched.rs
index 336d2518e43..c1090d36010 100644
--- a/src/libstd/rt/sched.rs
+++ b/src/libstd/rt/sched.rs
@@ -431,7 +431,7 @@ impl Scheduler {
     fn try_steals(&mut self) -> Option<~Task> {
         let work_queues = &mut self.work_queues;
         let len = work_queues.len();
-        let start_index = self.rng.gen_integer_range(0, len);
+        let start_index = self.rng.gen_range(0, len);
         for index in range(0, len).map(|i| (i + start_index) % len) {
             match work_queues[index].steal() {
                 Some(task) => {
diff --git a/src/test/bench/core-std.rs b/src/test/bench/core-std.rs
index f549f747ef7..dbd1edffe78 100644
--- a/src/test/bench/core-std.rs
+++ b/src/test/bench/core-std.rs
@@ -90,7 +90,7 @@ fn vec_plus() {
     let mut v = ~[];
     let mut i = 0;
     while i < 1500 {
-        let rv = vec::from_elem(r.gen_integer_range(0u, i + 1), i);
+        let rv = vec::from_elem(r.gen_range(0u, i + 1), i);
         if r.gen() {
             v.push_all_move(rv);
         } else {
@@ -106,7 +106,7 @@ fn vec_append() {
     let mut v = ~[];
     let mut i = 0;
     while i < 1500 {
-        let rv = vec::from_elem(r.gen_integer_range(0u, i + 1), i);
+        let rv = vec::from_elem(r.gen_range(0u, i + 1), i);
         if r.gen() {
             v = vec::append(v, rv);
         }
@@ -122,7 +122,7 @@ fn vec_push_all() {
 
     let mut v = ~[];
     for i in range(0u, 1500) {
-        let mut rv = vec::from_elem(r.gen_integer_range(0u, i + 1), i);
+        let mut rv = vec::from_elem(r.gen_range(0u, i + 1), i);
         if r.gen() {
             v.push_all(rv);
         }