about summary refs log tree commit diff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs288
1 files changed, 288 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..d2e4b2b
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,288 @@
+use std::ops::Range;
+
+/// Mersenne Twister implementation based on the C code from wikipedia
+/// https://en.wikipedia.org/wiki/Mersenne_Twister#C_code
+pub struct MersenneTwister {
+    state: [u32; Self::N],
+    state_index: usize,
+}
+
+impl MersenneTwister {
+    // "number of recurrence"
+    const N: usize = 624;
+    // " middle word, an offset used in the recurrence relation"
+    const M: usize = 397;
+    // ?
+    const F: u32 = 1812433253;
+    // "word size (in number of bits)"
+    const W: u32 = 32;
+    // "separation point of one word, or the number of bits of the lower bitmask"
+    const R: u32 = 31;
+    const UMASK: u32 = 0xffffffff << Self::R;
+    const LMASK: u32 = 0xffffffff >> (Self::W - Self::R);
+    // "coefficients of the rational normal form twist matrix"
+    const A: u32 = 0x9908b0df;
+
+    // TGFSR(R) Tempering bitmasks
+    const B: u32 = 0x9d2c5680;
+    const C: u32 = 0xefc60000;
+
+    // TGFSR(R) Tempering bit shifts
+    const S: u32 = 7;
+    const T: u32 = 15;
+
+    // Mersenne Twister bit shifts
+    const U: u32 = 11;
+    const L: u32 = 18;
+
+    fn initialize_state(state: &mut [u32], seed: u32) {
+        state[0] = seed;
+
+        let mut current_seed = seed;
+        for idx in 1..Self::N {
+            current_seed =
+                Self::F.wrapping_mul(current_seed ^ (current_seed >> (Self::W - 2))) + idx as u32;
+            state[idx] = current_seed;
+        }
+    }
+
+    fn initialize_by_array(state: &mut [u32], key: &[u32]) {
+        Self::initialize_state(state, 19650218);
+        let mut i = 1;
+        let mut j = 0;
+
+        let k = if Self::N > key.len() {
+            Self::N
+        } else {
+            key.len()
+        };
+
+        for _ in 0..k {
+            state[i] = (state[i] ^ (state[i - 1] ^ (state[i - 1] >> 30)).wrapping_mul(1664525))
+                + key[j]
+                + j as u32;
+
+            i += 1;
+            j += 1;
+
+            if i >= Self::N {
+                state[0] = state[Self::N - 1];
+                i = 1;
+            }
+
+            if j >= key.len() {
+                j = 0;
+            }
+        }
+
+        for _ in 0..Self::N - 1 {
+            state[i] = (state[i] ^ (state[i - 1] ^ (state[i - 1] >> 30)).wrapping_mul(1566083941))
+                - i as u32;
+
+            i += 1;
+
+            if i >= Self::N {
+                state[0] = state[Self::N - 1];
+                i = 1;
+            }
+        }
+
+        state[0] = 0x80000000;
+    }
+
+    #[allow(dead_code)]
+    pub fn new(seed: u32) -> Self {
+        let mut state = [032; Self::N];
+        Self::initialize_state(&mut state, seed);
+
+        Self {
+            state,
+            state_index: 0,
+        }
+    }
+
+    pub fn new_seed_array(key: &[u32]) -> Self {
+        let mut state = [032; Self::N];
+        Self::initialize_by_array(&mut state, key);
+
+        Self {
+            state,
+            state_index: 0,
+        }
+    }
+
+    #[allow(non_snake_case)]
+    pub fn random_u32(&mut self) -> u32 {
+        let mut j = if self.state_index < Self::N - 1 {
+            self.state_index + 1
+        } else {
+            self.state_index - (Self::N - 1)
+        };
+
+        let mut x = (self.state[self.state_index] & Self::UMASK) | (self.state[j] & Self::LMASK);
+        let mut xA = x >> 1;
+
+        if x & 0x00000001 > 0 {
+            xA ^= Self::A;
+        }
+
+        j = if self.state_index < (Self::N - Self::M) {
+            self.state_index + Self::M
+        } else {
+            self.state_index - (Self::N - Self::M)
+        };
+
+        x = self.state[j] ^ xA;
+        self.state[self.state_index] = x;
+        self.state_index += 1;
+
+        if self.state_index >= Self::N {
+            self.state_index = 0;
+        };
+
+        let mut y = x ^ (x >> Self::U); // tempering
+        y = y ^ ((y << Self::S) & Self::B);
+        y = y ^ ((y << Self::T) & Self::C);
+
+        y ^ (y >> Self::L)
+    }
+
+    pub fn random_f64(&mut self) -> f64 {
+        let rand = self.random_u32();
+
+        // [0,1)
+        rand as f64 * (1.0 / (u32::MAX as u64 + 1) as f64)
+    }
+
+    pub fn random_range(&mut self, range: Range<u32>) -> u32 {
+        (self.random_f64() * (range.end as f64 - range.start as f64) + range.start as f64) as u32
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use std::borrow::BorrowMut;
+
+    use super::MersenneTwister;
+
+    #[test]
+    fn mersenne_state_simple_seed_matches() {
+        let mt = MersenneTwister::new(19650218);
+        let numbers = load_MTwikipedia_state();
+
+        for (idx, number) in numbers.into_iter().enumerate() {
+            let ours = mt.state[idx];
+            if ours != number {
+                panic!("reference simple state was {number}, we got {ours}")
+            }
+        }
+    }
+
+    #[test]
+    fn mersenne_state_array_matches() {
+        let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456];
+        let mt = MersenneTwister::new_seed_array(&key);
+        let numbers = load_MTwikipedia_array_state();
+
+        for (idx, number) in numbers.into_iter().enumerate() {
+            let ours = mt.state[idx];
+            if ours != number {
+                panic!("reference array state was {number}, we got {ours}")
+            }
+        }
+    }
+
+    #[test]
+    fn mersenne_matches() {
+        // Initial array
+        let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456];
+        let mut mt = MersenneTwister::new_seed_array(&key);
+        let numbers = load_mt19937ar_cok();
+
+        for number in numbers {
+            let ours = mt.random_u32();
+            if ours != number {
+                panic!("reference was {number}, we got {ours}")
+            }
+        }
+    }
+
+    #[test]
+    fn mersenne_matches_real() {
+        // Initial array
+        let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456];
+        let mut mt = MersenneTwister::new_seed_array(&key);
+        let numbers = load_mt19937ar_cok_real();
+
+        // The C program first generated 1000 random u32, so we do it too.
+        for _ in 0..1000 {
+            let _ = mt.random_u32();
+        }
+
+        for number in numbers {
+            let ours = mt.random_f64();
+            let string = format!("{ours:.8}");
+            if !string.eq(&number) {
+                panic!("reference was '{number}', we got '{string}'")
+            }
+        }
+    }
+
+    fn load_mt19937ar_cok() -> Vec<u32> {
+        let string = std::fs::read_to_string("test/mt19937ar-cok").unwrap();
+        let mut numbers = vec![];
+
+        for line in string.lines() {
+            let mut chars = line.chars();
+            for _ in 0..5 {
+                let number: String = chars.borrow_mut().take(10).collect();
+                let _ = chars.next(); // throwout the space
+
+                let ulong: u32 = number.trim().parse().unwrap();
+                numbers.push(ulong);
+            }
+        }
+
+        numbers
+    }
+
+    fn load_mt19937ar_cok_real() -> Vec<String> {
+        let string = std::fs::read_to_string("test/mt19937ar-cok-real").unwrap();
+        let mut numbers = vec![];
+
+        for line in string.lines() {
+            let mut chars = line.chars();
+            for _ in 0..5 {
+                let number: String = chars.borrow_mut().take(10).collect();
+                let _ = chars.next(); // throwout the space
+                numbers.push(number);
+            }
+        }
+
+        numbers
+    }
+
+    #[allow(non_snake_case)]
+    fn load_MTwikipedia_array_state() -> Vec<u32> {
+        let string = std::fs::read_to_string("test/MTwikipedia-array_state").unwrap();
+        let mut numbers = vec![];
+
+        for line in string.lines() {
+            numbers.push(line.trim().parse().unwrap());
+        }
+
+        numbers
+    }
+
+    #[allow(non_snake_case)]
+    fn load_MTwikipedia_state() -> Vec<u32> {
+        let string = std::fs::read_to_string("test/MTwikipedia-state").unwrap();
+        let mut numbers = vec![];
+
+        for line in string.lines() {
+            numbers.push(line.trim().parse().unwrap());
+        }
+
+        numbers
+    }
+}