From b8d3e60fe34b9af53d627456de1be791b4291f85 Mon Sep 17 00:00:00 2001 From: Genevieve Alfirevic Date: Sat, 8 Feb 2025 05:41:19 -0600 Subject: Darnit rustfmt; add random_range_f32 --- src/lib.rs | 553 +++++++++++++++++++++++++++++++------------------------------ 1 file changed, 278 insertions(+), 275 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d2e4b2b..25ed652 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,286 +3,289 @@ use std::ops::Range; /// Mersenne Twister implementation based on the C code from wikipedia /// https://en.wikipedia.org/wiki/Mersenne_Twister#C_code pub struct MersenneTwister { - state: [u32; Self::N], - state_index: usize, + state: [u32; Self::N], + state_index: usize, } impl MersenneTwister { - // "number of recurrence" - const N: usize = 624; - // " middle word, an offset used in the recurrence relation" - const M: usize = 397; - // ? - const F: u32 = 1812433253; - // "word size (in number of bits)" - const W: u32 = 32; - // "separation point of one word, or the number of bits of the lower bitmask" - const R: u32 = 31; - const UMASK: u32 = 0xffffffff << Self::R; - const LMASK: u32 = 0xffffffff >> (Self::W - Self::R); - // "coefficients of the rational normal form twist matrix" - const A: u32 = 0x9908b0df; - - // TGFSR(R) Tempering bitmasks - const B: u32 = 0x9d2c5680; - const C: u32 = 0xefc60000; - - // TGFSR(R) Tempering bit shifts - const S: u32 = 7; - const T: u32 = 15; - - // Mersenne Twister bit shifts - const U: u32 = 11; - const L: u32 = 18; - - fn initialize_state(state: &mut [u32], seed: u32) { - state[0] = seed; - - let mut current_seed = seed; - for idx in 1..Self::N { - current_seed = - Self::F.wrapping_mul(current_seed ^ (current_seed >> (Self::W - 2))) + idx as u32; - state[idx] = current_seed; - } - } - - fn initialize_by_array(state: &mut [u32], key: &[u32]) { - Self::initialize_state(state, 19650218); - let mut i = 1; - let mut j = 0; - - let k = if Self::N > key.len() { - Self::N - } else { - key.len() - }; - - for _ in 0..k { - state[i] = (state[i] ^ (state[i - 1] ^ (state[i - 1] >> 30)).wrapping_mul(1664525)) - + key[j] - + j as u32; - - i += 1; - j += 1; - - if i >= Self::N { - state[0] = state[Self::N - 1]; - i = 1; - } - - if j >= key.len() { - j = 0; - } - } - - for _ in 0..Self::N - 1 { - state[i] = (state[i] ^ (state[i - 1] ^ (state[i - 1] >> 30)).wrapping_mul(1566083941)) - - i as u32; - - i += 1; - - if i >= Self::N { - state[0] = state[Self::N - 1]; - i = 1; - } - } - - state[0] = 0x80000000; - } - - #[allow(dead_code)] - pub fn new(seed: u32) -> Self { - let mut state = [032; Self::N]; - Self::initialize_state(&mut state, seed); - - Self { - state, - state_index: 0, - } - } - - pub fn new_seed_array(key: &[u32]) -> Self { - let mut state = [032; Self::N]; - Self::initialize_by_array(&mut state, key); - - Self { - state, - state_index: 0, - } - } - - #[allow(non_snake_case)] - pub fn random_u32(&mut self) -> u32 { - let mut j = if self.state_index < Self::N - 1 { - self.state_index + 1 - } else { - self.state_index - (Self::N - 1) - }; - - let mut x = (self.state[self.state_index] & Self::UMASK) | (self.state[j] & Self::LMASK); - let mut xA = x >> 1; - - if x & 0x00000001 > 0 { - xA ^= Self::A; - } - - j = if self.state_index < (Self::N - Self::M) { - self.state_index + Self::M - } else { - self.state_index - (Self::N - Self::M) - }; - - x = self.state[j] ^ xA; - self.state[self.state_index] = x; - self.state_index += 1; - - if self.state_index >= Self::N { - self.state_index = 0; - }; - - let mut y = x ^ (x >> Self::U); // tempering - y = y ^ ((y << Self::S) & Self::B); - y = y ^ ((y << Self::T) & Self::C); - - y ^ (y >> Self::L) - } - - pub fn random_f64(&mut self) -> f64 { - let rand = self.random_u32(); - - // [0,1) - rand as f64 * (1.0 / (u32::MAX as u64 + 1) as f64) - } - - pub fn random_range(&mut self, range: Range) -> u32 { - (self.random_f64() * (range.end as f64 - range.start as f64) + range.start as f64) as u32 - } + // "number of recurrence" + const N: usize = 624; + // " middle word, an offset used in the recurrence relation" + const M: usize = 397; + // ? + const F: u32 = 1812433253; + // "word size (in number of bits)" + const W: u32 = 32; + // "separation point of one word, or the number of bits of the lower bitmask" + const R: u32 = 31; + const UMASK: u32 = 0xffffffff << Self::R; + const LMASK: u32 = 0xffffffff >> (Self::W - Self::R); + // "coefficients of the rational normal form twist matrix" + const A: u32 = 0x9908b0df; + + // TGFSR(R) Tempering bitmasks + const B: u32 = 0x9d2c5680; + const C: u32 = 0xefc60000; + + // TGFSR(R) Tempering bit shifts + const S: u32 = 7; + const T: u32 = 15; + + // Mersenne Twister bit shifts + const U: u32 = 11; + const L: u32 = 18; + + fn initialize_state(state: &mut [u32], seed: u32) { + state[0] = seed; + + let mut current_seed = seed; + for idx in 1..Self::N { + current_seed = + Self::F.wrapping_mul(current_seed ^ (current_seed >> (Self::W - 2))) + idx as u32; + state[idx] = current_seed; + } + } + + fn initialize_by_array(state: &mut [u32], key: &[u32]) { + Self::initialize_state(state, 19650218); + let mut i = 1; + let mut j = 0; + + let k = if Self::N > key.len() { + Self::N + } else { + key.len() + }; + + for _ in 0..k { + state[i] = (state[i] ^ (state[i - 1] ^ (state[i - 1] >> 30)).wrapping_mul(1664525)) + + key[j] + j as u32; + + i += 1; + j += 1; + + if i >= Self::N { + state[0] = state[Self::N - 1]; + i = 1; + } + + if j >= key.len() { + j = 0; + } + } + + for _ in 0..Self::N - 1 { + state[i] = (state[i] ^ (state[i - 1] ^ (state[i - 1] >> 30)).wrapping_mul(1566083941)) + - i as u32; + + i += 1; + + if i >= Self::N { + state[0] = state[Self::N - 1]; + i = 1; + } + } + + state[0] = 0x80000000; + } + + #[allow(dead_code)] + pub fn new(seed: u32) -> Self { + let mut state = [032; Self::N]; + Self::initialize_state(&mut state, seed); + + Self { + state, + state_index: 0, + } + } + + pub fn new_seed_array(key: &[u32]) -> Self { + let mut state = [032; Self::N]; + Self::initialize_by_array(&mut state, key); + + Self { + state, + state_index: 0, + } + } + + #[allow(non_snake_case)] + pub fn random_u32(&mut self) -> u32 { + let mut j = if self.state_index < Self::N - 1 { + self.state_index + 1 + } else { + self.state_index - (Self::N - 1) + }; + + let mut x = (self.state[self.state_index] & Self::UMASK) | (self.state[j] & Self::LMASK); + let mut xA = x >> 1; + + if x & 0x00000001 > 0 { + xA ^= Self::A; + } + + j = if self.state_index < (Self::N - Self::M) { + self.state_index + Self::M + } else { + self.state_index - (Self::N - Self::M) + }; + + x = self.state[j] ^ xA; + self.state[self.state_index] = x; + self.state_index += 1; + + if self.state_index >= Self::N { + self.state_index = 0; + }; + + let mut y = x ^ (x >> Self::U); // tempering + y = y ^ ((y << Self::S) & Self::B); + y = y ^ ((y << Self::T) & Self::C); + + y ^ (y >> Self::L) + } + + pub fn random_f64(&mut self) -> f64 { + let rand = self.random_u32(); + + // [0,1) + rand as f64 * (1.0 / (u32::MAX as u64 + 1) as f64) + } + + pub fn random_range(&mut self, range: Range) -> u32 { + (self.random_f64() * (range.end as f64 - range.start as f64) + range.start as f64) as u32 + } + + pub fn random_range_f32(&mut self, range: Range) -> f32 { + (self.random_f64() * (range.end as f64 - range.start as f64) + range.start as f64) as f32 + } } #[cfg(test)] mod test { - use std::borrow::BorrowMut; - - use super::MersenneTwister; - - #[test] - fn mersenne_state_simple_seed_matches() { - let mt = MersenneTwister::new(19650218); - let numbers = load_MTwikipedia_state(); - - for (idx, number) in numbers.into_iter().enumerate() { - let ours = mt.state[idx]; - if ours != number { - panic!("reference simple state was {number}, we got {ours}") - } - } - } - - #[test] - fn mersenne_state_array_matches() { - let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; - let mt = MersenneTwister::new_seed_array(&key); - let numbers = load_MTwikipedia_array_state(); - - for (idx, number) in numbers.into_iter().enumerate() { - let ours = mt.state[idx]; - if ours != number { - panic!("reference array state was {number}, we got {ours}") - } - } - } - - #[test] - fn mersenne_matches() { - // Initial array - let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; - let mut mt = MersenneTwister::new_seed_array(&key); - let numbers = load_mt19937ar_cok(); - - for number in numbers { - let ours = mt.random_u32(); - if ours != number { - panic!("reference was {number}, we got {ours}") - } - } - } - - #[test] - fn mersenne_matches_real() { - // Initial array - let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; - let mut mt = MersenneTwister::new_seed_array(&key); - let numbers = load_mt19937ar_cok_real(); - - // The C program first generated 1000 random u32, so we do it too. - for _ in 0..1000 { - let _ = mt.random_u32(); - } - - for number in numbers { - let ours = mt.random_f64(); - let string = format!("{ours:.8}"); - if !string.eq(&number) { - panic!("reference was '{number}', we got '{string}'") - } - } - } - - fn load_mt19937ar_cok() -> Vec { - let string = std::fs::read_to_string("test/mt19937ar-cok").unwrap(); - let mut numbers = vec![]; - - for line in string.lines() { - let mut chars = line.chars(); - for _ in 0..5 { - let number: String = chars.borrow_mut().take(10).collect(); - let _ = chars.next(); // throwout the space - - let ulong: u32 = number.trim().parse().unwrap(); - numbers.push(ulong); - } - } - - numbers - } - - fn load_mt19937ar_cok_real() -> Vec { - let string = std::fs::read_to_string("test/mt19937ar-cok-real").unwrap(); - let mut numbers = vec![]; - - for line in string.lines() { - let mut chars = line.chars(); - for _ in 0..5 { - let number: String = chars.borrow_mut().take(10).collect(); - let _ = chars.next(); // throwout the space - numbers.push(number); - } - } - - numbers - } - - #[allow(non_snake_case)] - fn load_MTwikipedia_array_state() -> Vec { - let string = std::fs::read_to_string("test/MTwikipedia-array_state").unwrap(); - let mut numbers = vec![]; - - for line in string.lines() { - numbers.push(line.trim().parse().unwrap()); - } - - numbers - } - - #[allow(non_snake_case)] - fn load_MTwikipedia_state() -> Vec { - let string = std::fs::read_to_string("test/MTwikipedia-state").unwrap(); - let mut numbers = vec![]; - - for line in string.lines() { - numbers.push(line.trim().parse().unwrap()); - } - - numbers - } + use std::borrow::BorrowMut; + + use super::MersenneTwister; + + #[test] + fn mersenne_state_simple_seed_matches() { + let mt = MersenneTwister::new(19650218); + let numbers = load_MTwikipedia_state(); + + for (idx, number) in numbers.into_iter().enumerate() { + let ours = mt.state[idx]; + if ours != number { + panic!("reference simple state was {number}, we got {ours}") + } + } + } + + #[test] + fn mersenne_state_array_matches() { + let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; + let mt = MersenneTwister::new_seed_array(&key); + let numbers = load_MTwikipedia_array_state(); + + for (idx, number) in numbers.into_iter().enumerate() { + let ours = mt.state[idx]; + if ours != number { + panic!("reference array state was {number}, we got {ours}") + } + } + } + + #[test] + fn mersenne_matches() { + // Initial array + let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; + let mut mt = MersenneTwister::new_seed_array(&key); + let numbers = load_mt19937ar_cok(); + + for number in numbers { + let ours = mt.random_u32(); + if ours != number { + panic!("reference was {number}, we got {ours}") + } + } + } + + #[test] + fn mersenne_matches_real() { + // Initial array + let key: [u32; 4] = [0x123, 0x234, 0x345, 0x456]; + let mut mt = MersenneTwister::new_seed_array(&key); + let numbers = load_mt19937ar_cok_real(); + + // The C program first generated 1000 random u32, so we do it too. + for _ in 0..1000 { + let _ = mt.random_u32(); + } + + for number in numbers { + let ours = mt.random_f64(); + let string = format!("{ours:.8}"); + if !string.eq(&number) { + panic!("reference was '{number}', we got '{string}'") + } + } + } + + fn load_mt19937ar_cok() -> Vec { + let string = std::fs::read_to_string("test/mt19937ar-cok").unwrap(); + let mut numbers = vec![]; + + for line in string.lines() { + let mut chars = line.chars(); + for _ in 0..5 { + let number: String = chars.borrow_mut().take(10).collect(); + let _ = chars.next(); // throwout the space + + let ulong: u32 = number.trim().parse().unwrap(); + numbers.push(ulong); + } + } + + numbers + } + + fn load_mt19937ar_cok_real() -> Vec { + let string = std::fs::read_to_string("test/mt19937ar-cok-real").unwrap(); + let mut numbers = vec![]; + + for line in string.lines() { + let mut chars = line.chars(); + for _ in 0..5 { + let number: String = chars.borrow_mut().take(10).collect(); + let _ = chars.next(); // throwout the space + numbers.push(number); + } + } + + numbers + } + + #[allow(non_snake_case)] + fn load_MTwikipedia_array_state() -> Vec { + let string = std::fs::read_to_string("test/MTwikipedia-array_state").unwrap(); + let mut numbers = vec![]; + + for line in string.lines() { + numbers.push(line.trim().parse().unwrap()); + } + + numbers + } + + #[allow(non_snake_case)] + fn load_MTwikipedia_state() -> Vec { + let string = std::fs::read_to_string("test/MTwikipedia-state").unwrap(); + let mut numbers = vec![]; + + for line in string.lines() { + numbers.push(line.trim().parse().unwrap()); + } + + numbers + } } -- cgit 1.4.1-3-g733a5