Skip to content

Commit b63c831

Browse files
committed
Add split and key to nnx.Rngs
1 parent 32d5374 commit b63c831

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

flax/nnx/rnglib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,9 @@ def __getattr__(self, name: str):
415415
def __call__(self):
416416
return self.default()
417417

418+
def key(self):
419+
return self.default()
420+
418421
def __iter__(self) -> tp.Iterator[str]:
419422
for name, stream in vars(self).items():
420423
if isinstance(stream, RngStream):
@@ -433,6 +436,9 @@ def items(self):
433436
if isinstance(stream, RngStream):
434437
yield name, stream
435438

439+
def split(self, splits: int):
440+
return self.fork(split=splits)
441+
436442
def fork(
437443
self,
438444
/,

0 commit comments

Comments
 (0)