Skip to content

Commit fc38f21

Browse files
levskayaFlax Authors
authored and
Flax Authors
committed
Introduce a mode for a faster nn.scan mode that avoids extra jax retracing.
This adds a new keyword option to linen nn.scan `check_constancy_invariants` that defaults to True for the existing behavior. Setting it to False however avoids an extra jax trace to hoist scan loop constants out of the loop and to check for non-data-dependence of broadcast variables and body function outputs marked constant. The time savings from not running this extra trace and static check can be considerable when tracing and compiling larger models. PiperOrigin-RevId: 705869200
1 parent 207966e commit fc38f21

File tree

4 files changed

+134
-28
lines changed

4 files changed

+134
-28
lines changed

flax/core/axes_scan.py

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@
1313
# limitations under the License.
1414

1515
"""Wrapper around jax.lax.scan with in_axes/out_axes API."""
16+
from collections.abc import Callable
1617
import functools
1718
from typing import Any, Optional
18-
from collections.abc import Callable
1919

2020
import jax
21-
import jax.numpy as jnp
22-
import numpy as np
23-
from jax import core, lax
21+
from jax import core
22+
from jax import lax
2423
from jax.extend import linear_util as lu
2524
from jax.interpreters import partial_eval as pe
25+
import jax.numpy as jnp
26+
import numpy as np
2627

2728
ScanAxis = Optional[int]
2829

@@ -35,13 +36,14 @@ class _Broadcast:
3536

3637

