Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MtPython RNG #202

Draft
wants to merge 2 commits into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,11 @@ use core::fmt;

pub use crate::mt::Mt19937GenRand32;
pub use crate::mt64::Mt19937GenRand64;
pub use crate::python::Mt19937GenRandPython;

mod mt;
mod mt64;
mod python;
#[cfg(test)]
mod vectors;

Expand All @@ -124,6 +126,9 @@ pub type Mt = Mt19937GenRand32;
/// A type alias for [`Mt19937GenRand64`], 64-bit Mersenne Twister.
pub type Mt64 = Mt19937GenRand64;

/// A type alias for [`Mt19937GenRandPython`], Python-compatible Mersenne Twiser.
pub type MtPython = Mt19937GenRandPython;

/// Error returned from fallible Mersenne Twister recovery constructors.
///
/// When the `std` feature is enabled, this type implements `std::error::Error`.
Expand Down Expand Up @@ -182,7 +187,7 @@ mod tests {
RecoverRngError::TooManySamples(0),
RecoverRngError::TooManySamples(987),
];
for tc in test_cases {
for tc in &test_cases {
let mut buf = String::new();
write!(&mut buf, "{}", tc).unwrap();
assert!(!buf.is_empty());
Expand Down
183 changes: 183 additions & 0 deletions src/python.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
// src/python.rs
//
// Copyright (c) 2015,2017 rust-mersenne-twister developers
// Copyright (c) 2020 Ryan Lopopolo <rjl@hyperbo.la>
// Copyright (c) 2023 Ignacy Sewastianowicz
//
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE> or <http://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT> or <http://opensource.org/licenses/MIT>, at your
// option. All files in the project carrying such notice may not be copied,
// modified, or distributed except according to those terms.

use core::fmt::Debug;

use crate::Mt19937GenRand32;

/// CPython compatible implementation of the Mersenne Twister pseudorandom
/// number generator.
///
/// It is esentially a [Mt19937GenRand32] with slight modifications to match
/// behavior of the CPython's `random` module.
///
/// # Size
///
/// `Mt19937GenRandPython` requires approximately 2.5 kilobytes of internal state.
///
/// You may wish to store an `Mt19937GenRandPython` on the heap in a [`Box`] to make
/// it easier to embed in another struct.
///
/// `Mt19937GenRandPython` is also the same size as
/// [`Mt19937GenRand32`](crate::Mt19937GenRand32) and
/// [`Mt19937GenRand64`](crate::Mt19937GenRand64)
///
/// ```
/// # use core::mem;
/// # use rand_mt::{Mt19937GenRand32, Mt19937GenRand64, Mt19937GenRandPython};
/// assert_eq!(2504, mem::size_of::<Mt19937GenRandPython>());
/// assert_eq!(mem::size_of::<Mt19937GenRand64>(), mem::size_of::<Mt19937GenRand32>());
/// assert_eq!(mem::size_of::<Mt19937GenRand64>(), mem::size_of::<Mt19937GenRandPython>());
/// ```
#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct Mt19937GenRandPython {
inner: Mt19937GenRand32,
}

impl Mt19937GenRandPython {
/// Create a new Mersenne Twister random number generator using the given
/// seed.
///
/// This is equivalent to [random.seed](https://github.com/python/cpython/blob/3.11/Modules/_randommodule.c#L275) in CPython.
///
/// # Examples: TODO
///
/// ## Constructing with a `u32` seed
///
/// ```
/// # use rand_mt::Mt19937GenRand32;
/// let seed = 123_456_789_u32;
/// let mt1 = Mt19937GenRand32::new(seed);
/// let mt2 = Mt19937GenRand32::from(seed.to_le_bytes());
/// assert_eq!(mt1, mt2);
/// ```
#[inline]
#[must_use]
pub fn new(seed: u32) -> Self {
let key: [u32; 1] = [seed];
Self::new_with_key(key.iter().copied())
}

/// Create a new Mersenne Twister random number generator using the given
/// key.
///
/// Key can have any length.
#[inline]
#[must_use]
pub fn new_with_key<I>(key: I) -> Self
where
I: IntoIterator<Item = u32>,
I::IntoIter: Clone,
{
Self {
inner: Mt19937GenRand32::new_with_key(key),
}
}

/// Generate next 'f64' output.
///
/// This function will generate a random number on 0..1
/// with 53 bit resolution
///
/// It is compatible with CPython's `random.random`
///
/// CPython implementation: <https://github.com/python/cpython/blob/3.11/Modules/_randommodule.c#L181>
///
/// # Examples: TODO
///
/// ```
/// # use rand_mt::Mt19937GenRandPython;
/// let mut mt = Mt19937GenRandPython::new(1);
/// assert_ne!(mt.next_f64(), mt.next_f64());
/// ```
#[inline]
pub fn next_f64(&mut self) -> f64 {
let a = self.inner.next_u32() >> 5;
let b = self.inner.next_u32() >> 6;

(a as f64 * 67108864.0 + b as f64) * (1.0 / 9007199254740992.0)
}

/// Get n random bytes converted into an integer.
///
/// This method is compatible with CPython's `random.getrandbits` for n <= 128
///
/// # Examples: TODO
///
/// # Panics
///
/// The function panics if n is bigger than 128
///
pub fn getrandbits(&mut self, n: usize) -> u128 {
if n > 128 {
panic!(
"Can't generate higher integer than 128 bits. {} bits were given.",
n
);
} else if n <= 32 {
return (self.inner.next_u32() >> (32 - n)) as u128;
} else if n == 0 {
return 0;
}

let mut result: u128 = 0;
let mut shift: u128 = 0;

let mut j = n as i32;
let words = (n - 1) / 32 + 1;
for _ in 0..words {
let mut r = self.inner.next_u32();
if j < 32 {
r >>= 32 - j;
}

result |= (r as u128) << shift;

shift += 32;
if shift >= 128 {
break;
}

j -= 32;
}

return result;
}

/// Fill a buffer with bytes generated from the RNG.
///
/// This method generates random `u32`s (the native output unit of the RNG)
/// until `dest` is filled.
///
/// It is compatible with CPython's `random.randbytes`
///
/// # Examples: TODO
///
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
let words = dest.len() + 15 / 16;
for i in 0..words {
let start = i * 16;
let end = usize::min(start + 16, dest.len());

let n = end as isize - start as isize;
if n <= 0 {
break;
}
let n = n as usize;

let bytes = self.getrandbits(n * 8).to_le_bytes();
for j in start..end {
dest[j] = bytes[j - start]
}
}
}
}
Empty file added src/python/rand.rs
Empty file.
80 changes: 80 additions & 0 deletions tests/python_reproducibility.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// tests/python_reproducibility.rs

