From b5c9d8e478db664d439996dc6569c8180801b0f4 Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Fri, 11 Oct 2024 22:44:07 -0400 Subject: [PATCH 1/8] feat: add Mersenne Twister pseudorandom generator --- Batteries.lean | 1 + Batteries/Data/Random.lean | 1 + Batteries/Data/Random/MersenneTwister.lean | 143 +++++++++++++++++++++ 3 files changed, 145 insertions(+) create mode 100644 Batteries/Data/Random.lean create mode 100644 Batteries/Data/Random/MersenneTwister.lean diff --git a/Batteries.lean b/Batteries.lean index f04a7a1f76..613fb429e4 100644 --- a/Batteries.lean +++ b/Batteries.lean @@ -29,6 +29,7 @@ import Batteries.Data.MLList import Batteries.Data.Nat import Batteries.Data.PairingHeap import Batteries.Data.RBMap +import Batteries.Data.Random import Batteries.Data.Range import Batteries.Data.Rat import Batteries.Data.String diff --git a/Batteries/Data/Random.lean b/Batteries/Data/Random.lean new file mode 100644 index 0000000000..cf1e720ee0 --- /dev/null +++ b/Batteries/Data/Random.lean @@ -0,0 +1 @@ +import Batteries.Data.Random.MersenneTwister diff --git a/Batteries/Data/Random/MersenneTwister.lean b/Batteries/Data/Random/MersenneTwister.lean new file mode 100644 index 0000000000..0425367a98 --- /dev/null +++ b/Batteries/Data/Random/MersenneTwister.lean @@ -0,0 +1,143 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: François G. Dorais +-/ +import Batteries.Data.Vector + +/-! # Mersenne Twister + +Reference implementation for the Mersenne Twister pseudorandom number generator. + +### References: + +- Matsumoto, Makoto and Nishimura, Takuji (1998), + [**Mersenne twister: A 623-dimensionally equidistributed uniform pseudo-random number generator**](https://doi.org/10.1145/272991.272995), + ACM Trans. Model. Comput. Simul. 8, No. 1, 3-30. + [ZBL0917.65005](https://zbmath.org/?q=an:0917.65005). + +- Nishimura, Takuji (2000), + [**Tables of 64-bit Mersenne twisters**](https://doi.org/10.1145/369534.369540), + ACM Trans. Model. Comput. Simul. 10, No. 4, 348-357. + [ZBL1390.65014](https://zbmath.org/?q=an:1390.65014). +-/ + +namespace Batteries.Random.MersenneTwister + +/-- +Mersenne Twister configuration. + +Letters in parentheses correspond to variable names used by Matsumoto and Nishimura (1998) and +Nishimura (2000). +-/ +structure Config where + /-- Word size (`w`). -/ + wordSize : Nat + /-- Degree of recurrence (`n`). -/ + stateSize : Nat + /-- Middle word (`m`). -/ + shiftSize : Fin stateSize + /-- Twist value (`r`). -/ + maskBits : Fin wordSize + /-- Coefficients of the twist matrix (`a`). -/ + xorMask : BitVec wordSize + /-- Tempering shift parameters (`u`, `s`, `t`, `l`). -/ + temperingShifts : Nat × Nat × Nat × Nat + /-- Tempering mask parameters (`d`, `b`, `c`). -/ + temperingMasks : BitVec wordSize × BitVec wordSize × BitVec wordSize + /-- Initialization multiplier (`f`). -/ + initMult : BitVec wordSize + /-- Default initialization seed value. -/ + initSeed : BitVec wordSize + +private abbrev Config.uMask (cfg : Config) : BitVec cfg.wordSize := + BitVec.allOnes cfg.wordSize <<< cfg.maskBits.val + +private abbrev Config.lMask (cfg : Config) : BitVec cfg.wordSize := + BitVec.allOnes cfg.wordSize >>> (cfg.wordSize - cfg.maskBits.val) + +@[simp] theorem Config.zero_lt_wordSize (cfg : Config) : 0 < cfg.wordSize := + Nat.zero_lt_of_lt cfg.maskBits.is_lt + +@[simp] theorem Config.zero_lt_stateSize (cfg : Config) : 0 < cfg.stateSize := + Nat.zero_lt_of_lt cfg.shiftSize.is_lt + +/-- Mersenne Twister State. -/ +structure State (cfg : Config) where + /-- Data for current state. -/ + data : Vector (BitVec cfg.wordSize) cfg.stateSize + /-- Current data index. -/ + index : Fin cfg.stateSize + +/-- Mersenne Twister initialization given an optional seed. -/ +@[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config) + (seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg := + ⟨loop seed #[] (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩ +where + /-- Inner loop for Mersenne Twister initalization. -/ + loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) := + if heq : v.size = cfg.stateSize then ⟨v, heq⟩ else + let v := v.push w + let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size + loop w v (by simp only [v, Array.size_push]; omega) + +/-- Update the state by a number of generation steps (default 1). -/ +@[specialize cfg] protected def State.update (state : State cfg) (steps := 1) : State cfg := + loop state steps +where + /-- Inner loop for Mersenne Twister update. -/ + @[inline] loop (s : State cfg) (c : Nat) : State cfg := + if c = 0 then s else + let i := s.index + let i' : Fin cfg.stateSize := + if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩ + let y := s.data[i] &&& cfg.uMask ||| s.data[i'] &&& cfg.lMask + let x := s.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1 + loop ⟨s.data.set i x, i'⟩ (c-1) + +/-- Mersenne Twister iteration. -/ +@[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg := + let i := state.index + let s := state.update + (temper s.data[i], s) +where + /-- Tempering step for Mersenne Twister. -/ + @[inline] temper (x : BitVec cfg.wordSize) := + match cfg.temperingShifts, cfg.temperingMasks with + | (u, s, t, l), (d, b, c) => + let x := x ^^^ x >>> u &&& d + let x := x ^^^ x <<< s &&& b + let x := x ^^^ x <<< t &&& c + x ^^^ x >>> l + +instance (cfg) : RandomGen (State cfg) where + range _ := (0, 2 ^ cfg.wordSize - 1) + next s := match s.next with | (r, s) => (r.toNat, s) + split s := let (a, s) := s.next; (s, cfg.init a) + +instance (cfg) : Stream (State cfg) (BitVec cfg.wordSize) where + next? s := s.next + +/-- 32 bit Mersenne Twister (MT19937) configuration. -/ +def mt19937 : Config where + wordSize := 32 + stateSize := 624 + shiftSize := 397 + maskBits := 31 + xorMask := 0x9908b0df + temperingShifts := (11, 7, 15, 18) + temperingMasks := (0xffffffff, 0x9d2c5680, 0xefc60000) + initMult := 1812433253 + initSeed := 4357 + +/-- 64 bit Mersenne Twister (MT19937-64) configuration. -/ +def mt19937_64 : Config where + wordSize := 64 + stateSize := 312 + shiftSize := 156 + maskBits := 31 + xorMask := 0xb5026f5aa96619e9 + temperingShifts := (29, 17, 37, 43) + temperingMasks := (0x5555555555555555, 0x71d67fffeda60000, 0xfff7eee000000000) + initMult := 6364136223846793005 + initSeed := 19650218 From 616ae867278ba70cc2581b1563f0526b7db229ef Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Sat, 12 Oct 2024 19:50:59 -0400 Subject: [PATCH 2/8] refactor: state update step --- Batteries/Data/Random/MersenneTwister.lean | 27 +++++++++++----------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/Batteries/Data/Random/MersenneTwister.lean b/Batteries/Data/Random/MersenneTwister.lean index 0425367a98..b63c05b99e 100644 --- a/Batteries/Data/Random/MersenneTwister.lean +++ b/Batteries/Data/Random/MersenneTwister.lean @@ -72,7 +72,7 @@ structure State (cfg : Config) where /-- Mersenne Twister initialization given an optional seed. -/ @[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config) (seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg := - ⟨loop seed #[] (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩ + ⟨loop seed (.mkEmpty cfg.stateSize) (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩ where /-- Inner loop for Mersenne Twister initalization. -/ loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) := @@ -81,24 +81,23 @@ where let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size loop w v (by simp only [v, Array.size_push]; omega) +/-- Apply the twisting transformation to the given state. -/ +@[specialize cfg] protected def State.twist (state : State cfg) : State cfg := + let i := state.index + let i' : Fin cfg.stateSize := + if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩ + let y := state.data[i] &&& cfg.uMask ||| state.data[i'] &&& cfg.lMask + let x := state.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1 + ⟨state.data.set i x, i'⟩ + /-- Update the state by a number of generation steps (default 1). -/ -@[specialize cfg] protected def State.update (state : State cfg) (steps := 1) : State cfg := - loop state steps -where - /-- Inner loop for Mersenne Twister update. -/ - @[inline] loop (s : State cfg) (c : Nat) : State cfg := - if c = 0 then s else - let i := s.index - let i' : Fin cfg.stateSize := - if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩ - let y := s.data[i] &&& cfg.uMask ||| s.data[i'] &&& cfg.lMask - let x := s.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1 - loop ⟨s.data.set i x, i'⟩ (c-1) +@[inline] protected def State.update (state : State cfg) (steps := 1) : State cfg := + if steps = 0 then state else state.twist.update (steps-1) /-- Mersenne Twister iteration. -/ @[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg := let i := state.index - let s := state.update + let s := state.twist (temper s.data[i], s) where /-- Tempering step for Mersenne Twister. -/ From 2ff5dea59c6c5c4228195330b047bfce98888d01 Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Sun, 13 Oct 2024 09:08:31 -0400 Subject: [PATCH 3/8] fix: use match --- Batteries/Data/Random/MersenneTwister.lean | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/Batteries/Data/Random/MersenneTwister.lean b/Batteries/Data/Random/MersenneTwister.lean index b63c05b99e..f66fe7b5cd 100644 --- a/Batteries/Data/Random/MersenneTwister.lean +++ b/Batteries/Data/Random/MersenneTwister.lean @@ -91,8 +91,10 @@ where ⟨state.data.set i x, i'⟩ /-- Update the state by a number of generation steps (default 1). -/ -@[inline] protected def State.update (state : State cfg) (steps := 1) : State cfg := - if steps = 0 then state else state.twist.update (steps-1) +-- TODO: optimize to `O(log(steps))` using the minimal polynomial +protected def State.update (state : State cfg) : (steps : Nat := 1) → State cfg + | 0 => state + | steps+1 => state.twist.update steps /-- Mersenne Twister iteration. -/ @[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg := @@ -112,7 +114,9 @@ where instance (cfg) : RandomGen (State cfg) where range _ := (0, 2 ^ cfg.wordSize - 1) next s := match s.next with | (r, s) => (r.toNat, s) - split s := let (a, s) := s.next; (s, cfg.init a) + split s := + -- TODO: use `(s, s.update (2 ^ 128))` once `update` is optimized. + let (a, s) := s.next; (s, cfg.init a) instance (cfg) : Stream (State cfg) (BitVec cfg.wordSize) where next? s := s.next From 4d65a4cd56ecc95af9a82fe0a25adbf5386b5ec8 Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Sun, 13 Oct 2024 12:12:50 -0400 Subject: [PATCH 4/8] chore: add test --- test/mersenne_twister.lean | 94 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 test/mersenne_twister.lean diff --git a/test/mersenne_twister.lean b/test/mersenne_twister.lean new file mode 100644 index 0000000000..133c9e488a --- /dev/null +++ b/test/mersenne_twister.lean @@ -0,0 +1,94 @@ +import Batteries.Data.Random.MersenneTwister + +open Batteries.Random.MersenneTwister + +/- TODO: move somewhere else... -/ +def Stream.take [Stream σ α] (s : σ) (n : Nat) : Array α × σ := + loop s (.mkEmpty n) (Nat.zero_le _) +where + @[inline] loop (s : σ) (acc : Array α) (h : acc.size ≤ n) := + if heq : acc.size = n then (acc, s) else + match Stream.next? s with + | none => (acc, s) + | some (v, s) => loop s (acc.push v) (by simp only [Array.size_push]; omega) + +#guard (Stream.take mt19937.init 5).1 == #[874448474, 2424656266, 2174085406, 1265871120, 3155244894] + +/- Sample output was generated using `numpy`'s implementation of MT19937: +```python +from numpy import array, uint32 +from numpy.random import MT19937 + +mt = MT19937() +mt.state = { + 'bit_generator' : 'MT19937', + 'state' : { + 'pos' : 624, + 'key' : array([ + 4357, 1673174024, 1301878288, 1129097449, 2180885271, 2495295730, 3729202114, 3451529139, 2624228201, 696045212, + 2296245684, 4097888573, 2110311931, 1672374534, 381896678, 2887874951, 3859861197, 420983856, 1691952728, 4233606289, + 1707944415, 3515687962, 4265198858, 1433261659, 1131854641, 228846788, 3811811324, 873525989, 588291779, 2854617646, + 948269870, 3798261295, 3422826645, 340138072, 3671734944, 3961007161, 2839350439, 3264455490, 310719058, 2570596611, + 3750039289, 648992492, 3816674884, 2210726029, 371217291, 196912982, 3046892150, 470118103, 1302935133, 362465408, + 1360220904, 2946174945, 1630294895, 3570642538, 1798333338, 1196832683, 226789057, 2740096276, 1062441100, 1875507765, + 2599873619, 1037523070, 4029519294, 3231722367, 2232344613, 3458909352, 2906353456, 3064815497, 3166305847, + 3658630546, 3632421090, 885320275, 1621369481, 1258557244, 2827734740, 3209486301, 131295515, 2191201702, 44141830, + 1183978535, 4202966509, 801836240, 2303299448, 333191985, 4114943231, 1490315450, 453120554, 759253243, 1381163601, + 3455606116, 1027445020, 1144697221, 3040135651, 4176273102, 798935118, 49817807, 2492997557, 3171983608, 2742334400, + 1282687705, 1047297991, 3697219554, 1400278898, 3276297123, 843040281, 354711436, 4156544868, 2873126701, 3990490795, + 3966874614, 1376536470, 4189022583, 2283386237, 3645931808, 1312021512, 679663233, 3054458511, 1152865034, 1927729338, + 538380875, 374984161, 2453495220, 514433452, 1271601365, 3737270131, 630101278, 1292962526, 2908018207, 1209528133, + 413117768, 3762161744, 2194986537, 1414304087, 379722290, 2862208514, 3551161587, 3402627497, 2411204572, 3033657332, + 4161252989, 2267825211, 963150406, 2081690150, 4014304967, 1977732365, 2412979568, 613038232, 418857425, 3682807839, + 3416550746, 3692470090, 2764012443, 3255912817, 2160692740, 3914318396, 3437441061, 2828481795, 3655629678, 582770030, + 2946380655, 3506851541, 612362648, 3394202848, 1530337657, 3360830183, 570641538, 153365650, 1624454723, 80526649, + 1365694508, 2272925828, 34250189, 3066169803, 631734422, 3706776758, 3443270679, 659846301, 3707435456, 3573851432, + 1017208097, 1100519855, 1824765866, 3284762074, 2887949547, 569464065, 3057970772, 1726477004, 3119183733, 3349922451, + 4162228670, 249085950, 3854319807, 1155219045, 811161064, 207675760, 50531529, 141911159, 3819613906, 2655884066, + 3517624211, 514724041, 2094583932, 3681571092, 3518053661, 2207473499, 961982182, 1423628102, 628853095, 3823741997, + 1450180112, 1817911736, 384378993, 1749521215, 4080873978, 2604100714, 2468900411, 1718743185, 3679944356, 623522652, + 2974445253, 351789091, 776787982, 4087231118, 395771407, 2634989045, 2547249720, 2502583808, 3550523417, 648947207, + 2361409826, 2639137202, 4179155171, 3136025689, 3233151180, 3765213604, 459508845, 412632299, 3365801270, 1208603094, + 1978375863, 3608769469, 2648322656, 994422344, 1463198657, 1938300111, 1983437898, 3617090298, 582545291, 604707873, + 615071476, 1976468460, 4251555349, 2373160371, 4138683998, 927249694, 4178996063, 3071856005, 3264724616, 2539911824, + 1383596905, 3639900055, 2590770034, 1029541954, 369472051, 3757991913, 1470517532, 2317808180, 1065978813, 3301489275, + 4087716742, 2662718566, 678716423, 274451277, 1625396912, 3598469848, 3639725841, 726808159, 1490990746, 4062476682, + 2411471067, 1395972017, 1390554948, 1854727292, 2494590309, 1377225539, 2540041390, 3288614830, 706906287, 1416719637, + 609008344, 2311429920, 821102265, 2034260263, 3587569090, 3115591378, 3545840515, 4166871929, 139581804, 2421643972, + 1250638605, 4212965387, 2794805718, 3306616566, 2466109783, 2200482525, 1496197888, 381089640, 2743249505, 4221427695, + 1247199466, 1746114586, 2065302059, 1348936513, 2997505940, 3911013644, 428274869, 2816055507, 580438782, 135588414, + 916674047, 445684901, 1016784680, 654791600, 1282652681, 92916407, 1411782674, 1367985506, 1207661779, 3531669257, + 627085756, 1857409876, 4107311709, 1384928667, 2576697382, 2875531654, 4151312039, 116927085, 1281879888, 414036984, + 3931190705, 4100135295, 1170799418, 3130902186, 4055536507, 3692691153, 480878564, 2201474460, 3663014917, 4155766371, + 1987039566, 4121861326, 2525025103, 2465094709, 2536129400, 1843468352, 2926058841, 533253191, 1988389474, 1209435122, + 4141112867, 2699109017, 2373614092, 1694129124, 2730600877, 2249161515, 1355638390, 3319290902, 2209534967, + 1463955965, 204923808, 1025015944, 214266113, 3382305551, 2455594378, 1861944634, 1820710091, 449145441, 4119339060, + 2660525612, 3515028309, 3466454003, 1024657310, 50945886, 2913140895, 721595333, 3416444872, 2701847760, 2352361641, + 234184151, 3927502002, 3834792578, 3469473651, 4193637929, 2873594460, 1994191988, 1690724605, 1956524219, 476427462, + 212379302, 1370380615, 327076237, 1984104432, 682581272, 2521259089, 3543809183, 3275489242, 241390538, 3496199707, + 2497799665, 770560132, 1626015420, 2776148645, 3717161347, 3970592238, 710750702, 3421625839, 876972885, 2108460056, + 1195168096, 1195766777, 3121053543, 2819333890, 1916084498, 717897923, 3627489721, 1970264748, 1813355780, 4148615245, + 556824139, 411448086, 4228776246, 1732939415, 3206934813, 1949588544, 3291105704, 1044314017, 222045743, 3079457322, + 638497370, 1849452395, 921039233, 1115861204, 3019093836, 2828923381, 4185943827, 3344827454, 3923907710, 760572735, + 3828284133, 1559197800, 724485616, 1828677449, 2985767159, 4119101778, 1077348258, 3518446099, 2585587017, 1855673084, + 3495712148, 3265984413, 2998815707, 760668518, 2487249862, 3060757479, 3249514669, 4222804112, 1010910776, 3893641969, + 395812799, 2591540346, 1194664170, 49789115, 1363873041, 1005502756, 1164343260, 3646613829, 459869347, 3679832718, + 1137706766, 4189431951, 1412889205, 622040248, 1536739968, 3066727065, 666661511, 1672188834, 2714762802, 4135248739, + 35606745, 2775710540, 4083752484, 3680159469, 1950331243, 251641782, 1501029974, 486869303, 1720971325, 241603808, + 28070600, 2737782337, 910469455, 3810848458, 118398842, 3078470155, 2559096993, 2933522804, 2264615020, 3793195157, + 1614887475, 45727966, 3193899422, 1157273055, 2178255365, 2646663432, 724754192, 168779241, 4048503831, 3483948530, + 3996648642, 939343027, 917914729, 3030111132, 3908302516, 29247037, 3568084731, 1034472966, 1408004326, 1693666951, + 3712665549, 3120003376, 3374542680, 2868373905, 1362838239, 1421625626, 4275252746, 548825947, 622261297, 3152835012, + 2926192892, 423356389, 151058371, 3820087086, 1673993262, 252457775, 1317185941, 2594135384, 817169312, 2016796985, + 2292688295, 1654933570, 2158435154, 2703640067, 3260663801, 3267419116, 2293555012, 2721936781, 1727868043, 91884630, + 265685878, 1143096279, 961294173, 403541376, 2338233320, 1725318369, 4101205103, 4268086122, 3418016922, 1065995435, + 1936572353, 265163284, 3043694988, 2167402293, 2057323859, 4033232254, 3258990270, 1137868927, 2142656805, 4216785320, + 1188509744, 1051071625, 196974391, 2445666962, 3092595170, 2833121107, 2474761097, 2190021692, 1852037076, 3577763037, + 3794354715, 2124118694, 2641147398, 1551493415, 1913661165, 1313919440, 2232801400, 1781682225, 1340417535, 994676154, + 251493162, 2162155003, 1678056273, 3810976356, 1505106460, 3361449605, 1041703651, 1727972302, 3959583054, 3140845007, + 3202914485, 2878334456, 2354150592, 3334993881, 1015617735, 506838242, 4168775794, 839674019, 4238769945, 849116300, + 4189642852, 1596908589, 556328875, 2369067254, 2431152278, 1004682871], dtype=uint32)}} + +print(mt.random_raw(5)) +``` +-/ From 389e79aacbf9549f66912b024dd83240e4df94e0 Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Sun, 13 Oct 2024 12:37:33 -0400 Subject: [PATCH 5/8] feat: add `drop` and `take` for streams --- Batteries.lean | 1 + Batteries/Data/Stream.lean | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 Batteries/Data/Stream.lean diff --git a/Batteries.lean b/Batteries.lean index f04a7a1f76..68b568c189 100644 --- a/Batteries.lean +++ b/Batteries.lean @@ -31,6 +31,7 @@ import Batteries.Data.PairingHeap import Batteries.Data.RBMap import Batteries.Data.Range import Batteries.Data.Rat +import Batteries.Data.Stream import Batteries.Data.String import Batteries.Data.Sum import Batteries.Data.UInt diff --git a/Batteries/Data/Stream.lean b/Batteries/Data/Stream.lean new file mode 100644 index 0000000000..3b0197a810 --- /dev/null +++ b/Batteries/Data/Stream.lean @@ -0,0 +1,23 @@ +/- +Copyright (c) 2024 François G. Dorais. All rights reserved. +Released under Apache 2. license as described in the file LICENSE. +Authors: François G. Dorais +-/ + +/-- Drop up to `n` values from the stream `s`. -/ +def Stream.drop [Stream σ α] (s : σ) : Nat → σ + | 0 => s + | n+1 => match Stream.next? s with + | none => s + | some (_, s) => drop s n + +/-- Read up to `n` values from the stream `s`. -/ +def Stream.take [Stream σ α] (s : σ) (n : Nat) : Array α × σ := + loop s (.mkEmpty n) n +where + /-- Inner loop for `Stream.take`. -/ + loop (s : σ) (acc : Array α) + | 0 => (acc, s) + | n+1 => match Stream.next? s with + | none => (acc, s) + | some (v, s) => loop s (acc.push v) n From f7e84dbba05c9ad3e2420763abebdcb6b1d336a3 Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Sun, 13 Oct 2024 23:15:18 -0400 Subject: [PATCH 6/8] chore: update test --- test/mersenne_twister.lean | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/test/mersenne_twister.lean b/test/mersenne_twister.lean index 133c9e488a..645171912d 100644 --- a/test/mersenne_twister.lean +++ b/test/mersenne_twister.lean @@ -1,17 +1,8 @@ import Batteries.Data.Random.MersenneTwister +import Batteries.Data.Stream open Batteries.Random.MersenneTwister -/- TODO: move somewhere else... -/ -def Stream.take [Stream σ α] (s : σ) (n : Nat) : Array α × σ := - loop s (.mkEmpty n) (Nat.zero_le _) -where - @[inline] loop (s : σ) (acc : Array α) (h : acc.size ≤ n) := - if heq : acc.size = n then (acc, s) else - match Stream.next? s with - | none => (acc, s) - | some (v, s) => loop s (acc.push v) (by simp only [Array.size_push]; omega) - #guard (Stream.take mt19937.init 5).1 == #[874448474, 2424656266, 2174085406, 1265871120, 3155244894] /- Sample output was generated using `numpy`'s implementation of MT19937: From d5e11d66ea35cddeb25e9b87b428066ab529f27b Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Sun, 13 Oct 2024 23:39:36 -0400 Subject: [PATCH 7/8] chore: improve docs --- Batteries/Data/Random/MersenneTwister.lean | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/Batteries/Data/Random/MersenneTwister.lean b/Batteries/Data/Random/MersenneTwister.lean index f66fe7b5cd..86d1c64a7f 100644 --- a/Batteries/Data/Random/MersenneTwister.lean +++ b/Batteries/Data/Random/MersenneTwister.lean @@ -7,7 +7,26 @@ import Batteries.Data.Vector /-! # Mersenne Twister -Reference implementation for the Mersenne Twister pseudorandom number generator. +Generic implementation for the Mersenne Twister pseudorandom number generator. + +All choices of parameters from Matsumoto and Nishimura (1998) are supported, along with later +refinements. Parameters for the standard 32-bit MT19937 and 64-bit MT19937-64 algorithms are +provided. Both `RandomGen` and `Stream` interfaces are provided. + +Use `mt19937.init seed` to create a MT19937 PRNG with a 32 bit seed value; use +`mt19937_64.init seed` to create a MT19937-64 PRNG with a 64 bit seed value. If omitted, default +seed choices will be used. + +Sample usage: +``` +import Batteries.Data.Random.MersenneTwister + +open Batteries.Random.MersenneTwister + +def mtgen := mt19937.init -- default seed 4357 + +#eval (Stream.take mtgen 5).fst -- #[874448474, 2424656266, 2174085406, 1265871120, 3155244894] +``` ### References: From bac9ad686eebf956834773bf5829d7a6fb79523c Mon Sep 17 00:00:00 2001 From: "F. G. Dorais" Date: Sun, 22 Dec 2024 10:27:58 -0500 Subject: [PATCH 8/8] fix: remove instance --- Batteries/Data/Random/MersenneTwister.lean | 7 ------- 1 file changed, 7 deletions(-) diff --git a/Batteries/Data/Random/MersenneTwister.lean b/Batteries/Data/Random/MersenneTwister.lean index 817d92eabe..70a6f86042 100644 --- a/Batteries/Data/Random/MersenneTwister.lean +++ b/Batteries/Data/Random/MersenneTwister.lean @@ -130,13 +130,6 @@ where let x := x ^^^ x <<< t &&& c x ^^^ x >>> l -instance (cfg) : RandomGen (State cfg) where - range _ := (0, 2 ^ cfg.wordSize - 1) - next s := match s.next with | (r, s) => (r.toNat, s) - split s := - -- TODO: use `(s, s.update (2 ^ 128))` once `update` is optimized. - let (a, s) := s.next; (s, cfg.init a) - instance (cfg) : Stream (State cfg) (BitVec cfg.wordSize) where next? s := s.next