3738
def scan(
38-
fn: Callable[..., Any],
39-
in_axes: Any,
40-
out_axes: Any,
41-
length: int | None = None,
42-
reverse: bool = False,
43-
unroll: int = 1,
44-
_split_transpose: bool = False
39+
fn: Callable[..., Any],
40+
in_axes: Any,
41+
out_axes: Any,
42+
length: int | None = None,
43+
reverse: bool = False,
44+
unroll: int = 1,
45+
_split_transpose: bool = False,
46+
check_constancy_invariants: bool = True,
4547
):
4648
"""A wrapper around `jax.lax.scan` with in_axes/out_axes api.
4749
@@ -78,6 +80,11 @@ def body_fn(b, c, x):
7880
iteration of a loop (default: 1).
7981
_split_transpose: An experimental feature to split the transpose of scan
8082
into a scan and a map, backed by an experimental Jax lax.scan() feature.
83+
check_constancy_invariants: If true, the scan will verify that the
84+
broadcast constants are true loop invariants, and further supports
85+
broadcast function (non-carry) outputs. This requires an extra jax
86+
tracing step however, so setting to false can reduce trace time on larger
87+
models.
8188
Returns:
8289
the function that performs the scan of the form:
8390
(broadcast_in, carry_in, *args) -> (broadcast_out, carry_out, scan_out).
@@ -114,39 +121,43 @@ def trans(x):
114121
return jax.tree_util.tree_map(trans, xs)
115122

116123
def scan_fn(broadcast_in, init, *args):
124+
# Requires one extra tracing operation to test invariants:
125+
# Verifies that broadcast constants are true loop invariants, and further
126+
# supports broadcast function (non-carry) outputs.
127+
117128
xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args)
118129

119130
def body_fn(c, xs, init_mode=False):
120131
# inject constants
121132
xs = jax.tree_util.tree_map(
122-
lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs
133+
lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs
123134
)
124135
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
125136

126137
if init_mode:
127138
ys = jax.tree_util.tree_map(
128-
lambda ax, y: (y if ax is broadcast else ()), out_axes, ys
139+
lambda ax, y: (y if ax is broadcast else ()), out_axes, ys
129140
)
130141
return broadcast_out, ys
131142
else:
132143
ys = jax.tree_util.tree_map(
133-
lambda ax, y: (() if ax is broadcast else y), out_axes, ys
144+
lambda ax, y: (() if ax is broadcast else y), out_axes, ys
134145
)
135146
return c, ys
136147

137148
broadcast_body = functools.partial(body_fn, init_mode=True)
138149

139150
carry_avals = jax.tree_util.tree_map(
140-
lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init
151+
lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init
141152
)
142153
scan_avals = jax.tree_util.tree_map(
143-
lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs
154+
lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs
144155
)
145156
input_avals = (carry_avals, scan_avals)
146157

147158
in_avals, in_tree = jax.tree_util.tree_flatten(input_avals)
148159
f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
149-
lu.wrap_init(broadcast_body), in_tree
160+
lu.wrap_init(broadcast_body), in_tree
150161
)
151162
in_pvals = list(map(pe.PartialVal.unknown, in_avals))
152163
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
@@ -155,29 +166,63 @@ def body_fn(c, xs, init_mode=False):
155166
for pv, const in out_pvals:
156167
if pv is not None:
157168
raise ValueError(
158-
'broadcasted variable has a data dependency on the scan body.'
169+
'broadcasted variable has a data dependency on the scan body.'
159170
)
160171
out_flat.append(const)
161172
broadcast_in, constants_out = jax.tree_util.tree_unflatten(
162-
out_tree(), out_flat
173+
out_tree(), out_flat
163174
)
164175

165176
if jax.version.__version_info__ > (0, 4, 25):
166177
c, ys = lax.scan(
167-
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll,
168-
_split_transpose=_split_transpose
178+
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll,
179+
_split_transpose=_split_transpose
169180
)
170181
else:
171182
c, ys = lax.scan(
172-
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
183+
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
173184
)
174185
ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
175186
ys = jax.tree_util.tree_map(
176-
lambda ax, const, y: (const if ax is broadcast else y),
177-
out_axes,
178-
constants_out,
179-
ys,
187+
lambda ax, const, y: (const if ax is broadcast else y),
188+
out_axes,
189+
constants_out,
190+
ys,
180191
)
181192
return broadcast_in, c, ys
182193

183-
return scan_fn
194+
def simple_scan_fn(broadcast_in, init, *args):
195+
# Saves an extra tracing operation.
196+
# No verification of constancy, and no support for non-carry broadcast
197+
# function outputs.
198+
xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args)
199+
200+
if broadcast in jax.tree_util.tree_leaves(out_axes):
201+
raise ValueError(f"nn.scan run with check_constancy_invariants=False "
202+
f"does not support broadcast non-carry function "
203+
f"outputs. out_axes was given as {out_axes}")
204+
205+
def body_fn(c, xs):
206+
# inject constants
207+
xs = jax.tree_util.tree_map(
208+
lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs
209+
)
210+
_, c, ys = fn(broadcast_in, c, *xs)
211+
return c, ys
212+
213+
if jax.version.__version_info__ > (0, 4, 25):
214+
c, ys = lax.scan(
215+
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll,
216+
_split_transpose=_split_transpose
217+
)
218+
else:
219+
c, ys = lax.scan(
220+
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
221+
)
222+
ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
223+
return broadcast_in, c, ys
224+
225+
if check_constancy_invariants:
226+
return scan_fn
227+
else:
228+
return simple_scan_fn

flax/core/lift.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,7 @@ def scan(
879879
_split_transpose: bool = False,
880880
data_transform: Callable[..., Any] | None = None,
881881
metadata_params: dict[Any, Any] = {},
882+
check_constancy_invariants: bool = True,
882883
) -> Callable[..., Any]:
883884
"""A lifted version of ``jax.lax.scan``.
884885
@@ -946,6 +947,11 @@ def body_fn(scope, c, x):
946947
intended for inline SPMD annotations.
947948
metadata_params: arguments dict passed to AxisMetadata instances in the
948949
variable tree.
950+
check_constancy_invariants: If true, the scan will verify that the
951+
broadcast constants are true loop invariants, and further supports
952+
broadcast function (non-carry) outputs. This requires an extra jax
953+
tracing step however, so setting to false can reduce trace time on larger
954+
models.
949955
950956
Returns:
951957
The scan function with the signature
@@ -1000,7 +1006,8 @@ def find_length(axis, x):
10001006
length=length,
10011007
reverse=reverse,
10021008
unroll=unroll,
1003-
_split_transpose=_split_transpose
1009+
_split_transpose=_split_transpose,
1010+
check_constancy_invariants=check_constancy_invariants,
10041011
)
10051012
def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args):
10061013
carry_vars, c = carry

