14
14
15
15
"""RNN modules for Flax."""
16
16
17
- from typing import (
18
- Any ,
19
- TypeVar
20
- )
17
+ from typing import Any , TypeVar
18
+ from collections .abc import Mapping
21
19
from collections .abc import Callable
22
20
from functools import partial
23
21
from typing_extensions import Protocol
27
25
import jax .numpy as jnp
28
26
29
27
from flax import nnx
30
- from flax .nnx import rnglib
28
+ from flax .nnx import filterlib , rnglib
31
29
from flax .nnx .module import Module
32
30
from flax .nnx .nn import initializers
33
31
from flax .nnx .nn .linear import Linear
34
32
from flax .nnx .nn .activations import sigmoid
35
33
from flax .nnx .nn .activations import tanh
36
- from flax .nnx .transforms .iteration import Carry
34
+ from flax .nnx .transforms .iteration import Carry , StateAxes
37
35
from flax .typing import (
38
36
Dtype ,
39
37
Initializer ,
@@ -593,15 +591,19 @@ class RNN(Module):
593
591
using :func:`flax.nnx.scan`.
594
592
"""
595
593
594
+ state_axes : Mapping [str , int | type [Carry ] | None ]
595
+
596
596
def __init__ (
597
- self ,
598
- cell : RNNCellBase ,
599
- time_major : bool = False ,
600
- return_carry : bool = False ,
601
- reverse : bool = False ,
602
- keep_order : bool = False ,
603
- unroll : int = 1 ,
604
- rngs : rnglib .Rngs | None = None ,
597
+ self ,
598
+ cell : RNNCellBase ,
599
+ time_major : bool = False ,
600
+ return_carry : bool = False ,
601
+ reverse : bool = False ,
602
+ keep_order : bool = False ,
603
+ unroll : int = 1 ,
604
+ rngs : rnglib .Rngs | None = None ,
605
+ state_axes : Mapping [str , int | type [Carry ] | None ] | None = None ,
606
+ broadcast_rngs : filterlib .Filter = None ,
605
607
):
606
608
self .cell = cell
607
609
self .time_major = time_major
@@ -612,19 +614,21 @@ def __init__(
612
614
if rngs is None :
613
615
rngs = rnglib .Rngs (0 )
614
616
self .rngs = rngs
617
+ self .state_axes = state_axes or {...: Carry } # type: ignore
618
+ self .broadcast_rngs = broadcast_rngs
615
619
616
620
def __call__ (
617
- self ,
618
- inputs : Array ,
619
- * ,
620
- initial_carry : Carry | None = None ,
621
- seq_lengths : Array | None = None ,
622
- return_carry : bool | None = None ,
623
- time_major : bool | None = None ,
624
- reverse : bool | None = None ,
625
- keep_order : bool | None = None ,
626
- rngs : rnglib .Rngs | None = None ,
627
- ):
621
+ self ,
622
+ inputs : Array ,
623
+ * ,
624
+ initial_carry : Carry | None = None ,
625
+ seq_lengths : Array | None = None ,
626
+ return_carry : bool | None = None ,
627
+ time_major : bool | None = None ,
628
+ reverse : bool | None = None ,
629
+ keep_order : bool | None = None ,
630
+ rngs : rnglib .Rngs | None = None ,
631
+ ):
628
632
if return_carry is None :
629
633
return_carry = self .return_carry
630
634
if time_major is None :
@@ -670,20 +674,26 @@ def __call__(
670
674
)
671
675
672
676
slice_carry = seq_lengths is not None and return_carry
673
-
674
- def scan_fn (cell : RNNCellBase , carry : Carry , x : Array ) -> tuple [Carry , Array ] | tuple [Carry , tuple [Carry , Array ]]:
677
+ broadcast_rngs = nnx .All (nnx .RngState , self .broadcast_rngs )
678
+ state_axes = StateAxes ({broadcast_rngs : None , ** self .state_axes }) # type: ignore
679
+
680
+ # we use split_rngs with splits=1 and squeeze=True to get unique rngs
681
+ # every time RNN is called
682
+ @nnx .split_rngs (splits = 1 , only = self .broadcast_rngs , squeeze = True )
683
+ @nnx .scan (
684
+ in_axes = (state_axes , Carry , time_axis ),
685
+ out_axes = (Carry , (0 , time_axis )) if slice_carry else (Carry , time_axis ),
686
+ unroll = self .unroll ,
687
+ )
688
+ def scan_fn (
689
+ cell : RNNCellBase , carry : Carry , x : Array
690
+ ) -> tuple [Carry , Array ] | tuple [Carry , tuple [Carry , Array ]]:
675
691
carry , y = cell (carry , x )
676
692
if slice_carry :
677
693
return carry , (carry , y )
678
694
return carry , y
679
- state_axes = nnx .StateAxes ({...: Carry }) # type: ignore[arg-type]
680
- scan = nnx .scan (
681
- scan_fn ,
682
- in_axes = (state_axes , Carry , time_axis ),
683
- out_axes = (Carry , (0 , time_axis )) if slice_carry else (Carry , time_axis ),
684
- unroll = self .unroll ,
685
- )
686
- scan_output = scan (self .cell , carry , inputs )
695
+
696
+ scan_output = scan_fn (self .cell , carry , inputs )
687
697
688
698
# Next we select the final carry. If a segmentation mask was provided and
689
699
# return_carry is True we slice the carry history and select the last valid
0 commit comments