Skip to content

Commit 5408c40

Browse files
author
Flax Authors
committed
Merge pull request #4407 from google:rnn-broadcast-rngs
PiperOrigin-RevId: 713694016
2 parents e2134af + b9f016a commit 5408c40

File tree

3 files changed

+662
-538
lines changed

3 files changed

+662
-538
lines changed

flax/nnx/nn/recurrent.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414

1515
"""RNN modules for Flax."""
1616

17-
from typing import (
18-
Any,
19-
TypeVar
20-
)
17+
from typing import Any, TypeVar
18+
from collections.abc import Mapping
2119
from collections.abc import Callable
2220
from functools import partial
2321
from typing_extensions import Protocol
@@ -27,13 +25,13 @@
2725
import jax.numpy as jnp
2826

2927
from flax import nnx
30-
from flax.nnx import rnglib
28+
from flax.nnx import filterlib, rnglib
3129
from flax.nnx.module import Module
3230
from flax.nnx.nn import initializers
3331
from flax.nnx.nn.linear import Linear
3432
from flax.nnx.nn.activations import sigmoid
3533
from flax.nnx.nn.activations import tanh
36-
from flax.nnx.transforms.iteration import Carry
34+
from flax.nnx.transforms.iteration import Carry, StateAxes
3735
from flax.typing import (
3836
Dtype,
3937
Initializer,
@@ -593,15 +591,19 @@ class RNN(Module):
593591
using :func:`flax.nnx.scan`.
594592
"""
595593

594+
state_axes: Mapping[str, int | type[Carry] | None]
595+
596596
def __init__(
597-
self,
598-
cell: RNNCellBase,
599-
time_major: bool = False,
600-
return_carry: bool = False,
601-
reverse: bool = False,
602-
keep_order: bool = False,
603-
unroll: int = 1,
604-
rngs: rnglib.Rngs | None = None,
597+
self,
598+
cell: RNNCellBase,
599+
time_major: bool = False,
600+
return_carry: bool = False,
601+
reverse: bool = False,
602+
keep_order: bool = False,
603+
unroll: int = 1,
604+
rngs: rnglib.Rngs | None = None,
605+
state_axes: Mapping[str, int | type[Carry] | None] | None = None,
606+
broadcast_rngs: filterlib.Filter = None,
605607
):
606608
self.cell = cell
607609
self.time_major = time_major
@@ -612,19 +614,21 @@ def __init__(
612614
if rngs is None:
613615
rngs = rnglib.Rngs(0)
614616
self.rngs = rngs
617+
self.state_axes = state_axes or {...: Carry} # type: ignore
618+
self.broadcast_rngs = broadcast_rngs
615619

616620
def __call__(
617-
self,
618-
inputs: Array,
619-
*,
620-
initial_carry: Carry | None = None,
621-
seq_lengths: Array | None = None,
622-
return_carry: bool | None = None,
623-
time_major: bool | None = None,
624-
reverse: bool | None = None,
625-
keep_order: bool | None = None,
626-
rngs: rnglib.Rngs | None = None,
627-
):
621+
self,
622+
inputs: Array,
623+
*,
624+
initial_carry: Carry | None = None,
625+
seq_lengths: Array | None = None,
626+
return_carry: bool | None = None,
627+
time_major: bool | None = None,
628+
reverse: bool | None = None,
629+
keep_order: bool | None = None,
630+
rngs: rnglib.Rngs | None = None,
631+
):
628632
if return_carry is None:
629633
return_carry = self.return_carry
630634
if time_major is None:
@@ -670,20 +674,26 @@ def __call__(
670674
)
671675

672676
slice_carry = seq_lengths is not None and return_carry
673-
674-
def scan_fn(cell: RNNCellBase, carry: Carry, x: Array) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]:
677+
broadcast_rngs = nnx.All(nnx.RngState, self.broadcast_rngs)
678+
state_axes = StateAxes({broadcast_rngs: None, **self.state_axes}) # type: ignore
679+
680+
# we use split_rngs with splits=1 and squeeze=True to get unique rngs
681+
# every time RNN is called
682+
@nnx.split_rngs(splits=1, only=self.broadcast_rngs, squeeze=True)
683+
@nnx.scan(
684+
in_axes=(state_axes, Carry, time_axis),
685+
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis),
686+
unroll=self.unroll,
687+
)
688+
def scan_fn(
689+
cell: RNNCellBase, carry: Carry, x: Array
690+
) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]:
675691
carry, y = cell(carry, x)
676692
if slice_carry:
677693
return carry, (carry, y)
678694
return carry, y
679-
state_axes = nnx.StateAxes({...: Carry}) # type: ignore[arg-type]
680-
scan = nnx.scan(
681-
scan_fn,
682-
in_axes=(state_axes, Carry, time_axis),
683-
out_axes=(Carry, (0, time_axis)) if slice_carry else (Carry, time_axis),
684-
unroll=self.unroll,
685-
)
686-
scan_output = scan(self.cell, carry, inputs)
695+
696+
scan_output = scan_fn(self.cell, carry, inputs)
687697

688698
# Next we select the final carry. If a segmentation mask was provided and
689699
# return_carry is True we slice the carry history and select the last valid

flax/nnx/rnglib.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,19 +302,22 @@ def split_rngs(
302302
*,
303303
splits: int | tuple[int, ...],
304304
only: filterlib.Filter = ...,
305+
squeeze: bool = False,
305306
) -> SplitBackups: ...
306307
@tp.overload
307308
def split_rngs(
308309
*,
309310
splits: int | tuple[int, ...],
310311
only: filterlib.Filter = ...,
312+
squeeze: bool = False,
311313
) -> tp.Callable[[F], F]: ...
312314
def split_rngs(
313315
node: tp.Any = MISSING,
314316
/,
315317
*,
316318
splits: int | tuple[int, ...],
317319
only: filterlib.Filter = ...,
320+
squeeze: bool = False,
318321
) -> SplitBackups | tp.Callable[[F], F]:
319322
"""Splits the (nested) Rng states of the given node.
320323
@@ -412,13 +415,18 @@ def split_rngs(
412415
def split_rngs_decorator(f: F) -> F:
413416
@functools.wraps(f)
414417
def split_rngs_wrapper(*args, **kwargs):
415-
with split_rngs((args, kwargs), splits=splits, only=only):
418+
with split_rngs(
419+
(args, kwargs), splits=splits, only=only, squeeze=squeeze
420+
):
416421
return f(*args, **kwargs)
417422

418423
return tp.cast(F, split_rngs_wrapper)
419424

420425
return split_rngs_decorator # type: ignore[bad-return-type]
421426

427+
if squeeze and splits != 1:
428+
raise ValueError('squeeze=True is only supported for splits=1')
429+
422430
predicate = filterlib.to_predicate(only)
423431
backups: list[StreamBackup] = []
424432
for path, stream in graph.iter_graph(node):
@@ -429,8 +437,13 @@ def split_rngs_wrapper(*args, **kwargs):
429437
):
430438
key = stream()
431439
backups.append((stream, stream.key.value, stream.count.value))
432-
stream.key.value = jax.random.split(key, splits)
433-
if isinstance(splits, int):
440+
key = jax.random.split(key, splits)
441+
if squeeze:
442+
key = key[0]
443+
stream.key.value = key
444+
if squeeze:
445+
counts_shape = stream.count.shape
446+
elif isinstance(splits, int):
434447
counts_shape = (splits, *stream.count.shape)
435448
else:
436449
counts_shape = (*splits, *stream.count.shape)

0 commit comments

Comments
 (0)