use rand_mt::MtPython;

// ```console
// $ python -c " \
// import random; \
// random.seed(1); \
// print(bytearray(random.randbytes(1))[0])"
// 34
// ```
// #[test]
// fn spec_bytes() {
// let mut rng = MtPython::new(1);
// let mut buf = [0; 1];
// rng.fill_bytes(&mut buf);
// println!("{:?}", buf);
// }

// ```python
// import random
// random.seed(1)
// n = random.getrandbits(128)
// ```
#[test]
fn getrandbits() {
let mut rng = MtPython::new(1);
let n = rng.getrandbits(128);

assert_eq!(n, 272996653310673477252411125948039410165);
}

// ```python
// import random
// random.seed(1)
// a = list(random.randbytes(3))
// b = list(random.randbytes(32))
// c = list(random.randbytes(64))
// d = list(random.randbytes(128))
// ```
#[test]
fn fill_bytes() {
let mut rng = MtPython::new(1);

macro_rules! test_array {
($comp: expr) => {
let mut buf = [0u8; $comp.len()];
rng.fill_bytes(&mut buf);
assert_eq!(buf, $comp);
};
}

// a
test_array!([177, 101, 34]);

// b
test_array!([
74, 88, 183, 145, 223, 106, 241, 216, 48, 62, 97, 205, 196, 187, 134, 195, 209, 196, 39,
16, 60, 52, 76, 65, 137, 235, 47, 30, 123, 213, 212, 126,
]);

// c
test_array!([
68, 111, 206, 194, 163, 216, 17, 115, 97, 16, 229, 120, 27, 204, 206, 166, 150, 118, 46,
97, 22, 198, 233, 201, 45, 153, 191, 53, 140, 46, 7, 24, 130, 44, 228, 124, 168, 199, 65,
7, 230, 108, 176, 228, 178, 179, 244, 213, 141, 130, 202, 99, 134, 210, 201, 110, 118, 14,
129, 155, 133, 201, 36, 195,
]);

// d
test_array!([
89, 113, 100, 196, 166, 5, 138, 0, 88, 26, 34, 178, 45, 229, 4, 114, 67, 61, 46, 68, 254,
216, 182, 184, 53, 126, 68, 205, 49, 41, 144, 58, 193, 212, 85, 151, 162, 66, 253, 241, 31,
143, 43, 26, 57, 243, 195, 230, 147, 17, 67, 81, 220, 190, 212, 7, 227, 230, 182, 5, 185,
158, 131, 6, 221, 167, 72, 166, 30, 2, 154, 138, 63, 65, 91, 2, 74, 20, 108, 240, 217, 138,
152, 225, 207, 153, 150, 97, 249, 103, 189, 175, 223, 14, 115, 55, 66, 12, 19, 248, 245,
212, 15, 108, 224, 121, 209, 185, 135, 55, 111, 7, 188, 184, 18, 135, 253, 200, 192, 56,
143, 232, 129, 195, 160, 102, 25, 112,
]);
}