flax/linen/transforms.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,7 @@ def scan(
11651165
metadata_params: Mapping[Any, Any] = {},
11661166
methods=None,
11671167
_split_transpose: bool = False,
1168+
check_constancy_invariants: bool = True,
11681169
) -> Target:
11691170
"""A lifted version of ``jax.lax.scan``.
11701171
@@ -1304,6 +1305,11 @@ def scan(
13041305
methods: If ``target`` is a ``Module``, the methods of ``Module`` to scan over.
13051306
_split_transpose: An experimental feature to split the transpose of a scan
13061307
into a scan and a map, backed by an experimental Jax lax.scan() feature.
1308+
check_constancy_invariants: If true, the scan will verify that the
1309+
broadcast constants are true loop invariants, and further supports
1310+
broadcast function (non-carry) outputs. This requires an extra jax
1311+
tracing step however, so setting to false can reduce trace time on larger
1312+
models.
13071313
13081314
Returns:
13091315
The scan function with the signature ``(module, carry, *xs) -> (carry,
@@ -1326,6 +1332,7 @@ def scan(
13261332
data_transform=data_transform,
13271333
metadata_params=metadata_params,
13281334
methods=methods,
1335+
check_constancy_invariants=check_constancy_invariants,
13291336
)
13301337

13311338

tests/linen/linen_transforms_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2715,6 +2715,53 @@ def __call__(self, x):
27152715
params = foo.init(key, x)
27162716
foo.apply(params, x)
27172717

2718+
@parameterized.named_parameters(
2719+
('retracing scan', True), ('simple scan', False)
2720+
)
2721+
def test_jit_scan_retracing(self, retracing_scan: bool):
2722+
num_blocks = 4
2723+
num_patterns = 4
2724+
features = 4
2725+
trace_counts = [0, 0]
2726+
2727+
class Block(nn.Module):
2728+
def setup(self):
2729+
self.dense = nn.Dense(features, use_bias=False)
2730+
@nn.jit
2731+
def __call__(self, x):
2732+
nonlocal trace_counts
2733+
trace_counts[1] += 1
2734+
return self.dense(x)
2735+
2736+
class BlockSequence(nn.Module):
2737+
def setup(self):
2738+
self.blocks = [Block() for _ in range(num_blocks)]
2739+
@nn.jit
2740+
def __call__(self, carry, inputs):
2741+
nonlocal trace_counts
2742+
trace_counts[0] += 1
2743+
for block in self.blocks:
2744+
carry = block(carry)
2745+
return carry, inputs
2746+
2747+
class Transformer(nn.Module):
2748+
retracing_scan: bool = True
2749+
def setup(self):
2750+
self.scan = nn.scan(
2751+
BlockSequence,
2752+
variable_axes={'params': 0},
2753+
split_rngs={'params': False},
2754+
length=num_patterns,
2755+
check_constancy_invariants=retracing_scan,
2756+
)()
2757+
def __call__(self, inputs):
2758+
return self.scan(jnp.zeros_like(inputs), inputs)
2759+
2760+
model = Transformer(retracing_scan=retracing_scan)
2761+
_ = model.init(random.key(0), jnp.ones((num_patterns, features,)))
2762+
self.assertEqual(trace_counts[0], 2 if retracing_scan else 1)
2763+
self.assertEqual(trace_counts[1], 2 if retracing_scan else 1)
2764+
27182765

27192766
if __name__ == '__main__':
27202767
absltest.main()

0 commit comments

Comments
 (0)