-
Notifications
You must be signed in to change notification settings - Fork 109
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add Mersenne Twister PRNG (#984)
- Loading branch information
Showing
4 changed files
with
245 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import Batteries.Data.Random.MersenneTwister |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/- | ||
Copyright (c) 2024 François G. Dorais. All rights reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: François G. Dorais | ||
-/ | ||
import Batteries.Data.Vector | ||
|
||
/-! # Mersenne Twister | ||
Generic implementation for the Mersenne Twister pseudorandom number generator. | ||
All choices of parameters from Matsumoto and Nishimura (1998) are supported, along with later | ||
refinements. Parameters for the standard 32-bit MT19937 and 64-bit MT19937-64 algorithms are | ||
provided. Both `RandomGen` and `Stream` interfaces are provided. | ||
Use `mt19937.init seed` to create a MT19937 PRNG with a 32 bit seed value; use | ||
`mt19937_64.init seed` to create a MT19937-64 PRNG with a 64 bit seed value. If omitted, default | ||
seed choices will be used. | ||
Sample usage: | ||
``` | ||
import Batteries.Data.Random.MersenneTwister | ||
open Batteries.Random.MersenneTwister | ||
def mtgen := mt19937.init -- default seed 4357 | ||
#eval (Stream.take mtgen 5).fst -- [874448474, 2424656266, 2174085406, 1265871120, 3155244894] | ||
``` | ||
### References: | ||
- Matsumoto, Makoto and Nishimura, Takuji (1998), | ||
[**Mersenne twister: A 623-dimensionally equidistributed uniform pseudo-random number generator**](https://doi.org/10.1145/272991.272995), | ||
ACM Trans. Model. Comput. Simul. 8, No. 1, 3-30. | ||
[ZBL0917.65005](https://zbmath.org/?q=an:0917.65005). | ||
- Nishimura, Takuji (2000), | ||
[**Tables of 64-bit Mersenne twisters**](https://doi.org/10.1145/369534.369540), | ||
ACM Trans. Model. Comput. Simul. 10, No. 4, 348-357. | ||
[ZBL1390.65014](https://zbmath.org/?q=an:1390.65014). | ||
-/ | ||
|
||
namespace Batteries.Random.MersenneTwister | ||
|
||
/-- | ||
Mersenne Twister configuration. | ||
Letters in parentheses correspond to variable names used by Matsumoto and Nishimura (1998) and | ||
Nishimura (2000). | ||
-/ | ||
structure Config where | ||
/-- Word size (`w`). -/ | ||
wordSize : Nat | ||
/-- Degree of recurrence (`n`). -/ | ||
stateSize : Nat | ||
/-- Middle word (`m`). -/ | ||
shiftSize : Fin stateSize | ||
/-- Twist value (`r`). -/ | ||
maskBits : Fin wordSize | ||
/-- Coefficients of the twist matrix (`a`). -/ | ||
xorMask : BitVec wordSize | ||
/-- Tempering shift parameters (`u`, `s`, `t`, `l`). -/ | ||
temperingShifts : Nat × Nat × Nat × Nat | ||
/-- Tempering mask parameters (`d`, `b`, `c`). -/ | ||
temperingMasks : BitVec wordSize × BitVec wordSize × BitVec wordSize | ||
/-- Initialization multiplier (`f`). -/ | ||
initMult : BitVec wordSize | ||
/-- Default initialization seed value. -/ | ||
initSeed : BitVec wordSize | ||
|
||
private abbrev Config.uMask (cfg : Config) : BitVec cfg.wordSize := | ||
BitVec.allOnes cfg.wordSize <<< cfg.maskBits.val | ||
|
||
private abbrev Config.lMask (cfg : Config) : BitVec cfg.wordSize := | ||
BitVec.allOnes cfg.wordSize >>> (cfg.wordSize - cfg.maskBits.val) | ||
|
||
@[simp] theorem Config.zero_lt_wordSize (cfg : Config) : 0 < cfg.wordSize := | ||
Nat.zero_lt_of_lt cfg.maskBits.is_lt | ||
|
||
@[simp] theorem Config.zero_lt_stateSize (cfg : Config) : 0 < cfg.stateSize := | ||
Nat.zero_lt_of_lt cfg.shiftSize.is_lt | ||
|
||
/-- Mersenne Twister State. -/ | ||
structure State (cfg : Config) where | ||
/-- Data for current state. -/ | ||
data : Vector (BitVec cfg.wordSize) cfg.stateSize | ||
/-- Current data index. -/ | ||
index : Fin cfg.stateSize | ||
|
||
/-- Mersenne Twister initialization given an optional seed. -/ | ||
@[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config) | ||
(seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg := | ||
⟨loop seed (.mkEmpty cfg.stateSize) (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩ | ||
where | ||
/-- Inner loop for Mersenne Twister initalization. -/ | ||
loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) := | ||
if heq : v.size = cfg.stateSize then ⟨v, heq⟩ else | ||
let v := v.push w | ||
let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size | ||
loop w v (by simp only [v, Array.size_push]; omega) | ||
|
||
/-- Apply the twisting transformation to the given state. -/ | ||
@[specialize cfg] protected def State.twist (state : State cfg) : State cfg := | ||
let i := state.index | ||
let i' : Fin cfg.stateSize := | ||
if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩ | ||
let y := state.data[i] &&& cfg.uMask ||| state.data[i'] &&& cfg.lMask | ||
let x := state.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1 | ||
⟨state.data.set i x, i'⟩ | ||
|
||
/-- Update the state by a number of generation steps (default 1). -/ | ||
-- TODO: optimize to `O(log(steps))` using the minimal polynomial | ||
protected def State.update (state : State cfg) : (steps : Nat := 1) → State cfg | ||
| 0 => state | ||
| steps+1 => state.twist.update steps | ||
|
||
/-- Mersenne Twister iteration. -/ | ||
@[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg := | ||
let i := state.index | ||
let s := state.twist | ||
(temper s.data[i], s) | ||
where | ||
/-- Tempering step for Mersenne Twister. -/ | ||
@[inline] temper (x : BitVec cfg.wordSize) := | ||
match cfg.temperingShifts, cfg.temperingMasks with | ||
| (u, s, t, l), (d, b, c) => | ||
let x := x ^^^ x >>> u &&& d | ||
let x := x ^^^ x <<< s &&& b | ||
let x := x ^^^ x <<< t &&& c | ||
x ^^^ x >>> l | ||
|
||
instance (cfg) : Stream (State cfg) (BitVec cfg.wordSize) where | ||
next? s := s.next | ||
|
||
/-- 32 bit Mersenne Twister (MT19937) configuration. -/ | ||
def mt19937 : Config where | ||
wordSize := 32 | ||
stateSize := 624 | ||
shiftSize := 397 | ||
maskBits := 31 | ||
xorMask := 0x9908b0df | ||
temperingShifts := (11, 7, 15, 18) | ||
temperingMasks := (0xffffffff, 0x9d2c5680, 0xefc60000) | ||
initMult := 1812433253 | ||
initSeed := 4357 | ||
|
||
/-- 64 bit Mersenne Twister (MT19937-64) configuration. -/ | ||
def mt19937_64 : Config where | ||
wordSize := 64 | ||
stateSize := 312 | ||
shiftSize := 156 | ||
maskBits := 31 | ||
xorMask := 0xb5026f5aa96619e9 | ||
temperingShifts := (29, 17, 37, 43) | ||
temperingMasks := (0x5555555555555555, 0x71d67fffeda60000, 0xfff7eee000000000) | ||
initMult := 6364136223846793005 | ||
initSeed := 19650218 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import Batteries.Data.Random.MersenneTwister | ||
import Batteries.Data.Stream | ||
|
||
open Batteries.Random.MersenneTwister | ||
|
||
#guard (Stream.take mt19937.init 5).1 == [874448474, 2424656266, 2174085406, 1265871120, 3155244894] | ||
|
||
/- Sample output was generated using `numpy`'s implementation of MT19937: | ||
```python | ||
from numpy import array, uint32 | ||
from numpy.random import MT19937 | ||
mt = MT19937() | ||
mt.state = { | ||
'bit_generator' : 'MT19937', | ||
'state' : { | ||
'pos' : 624, | ||
'key' : array([ | ||
4357, 1673174024, 1301878288, 1129097449, 2180885271, 2495295730, 3729202114, 3451529139, 2624228201, 696045212, | ||
2296245684, 4097888573, 2110311931, 1672374534, 381896678, 2887874951, 3859861197, 420983856, 1691952728, 4233606289, | ||
1707944415, 3515687962, 4265198858, 1433261659, 1131854641, 228846788, 3811811324, 873525989, 588291779, 2854617646, | ||
948269870, 3798261295, 3422826645, 340138072, 3671734944, 3961007161, 2839350439, 3264455490, 310719058, 2570596611, | ||
3750039289, 648992492, 3816674884, 2210726029, 371217291, 196912982, 3046892150, 470118103, 1302935133, 362465408, | ||
1360220904, 2946174945, 1630294895, 3570642538, 1798333338, 1196832683, 226789057, 2740096276, 1062441100, 1875507765, | ||
2599873619, 1037523070, 4029519294, 3231722367, 2232344613, 3458909352, 2906353456, 3064815497, 3166305847, | ||
3658630546, 3632421090, 885320275, 1621369481, 1258557244, 2827734740, 3209486301, 131295515, 2191201702, 44141830, | ||
1183978535, 4202966509, 801836240, 2303299448, 333191985, 4114943231, 1490315450, 453120554, 759253243, 1381163601, | ||
3455606116, 1027445020, 1144697221, 3040135651, 4176273102, 798935118, 49817807, 2492997557, 3171983608, 2742334400, | ||
1282687705, 1047297991, 3697219554, 1400278898, 3276297123, 843040281, 354711436, 4156544868, 2873126701, 3990490795, | ||
3966874614, 1376536470, 4189022583, 2283386237, 3645931808, 1312021512, 679663233, 3054458511, 1152865034, 1927729338, | ||
538380875, 374984161, 2453495220, 514433452, 1271601365, 3737270131, 630101278, 1292962526, 2908018207, 1209528133, | ||
413117768, 3762161744, 2194986537, 1414304087, 379722290, 2862208514, 3551161587, 3402627497, 2411204572, 3033657332, | ||
4161252989, 2267825211, 963150406, 2081690150, 4014304967, 1977732365, 2412979568, 613038232, 418857425, 3682807839, | ||
3416550746, 3692470090, 2764012443, 3255912817, 2160692740, 3914318396, 3437441061, 2828481795, 3655629678, 582770030, | ||
2946380655, 3506851541, 612362648, 3394202848, 1530337657, 3360830183, 570641538, 153365650, 1624454723, 80526649, | ||
1365694508, 2272925828, 34250189, 3066169803, 631734422, 3706776758, 3443270679, 659846301, 3707435456, 3573851432, | ||
1017208097, 1100519855, 1824765866, 3284762074, 2887949547, 569464065, 3057970772, 1726477004, 3119183733, 3349922451, | ||
4162228670, 249085950, 3854319807, 1155219045, 811161064, 207675760, 50531529, 141911159, 3819613906, 2655884066, | ||
3517624211, 514724041, 2094583932, 3681571092, 3518053661, 2207473499, 961982182, 1423628102, 628853095, 3823741997, | ||
1450180112, 1817911736, 384378993, 1749521215, 4080873978, 2604100714, 2468900411, 1718743185, 3679944356, 623522652, | ||
2974445253, 351789091, 776787982, 4087231118, 395771407, 2634989045, 2547249720, 2502583808, 3550523417, 648947207, | ||
2361409826, 2639137202, 4179155171, 3136025689, 3233151180, 3765213604, 459508845, 412632299, 3365801270, 1208603094, | ||
1978375863, 3608769469, 2648322656, 994422344, 1463198657, 1938300111, 1983437898, 3617090298, 582545291, 604707873, | ||
615071476, 1976468460, 4251555349, 2373160371, 4138683998, 927249694, 4178996063, 3071856005, 3264724616, 2539911824, | ||
1383596905, 3639900055, 2590770034, 1029541954, 369472051, 3757991913, 1470517532, 2317808180, 1065978813, 3301489275, | ||
4087716742, 2662718566, 678716423, 274451277, 1625396912, 3598469848, 3639725841, 726808159, 1490990746, 4062476682, | ||
2411471067, 1395972017, 1390554948, 1854727292, 2494590309, 1377225539, 2540041390, 3288614830, 706906287, 1416719637, | ||
609008344, 2311429920, 821102265, 2034260263, 3587569090, 3115591378, 3545840515, 4166871929, 139581804, 2421643972, | ||
1250638605, 4212965387, 2794805718, 3306616566, 2466109783, 2200482525, 1496197888, 381089640, 2743249505, 4221427695, | ||
1247199466, 1746114586, 2065302059, 1348936513, 2997505940, 3911013644, 428274869, 2816055507, 580438782, 135588414, | ||
916674047, 445684901, 1016784680, 654791600, 1282652681, 92916407, 1411782674, 1367985506, 1207661779, 3531669257, | ||
627085756, 1857409876, 4107311709, 1384928667, 2576697382, 2875531654, 4151312039, 116927085, 1281879888, 414036984, | ||
3931190705, 4100135295, 1170799418, 3130902186, 4055536507, 3692691153, 480878564, 2201474460, 3663014917, 4155766371, | ||
1987039566, 4121861326, 2525025103, 2465094709, 2536129400, 1843468352, 2926058841, 533253191, 1988389474, 1209435122, | ||
4141112867, 2699109017, 2373614092, 1694129124, 2730600877, 2249161515, 1355638390, 3319290902, 2209534967, | ||
1463955965, 204923808, 1025015944, 214266113, 3382305551, 2455594378, 1861944634, 1820710091, 449145441, 4119339060, | ||
2660525612, 3515028309, 3466454003, 1024657310, 50945886, 2913140895, 721595333, 3416444872, 2701847760, 2352361641, | ||
234184151, 3927502002, 3834792578, 3469473651, 4193637929, 2873594460, 1994191988, 1690724605, 1956524219, 476427462, | ||
212379302, 1370380615, 327076237, 1984104432, 682581272, 2521259089, 3543809183, 3275489242, 241390538, 3496199707, | ||
2497799665, 770560132, 1626015420, 2776148645, 3717161347, 3970592238, 710750702, 3421625839, 876972885, 2108460056, | ||
1195168096, 1195766777, 3121053543, 2819333890, 1916084498, 717897923, 3627489721, 1970264748, 1813355780, 4148615245, | ||
556824139, 411448086, 4228776246, 1732939415, 3206934813, 1949588544, 3291105704, 1044314017, 222045743, 3079457322, | ||
638497370, 1849452395, 921039233, 1115861204, 3019093836, 2828923381, 4185943827, 3344827454, 3923907710, 760572735, | ||
3828284133, 1559197800, 724485616, 1828677449, 2985767159, 4119101778, 1077348258, 3518446099, 2585587017, 1855673084, | ||
3495712148, 3265984413, 2998815707, 760668518, 2487249862, 3060757479, 3249514669, 4222804112, 1010910776, 3893641969, | ||
395812799, 2591540346, 1194664170, 49789115, 1363873041, 1005502756, 1164343260, 3646613829, 459869347, 3679832718, | ||
1137706766, 4189431951, 1412889205, 622040248, 1536739968, 3066727065, 666661511, 1672188834, 2714762802, 4135248739, | ||
35606745, 2775710540, 4083752484, 3680159469, 1950331243, 251641782, 1501029974, 486869303, 1720971325, 241603808, | ||
28070600, 2737782337, 910469455, 3810848458, 118398842, 3078470155, 2559096993, 2933522804, 2264615020, 3793195157, | ||
1614887475, 45727966, 3193899422, 1157273055, 2178255365, 2646663432, 724754192, 168779241, 4048503831, 3483948530, | ||
3996648642, 939343027, 917914729, 3030111132, 3908302516, 29247037, 3568084731, 1034472966, 1408004326, 1693666951, | ||
3712665549, 3120003376, 3374542680, 2868373905, 1362838239, 1421625626, 4275252746, 548825947, 622261297, 3152835012, | ||
2926192892, 423356389, 151058371, 3820087086, 1673993262, 252457775, 1317185941, 2594135384, 817169312, 2016796985, | ||
2292688295, 1654933570, 2158435154, 2703640067, 3260663801, 3267419116, 2293555012, 2721936781, 1727868043, 91884630, | ||
265685878, 1143096279, 961294173, 403541376, 2338233320, 1725318369, 4101205103, 4268086122, 3418016922, 1065995435, | ||
1936572353, 265163284, 3043694988, 2167402293, 2057323859, 4033232254, 3258990270, 1137868927, 2142656805, 4216785320, | ||
1188509744, 1051071625, 196974391, 2445666962, 3092595170, 2833121107, 2474761097, 2190021692, 1852037076, 3577763037, | ||
3794354715, 2124118694, 2641147398, 1551493415, 1913661165, 1313919440, 2232801400, 1781682225, 1340417535, 994676154, | ||
251493162, 2162155003, 1678056273, 3810976356, 1505106460, 3361449605, 1041703651, 1727972302, 3959583054, 3140845007, | ||
3202914485, 2878334456, 2354150592, 3334993881, 1015617735, 506838242, 4168775794, 839674019, 4238769945, 849116300, | ||
4189642852, 1596908589, 556328875, 2369067254, 2431152278, 1004682871], dtype=uint32)}} | ||
print(mt.random_raw(5)) | ||
``` | ||
-/ |