diff options
Diffstat (limited to 'src/lib.rs')
-rw-r--r-- | src/lib.rs | 288 |
1 files changed, 288 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..d2e4b2b --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,288 @@ +use std::ops::Range; + +/// Mersenne Twister implementation based on the C code from wikipedia +/// https://en.wikipedia.org/wiki/Mersenne_Twister#C_code +pub struct MersenneTwister { + state: [u32; Self::N], + state_index: usize, +} + +impl MersenneTwister { + // "number of recurrence" + const N: usize = 624; + // " middle word, an offset used in the recurrence relation" + const M: usize = 397; + // ? + const F: u32 = 1812433253; + // "word size (in number of bits)" + const W: u32 = 32; + // "separation point of one word, or the number of bits of the lower bitmask" + const R: u32 = 31; + const UMASK: u32 = 0xffffffff << Self::R; + const LMASK: u32 = 0xffffffff >> (Self::W - Self::R); + // "coefficients of the rational normal form twist matrix" + const A: u32 = 0x9908b0df; + + // TGFSR(R) Tempering bitmasks + const B: u32 = 0x9d2c5680; + const C: u32 = 0xefc60000; + + // TGFSR(R) Tempering bit shifts + const S: u32 = 7; + const T: u32 = 15; + + // Mersenne Twister bit shifts + const U: u32 = 11; + const L: u32 = 18; + + fn initialize_state(state: &mut [u32], seed: u32) { + state[0] = seed; + + let mut current_seed = seed; + for idx in 1..Self::N { + current_seed = + Self::F.wrapping_mul(current_seed ^ (current_seed >> (Self::W - 2))) + idx as u32; + state[idx] = current_seed; + } + } + + fn initialize_by_array(state: &mut [u32], key: &[u32]) { + Self::initialize_state(state, 19650218); + let mut i = 1; + let mut j = 0; + + let k = if Self::N > key.len() { + Self::N + } else { + key.len() + }; + + for _ in 0..k { + state[i] = (state[i] ^ (state[i - 1] ^ (state[i - 1] >> 30)).wrapping_mul(1664525)) + + key[j] + + j as u32; + + i += 1; + j += 1; + + if i >= Self::N { + state[0] = state[Self::N - 1]; + i = 1; + } + + if j >= key.len() { + j = 0; + } + } + + for _ in 0..Self::N - 1 { + state[i] = (state[i] ^ (state[i - 1] ^ (state[i - 1] >> 30)).wrapping_mul(1566083941)) + - i as u32; + + i += 1; + + if i >= Self::N { + state[0] = state[Self::N - 1]; + i = 1; + } + } + + state[0] = 0x80000000; + } + + #[allow(dead_code)] + pub fn new(seed: u32) -> Self { + let mut state = [032; Self::N]; + Self::initialize_state(&mut state, seed); + + Self { + state, + state_index: 0, + } + } + + pub fn new_seed_array(key: &[u32]) -> Self { + let mut state = [032; Self::N]; + Self::initialize_by_array(&mut state, key); + + Self { + state, + state_index: 0, + } + } + + #[allow(non_snake_case)] + pub fn random_u32(&mut self) -> u32 { + let mut j = if self.state_index < Self::N - 1 { + self.state_index + 1 + } else { + self.state_index - (Self::N - 1) + }; + + let mut x = (self.state[self.state_index] & Self::UMASK) | (self.state[j] & Self::LMASK); + let mut xA = x >> 1; + + if x & 0x00000001 > 0 { + xA ^= Self::A; + } + + j = if self.state_index < (Self::N - Self::M) { + self.state_index + Self::M + } else { + self.state_index - (Self::N - Self::M) + }; + + x = self.state[j] ^ xA; + self.state[self.state_index] = x; + self.state_index += 1; + + if self.state_index >= Self::N { + self.state_index = 0; + }; + + let mut y = x ^ (x >> Self::U); // tempering + y = y ^ ((y << Self::S) & Self::B); + y = y ^ ((y << Self::T) & Self::C); + + y ^ (y >> Self::L) + } + + pub fn random_f64(&mut self) -> f64 { + let rand = self.random_u32(); + + // [0,1) + rand as f64 * (1.0 / (u32::MAX as u64 + 1) as f64) + } + + pub fn random_range(&mut self, range: Range<u32>) -> u32 { + (self.random_f64() * (range.end as f64 - range.start as f64) + range.start as f64) as u32 + } +} + +#[cfg(test)] +mod test { + use std::borrow::BorrowMut; + + use super::MersenneTwister; + + #[test] + fn mersenne_state_simple_seed_matches() { + let mt = MersenneTwister::new(19650218); + let numbers = load_MTwikipedia_state(); + + for (idx, number) in numbers.into_iter().enumerate() { + let ours = mt.state[idx]; + if ours != number { + panic!("reference simple state was {number}, we got {ours}") + } + } + } + + #[test] + fn mersenne_state_array_matches() { + let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; + let mt = MersenneTwister::new_seed_array(&key); + let numbers = load_MTwikipedia_array_state(); + + for (idx, number) in numbers.into_iter().enumerate() { + let ours = mt.state[idx]; + if ours != number { + panic!("reference array state was {number}, we got {ours}") + } + } + } + + #[test] + fn mersenne_matches() { + // Initial array + let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; + let mut mt = MersenneTwister::new_seed_array(&key); + let numbers = load_mt19937ar_cok(); + + for number in numbers { + let ours = mt.random_u32(); + if ours != number { + panic!("reference was {number}, we got {ours}") + } + } + } + + #[test] + fn mersenne_matches_real() { + // Initial array + let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; + let mut mt = MersenneTwister::new_seed_array(&key); + let numbers = load_mt19937ar_cok_real(); + + // The C program first generated 1000 random u32, so we do it too. + for _ in 0..1000 { + let _ = mt.random_u32(); + } + + for number in numbers { + let ours = mt.random_f64(); + let string = format!("{ours:.8}"); + if !string.eq(&number) { + panic!("reference was '{number}', we got '{string}'") + } + } + } + + fn load_mt19937ar_cok() -> Vec<u32> { + let string = std::fs::read_to_string("test/mt19937ar-cok").unwrap(); + let mut numbers = vec![]; + + for line in string.lines() { + let mut chars = line.chars(); + for _ in 0..5 { + let number: String = chars.borrow_mut().take(10).collect(); + let _ = chars.next(); // throwout the space + + let ulong: u32 = number.trim().parse().unwrap(); + numbers.push(ulong); + } + } + + numbers + } + + fn load_mt19937ar_cok_real() -> Vec<String> { + let string = std::fs::read_to_string("test/mt19937ar-cok-real").unwrap(); + let mut numbers = vec![]; + + for line in string.lines() { + let mut chars = line.chars(); + for _ in 0..5 { + let number: String = chars.borrow_mut().take(10).collect(); + let _ = chars.next(); // throwout the space + numbers.push(number); + } + } + + numbers + } + + #[allow(non_snake_case)] + fn load_MTwikipedia_array_state() -> Vec<u32> { + let string = std::fs::read_to_string("test/MTwikipedia-array_state").unwrap(); + let mut numbers = vec![]; + + for line in string.lines() { + numbers.push(line.trim().parse().unwrap()); + } + + numbers + } + + #[allow(non_snake_case)] + fn load_MTwikipedia_state() -> Vec<u32> { + let string = std::fs::read_to_string("test/MTwikipedia-state").unwrap(); + let mut numbers = vec![]; + + for line in string.lines() { + numbers.push(line.trim().parse().unwrap()); + } + + numbers + } +} |