13
13
# limitations under the License.
14
14
15
15
"""Wrapper around jax.lax.scan with in_axes/out_axes API."""
16
+ from collections .abc import Callable
16
17
import functools
17
18
from typing import Any , Optional
18
- from collections .abc import Callable
19
19
20
20
import jax
21
- import jax .numpy as jnp
22
- import numpy as np
23
- from jax import core , lax
21
+ from jax import core
22
+ from jax import lax
24
23
from jax .extend import linear_util as lu
25
24
from jax .interpreters import partial_eval as pe
25
+ import jax .numpy as jnp
26
+ import numpy as np
26
27
27
28
ScanAxis = Optional [int ]
28
29
@@ -35,13 +36,14 @@ class _Broadcast:
35
36
36
37
37
38
def scan (
38
- fn : Callable [..., Any ],
39
- in_axes : Any ,
40
- out_axes : Any ,
41
- length : int | None = None ,
42
- reverse : bool = False ,
43
- unroll : int = 1 ,
44
- _split_transpose : bool = False
39
+ fn : Callable [..., Any ],
40
+ in_axes : Any ,
41
+ out_axes : Any ,
42
+ length : int | None = None ,
43
+ reverse : bool = False ,
44
+ unroll : int = 1 ,
45
+ _split_transpose : bool = False ,
46
+ check_constancy_invariants : bool = True ,
45
47
):
46
48
"""A wrapper around `jax.lax.scan` with in_axes/out_axes api.
47
49
@@ -78,6 +80,11 @@ def body_fn(b, c, x):
78
80
iteration of a loop (default: 1).
79
81
_split_transpose: An experimental feature to split the transpose of scan
80
82
into a scan and a map, backed by an experimental Jax lax.scan() feature.
83
+ check_constancy_invariants: If true, the scan will verify that the
84
+ broadcast constants are true loop invariants, and further supports
85
+ broadcast function (non-carry) outputs. This requires an extra jax
86
+ tracing step however, so setting to false can reduce trace time on larger
87
+ models.
81
88
Returns:
82
89
the function that performs the scan of the form:
83
90
(broadcast_in, carry_in, *args) -> (broadcast_out, carry_out, scan_out).
@@ -114,39 +121,43 @@ def trans(x):
114
121
return jax .tree_util .tree_map (trans , xs )
115
122
116
123
def scan_fn (broadcast_in , init , * args ):
124
+ # Requires one extra tracing operation to test invariants:
125
+ # Verifies that broadcast constants are true loop invariants, and further
126
+ # supports broadcast function (non-carry) outputs.
127
+
117
128
xs = jax .tree_util .tree_map (transpose_to_front , in_axes , args )
118
129
119
130
def body_fn (c , xs , init_mode = False ):
120
131
# inject constants
121
132
xs = jax .tree_util .tree_map (
122
- lambda ax , arg , x : (arg if ax is broadcast else x ), in_axes , args , xs
133
+ lambda ax , arg , x : (arg if ax is broadcast else x ), in_axes , args , xs
123
134
)
124
135
broadcast_out , c , ys = fn (broadcast_in , c , * xs )
125
136
126
137
if init_mode :
127
138
ys = jax .tree_util .tree_map (
128
- lambda ax , y : (y if ax is broadcast else ()), out_axes , ys
139
+ lambda ax , y : (y if ax is broadcast else ()), out_axes , ys
129
140
)
130
141
return broadcast_out , ys
131
142
else :
132
143
ys = jax .tree_util .tree_map (
133
- lambda ax , y : (() if ax is broadcast else y ), out_axes , ys
144
+ lambda ax , y : (() if ax is broadcast else y ), out_axes , ys
134
145
)
135
146
return c , ys
136
147
137
148
broadcast_body = functools .partial (body_fn , init_mode = True )
138
149
139
150
carry_avals = jax .tree_util .tree_map (
140
- lambda x : core .ShapedArray (jnp .shape (x ), jnp .result_type (x )), init
151
+ lambda x : core .ShapedArray (jnp .shape (x ), jnp .result_type (x )), init
141
152
)
142
153
scan_avals = jax .tree_util .tree_map (
143
- lambda x : core .ShapedArray (jnp .shape (x )[1 :], jnp .result_type (x )), xs
154
+ lambda x : core .ShapedArray (jnp .shape (x )[1 :], jnp .result_type (x )), xs
144
155
)
145
156
input_avals = (carry_avals , scan_avals )
146
157
147
158
in_avals , in_tree = jax .tree_util .tree_flatten (input_avals )
148
159
f_flat , out_tree = jax .api_util .flatten_fun_nokwargs (
149
- lu .wrap_init (broadcast_body ), in_tree
160
+ lu .wrap_init (broadcast_body ), in_tree
150
161
)
151
162
in_pvals = list (map (pe .PartialVal .unknown , in_avals ))
152
163
_ , out_pvals , _ = pe .trace_to_jaxpr_nounits (f_flat , in_pvals )
@@ -155,29 +166,63 @@ def body_fn(c, xs, init_mode=False):
155
166
for pv , const in out_pvals :
156
167
if pv is not None :
157
168
raise ValueError (
158
- 'broadcasted variable has a data dependency on the scan body.'
169
+ 'broadcasted variable has a data dependency on the scan body.'
159
170
)
160
171
out_flat .append (const )
161
172
broadcast_in , constants_out = jax .tree_util .tree_unflatten (
162
- out_tree (), out_flat
173
+ out_tree (), out_flat
163
174
)
164
175
165
176
if jax .version .__version_info__ > (0 , 4 , 25 ):
166
177
c , ys = lax .scan (
167
- body_fn , init , xs , length = length , reverse = reverse , unroll = unroll ,
168
- _split_transpose = _split_transpose
178
+ body_fn , init , xs , length = length , reverse = reverse , unroll = unroll ,
179
+ _split_transpose = _split_transpose
169
180
)
170
181
else :
171
182
c , ys = lax .scan (
172
- body_fn , init , xs , length = length , reverse = reverse , unroll = unroll
183
+ body_fn , init , xs , length = length , reverse = reverse , unroll = unroll
173
184
)
174
185
ys = jax .tree_util .tree_map (transpose_from_front , out_axes , ys )
175
186
ys = jax .tree_util .tree_map (
176
- lambda ax , const , y : (const if ax is broadcast else y ),
177
- out_axes ,
178
- constants_out ,
179
- ys ,
187
+ lambda ax , const , y : (const if ax is broadcast else y ),
188
+ out_axes ,
189
+ constants_out ,
190
+ ys ,
180
191
)
181
192
return broadcast_in , c , ys
182
193
183
- return scan_fn
194
+ def simple_scan_fn (broadcast_in , init , * args ):
195
+ # Saves an extra tracing operation.
196
+ # No verification of constancy, and no support for non-carry broadcast
197
+ # function outputs.
198
+ xs = jax .tree_util .tree_map (transpose_to_front , in_axes , args )
199
+
200
+ if broadcast in jax .tree_util .tree_leaves (out_axes ):
201
+ raise ValueError (f"nn.scan run with check_constancy_invariants=False "
202
+ f"does not support broadcast non-carry function "
203
+ f"outputs. out_axes was given as { out_axes } " )
204
+
205
+ def body_fn (c , xs ):
206
+ # inject constants
207
+ xs = jax .tree_util .tree_map (
208
+ lambda ax , arg , x : (arg if ax is broadcast else x ), in_axes , args , xs
209
+ )
210
+ _ , c , ys = fn (broadcast_in , c , * xs )
211
+ return c , ys
212
+
213
+ if jax .version .__version_info__ > (0 , 4 , 25 ):
214
+ c , ys = lax .scan (
215
+ body_fn , init , xs , length = length , reverse = reverse , unroll = unroll ,
216
+ _split_transpose = _split_transpose
217
+ )
218
+ else :
219
+ c , ys = lax .scan (
220
+ body_fn , init , xs , length = length , reverse = reverse , unroll = unroll
221
+ )
222
+ ys = jax .tree_util .tree_map (transpose_from_front , out_axes , ys )
223
+ return broadcast_in , c , ys
224
+
225
+ if check_constancy_invariants :
226
+ return scan_fn
227
+ else :
228
+ return simple_scan_fn
0 commit comments