use std::ops::Range; /// Mersenne Twister implementation based on the C code from wikipedia /// https://en.wikipedia.org/wiki/Mersenne_Twister#C_code pub struct MersenneTwister { state: [u32; Self::N], state_index: usize, } impl MersenneTwister { // "number of recurrence" const N: usize = 624; // " middle word, an offset used in the recurrence relation" const M: usize = 397; // ? const F: u32 = 1812433253; // "word size (in number of bits)" const W: u32 = 32; // "separation point of one word, or the number of bits of the lower bitmask" const R: u32 = 31; const UMASK: u32 = 0xffffffff << Self::R; const LMASK: u32 = 0xffffffff >> (Self::W - Self::R); // "coefficients of the rational normal form twist matrix" const A: u32 = 0x9908b0df; // TGFSR(R) Tempering bitmasks const B: u32 = 0x9d2c5680; const C: u32 = 0xefc60000; // TGFSR(R) Tempering bit shifts const S: u32 = 7; const T: u32 = 15; // Mersenne Twister bit shifts const U: u32 = 11; const L: u32 = 18; fn initialize_state(state: &mut [u32], seed: u32) { state[0] = seed; let mut current_seed = seed; for idx in 1..Self::N { current_seed = Self::F.wrapping_mul(current_seed ^ (current_seed >> (Self::W - 2))) + idx as u32; state[idx] = current_seed; } } fn initialize_by_array(state: &mut [u32], key: &[u32]) { Self::initialize_state(state, 19650218); let mut i = 1; let mut j = 0; let k = if Self::N > key.len() { Self::N } else { key.len() }; for _ in 0..k { state[i] = (state[i] ^ (state[i - 1] ^ (state[i - 1] >> 30)).wrapping_mul(1664525)) + key[j] + j as u32; i += 1; j += 1; if i >= Self::N { state[0] = state[Self::N - 1]; i = 1; } if j >= key.len() { j = 0; } } for _ in 0..Self::N - 1 { state[i] = (state[i] ^ (state[i - 1] ^ (state[i - 1] >> 30)).wrapping_mul(1566083941)) - i as u32; i += 1; if i >= Self::N { state[0] = state[Self::N - 1]; i = 1; } } state[0] = 0x80000000; } #[allow(dead_code)] pub fn new(seed: u32) -> Self { let mut state = [032; Self::N]; Self::initialize_state(&mut state, seed); Self { state, state_index: 0, } } pub fn new_seed_array(key: &[u32]) -> Self { let mut state = [032; Self::N]; Self::initialize_by_array(&mut state, key); Self { state, state_index: 0, } } #[allow(non_snake_case)] pub fn random_u32(&mut self) -> u32 { let mut j = if self.state_index < Self::N - 1 { self.state_index + 1 } else { self.state_index - (Self::N - 1) }; let mut x = (self.state[self.state_index] & Self::UMASK) | (self.state[j] & Self::LMASK); let mut xA = x >> 1; if x & 0x00000001 > 0 { xA ^= Self::A; } j = if self.state_index < (Self::N - Self::M) { self.state_index + Self::M } else { self.state_index - (Self::N - Self::M) }; x = self.state[j] ^ xA; self.state[self.state_index] = x; self.state_index += 1; if self.state_index >= Self::N { self.state_index = 0; }; let mut y = x ^ (x >> Self::U); // tempering y = y ^ ((y << Self::S) & Self::B); y = y ^ ((y << Self::T) & Self::C); y ^ (y >> Self::L) } pub fn random_f64(&mut self) -> f64 { let rand = self.random_u32(); // [0,1) rand as f64 * (1.0 / (u32::MAX as u64 + 1) as f64) } pub fn random_range(&mut self, range: Range) -> u32 { (self.random_f64() * (range.end as f64 - range.start as f64) + range.start as f64) as u32 } pub fn random_range_f32(&mut self, range: Range) -> f32 { (self.random_f64() * (range.end as f64 - range.start as f64) + range.start as f64) as f32 } } #[cfg(test)] mod test { use std::borrow::BorrowMut; use super::MersenneTwister; #[test] fn mersenne_state_simple_seed_matches() { let mt = MersenneTwister::new(19650218); let numbers = load_MTwikipedia_state(); for (idx, number) in numbers.into_iter().enumerate() { let ours = mt.state[idx]; if ours != number { panic!("reference simple state was {number}, we got {ours}") } } } #[test] fn mersenne_state_array_matches() { let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; let mt = MersenneTwister::new_seed_array(&key); let numbers = load_MTwikipedia_array_state(); for (idx, number) in numbers.into_iter().enumerate() { let ours = mt.state[idx]; if ours != number { panic!("reference array state was {number}, we got {ours}") } } } #[test] fn mersenne_matches() { // Initial array let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; let mut mt = MersenneTwister::new_seed_array(&key); let numbers = load_mt19937ar_cok(); for number in numbers { let ours = mt.random_u32(); if ours != number { panic!("reference was {number}, we got {ours}") } } } #[test] fn mersenne_matches_real() { // Initial array let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; let mut mt = MersenneTwister::new_seed_array(&key); let numbers = load_mt19937ar_cok_real(); // The C program first generated 1000 random u32, so we do it too. for _ in 0..1000 { let _ = mt.random_u32(); } for number in numbers { let ours = mt.random_f64(); let string = format!("{ours:.8}"); if !string.eq(&number) { panic!("reference was '{number}', we got '{string}'") } } } fn load_mt19937ar_cok() -> Vec { let string = std::fs::read_to_string("test/mt19937ar-cok").unwrap(); let mut numbers = vec![]; for line in string.lines() { let mut chars = line.chars(); for _ in 0..5 { let number: String = chars.borrow_mut().take(10).collect(); let _ = chars.next(); // throwout the space let ulong: u32 = number.trim().parse().unwrap(); numbers.push(ulong); } } numbers } fn load_mt19937ar_cok_real() -> Vec { let string = std::fs::read_to_string("test/mt19937ar-cok-real").unwrap(); let mut numbers = vec![]; for line in string.lines() { let mut chars = line.chars(); for _ in 0..5 { let number: String = chars.borrow_mut().take(10).collect(); let _ = chars.next(); // throwout the space numbers.push(number); } } numbers } #[allow(non_snake_case)] fn load_MTwikipedia_array_state() -> Vec { let string = std::fs::read_to_string("test/MTwikipedia-array_state").unwrap(); let mut numbers = vec![]; for line in string.lines() { numbers.push(line.trim().parse().unwrap()); } numbers } #[allow(non_snake_case)] fn load_MTwikipedia_state() -> Vec { let string = std::fs::read_to_string("test/MTwikipedia-state").unwrap(); let mut numbers = vec![]; for line in string.lines() { numbers.push(line.trim().parse().unwrap()); } numbers } }