Skip to content

Commit 6cf5b7d

Browse files
committed
[nnx] fast jit
1 parent 53bde74 commit 6cf5b7d

File tree

9 files changed

+309
-96
lines changed

9 files changed

+309
-96
lines changed

benchmarks/nnx_simple_training.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from absl import app
2626

2727
FLAGS = flags.FLAGS
28-
flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in')
28+
flags.DEFINE_enum(
29+
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
30+
)
2931
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
3032
flags.DEFINE_integer('batch_size', 32, 'Batch size')
3133
flags.DEFINE_integer('width', 32, 'Hidden layer size')
@@ -46,6 +48,13 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
4648
def __call__(self, x):
4749
return x @ self.w + self.b
4850

51+
class Block(nnx.Module):
52+
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
53+
self.linear = Linear(din, dout, rngs=rngs)
54+
self.bn = nnx.BatchNorm(dout, rngs=rngs)
55+
56+
def __call__(self, x):
57+
return nnx.relu(self.bn(self.linear(x)))
4958

5059
class Count(nnx.Variable):
5160
pass
@@ -54,11 +63,11 @@ class Count(nnx.Variable):
5463
class MLP(nnx.Module):
5564
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
5665
self.count = Count(jnp.array(0))
57-
self.linear_in = Linear(din, dhidden, rngs=rngs)
66+
self.linear_in = Block(din, dhidden, rngs=rngs)
5867
self.intermediates = [
59-
Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
68+
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
6069
]
61-
self.linear_out = Linear(dhidden, dout, rngs=rngs)
70+
self.linear_out = Block(dhidden, dout, rngs=rngs)
6271

6372
def __call__(self, x):
6473
self.count.value += 1
@@ -79,18 +88,14 @@ def main(argv):
7988

8089
print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')
8190

82-
if mode not in ['nnx', 'jax']:
83-
raise ValueError(f'Invalid mode: {mode}')
84-
8591
X = np.linspace(0, 1, 100)[:, None]
8692
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
8793

88-
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
89-
tx = optax.sgd(1e-3)
90-
optimizer = nnx.Optimizer(model, tx)
91-
t0 = time()
92-
93-
if mode == 'nnx':
94+
if mode == 'nnx' or mode == 'all':
95+
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
96+
tx = optax.sgd(1e-3)
97+
optimizer = nnx.Optimizer(model, tx)
98+
t0 = time()
9499

95100
@nnx.jit
96101
def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch):
@@ -115,11 +120,22 @@ def test_step_nnx(model: MLP, batch):
115120

116121
if step % 1000 == 0:
117122
logs = test_step_nnx(model, (X, Y))
118-
print(f"step: {step}, loss: {logs['loss']}")
119123

120124
if step >= total_steps - 1:
121125
break
122-
else:
126+
127+
print('### NNX ###')
128+
print(f"final loss: {logs['loss']}")
129+
total_time = time() - t0
130+
print('total time:', total_time)
131+
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
132+
print('times called:', model.count.value)
133+
134+
if mode == 'jax' or mode == 'all':
135+
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
136+
tx = optax.sgd(1e-3)
137+
optimizer = nnx.Optimizer(model, tx)
138+
t0 = time()
123139

124140
@jax.jit
125141
def train_step_jax(graphdef, state, batch):
@@ -151,17 +167,18 @@ def test_step_jax(graphdef, state, batch):
151167

152168
if step % 1000 == 0:
153169
state, logs = test_step_jax(graphdef, state, (X, Y))
154-
print(f"step: {step}, loss: {logs['loss']}")
155170

156171
if step >= total_steps - 1:
157172
break
158173

159174
model, optimizer = nnx.merge(graphdef, state)
160175

161-
total_time = time() - t0
162-
print('total time:', total_time)
163-
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
164-
print('times called:', model.count.value)
176+
print('### JAX ###')
177+
print(f"final loss: {logs['loss']}")
178+
total_time = time() - t0
179+
print('total time:', total_time)
180+
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
181+
print('times called:', model.count.value)
165182

166183

167184
if __name__ == '__main__':

flax/nnx/extract.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ class GraphDefState(struct.PyTreeNode):
254254

255255
class NodeStates(struct.PyTreeNode):
256256
_graphdef: graph.GraphDef[tp.Any] | None
257-
states: tuple[graph.GraphState, ...]
257+
states: tuple[graph.GraphState | graph.GraphFlatState, ...]
258258
metadata: tp.Any = struct.field(pytree_node=False)
259259

260260
@property
@@ -264,7 +264,7 @@ def graphdef(self) -> graph.GraphDef[tp.Any]:
264264
return self._graphdef
265265

266266
@property
267-
def state(self) -> graph.GraphState:
267+
def state(self) -> graph.GraphState | graph.GraphFlatState:
268268
if len(self.states) != 1:
269269
raise ValueError(
270270
f'Expected exactly one GraphDefState, got {len(self.states)}'
@@ -275,15 +275,19 @@ def state(self) -> graph.GraphState:
275275
def from_split(
276276
cls,
277277
graphdef: graph.GraphDef[tp.Any],
278-
state: graph.GraphState,
278+
state: graph.GraphState | graph.GraphFlatState,
279279
/,
280-
*states: graph.GraphState,
280+
*states: graph.GraphState | graph.GraphFlatState,
281281
metadata: tp.Any = None,
282282
):
283283
return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata)
284284

285285
@classmethod
286-
def from_states(cls, state: graph.GraphState, *states: graph.GraphState):
286+
def from_states(
287+
cls,
288+
state: graph.GraphState | graph.GraphFlatState,
289+
*states: graph.GraphState | graph.GraphFlatState,
290+
):
287291
return cls(_graphdef=None, states=(state, *states), metadata=None)
288292

289293
@classmethod

flax/nnx/graph.py

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
CallableProxy,
3131
DelayedAccessor,
3232
)
33-
from flax.nnx.statelib import State
33+
from flax.nnx.statelib import FlatState, State
3434
from flax.nnx import variablelib
3535
from flax.nnx.variablelib import Variable, VariableState
3636
from flax.typing import Key, PathParts, is_key_like
@@ -53,6 +53,7 @@
5353
StateLeaf = VariableState[tp.Any]
5454
NodeLeaf = Variable[tp.Any]
5555
GraphState = State[Key, StateLeaf]
56+
GraphFlatState = FlatState[StateLeaf]
5657

5758

5859
def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
@@ -377,7 +378,9 @@ def _apply(
377378
module = merge(self, state, *states)
378379
fn = accessor(module)
379380
out = fn(*args, **kwargs)
380-
return out, flatten(module)
381+
graphdef, flat_state = flatten(module)
382+
state_ = State.from_flat_path(flat_state)
383+
return out, (graphdef, state_)
381384

382385
return CallableProxy(_apply, accessor) # type: ignore
383386

@@ -389,7 +392,7 @@ def _apply(
389392

390393
def flatten(
391394
node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None
392-
) -> tuple[GraphDef[Node], GraphState]:
395+
) -> tuple[GraphDef[Node], FlatState[tp.Any]]:
393396
"""Flattens a graph node into a (graphdef, state) pair.
394397
395398
Args:
@@ -402,7 +405,7 @@ def flatten(
402405
ref_index = RefMap()
403406
flat_state: list[tuple[PathParts, StateLeaf]] = []
404407
graphdef = _graph_flatten((), ref_index, flat_state, node)
405-
return graphdef, GraphState.from_flat_path(flat_state)
408+
return graphdef, FlatState(flat_state)
406409

407410

408411
def _graph_flatten(
@@ -811,8 +814,11 @@ def split(
811814
ctx = (
812815
current_update_context(self.ctxtag) if self.ctxtag is not None else None
813816
)
814-
graphdef, state = flatten(node, self.ref_index)
815-
states = _split_state(state, filters)
817+
graphdef, flat_state = flatten(node, self.ref_index)
818+
flat_states = _split_state(flat_state, filters)
819+
states = tuple(
820+
State.from_flat_path(flat_state) for flat_state in flat_states
821+
)
816822
if ctx is not None:
817823
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
818824
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
@@ -822,6 +828,47 @@ def split(
822828

823829
return graphdef, *states
824830

831+
@tp.overload
832+
def flatten(
833+
self, graph_node: A, /
834+
) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
835+
@tp.overload
836+
def flatten(
837+
self, graph_node: A, first: filterlib.Filter, /
838+
) -> tuple[GraphDef[A], FlatState[VariableState[tp.Any]]]: ...
839+
@tp.overload
840+
def flatten(
841+
self,
842+
graph_node: A,
843+
first: filterlib.Filter,
844+
second: filterlib.Filter,
845+
/,
846+
*filters: filterlib.Filter,
847+
) -> tuple[
848+
GraphDef[A],
849+
FlatState[VariableState[tp.Any]],
850+
tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]],
851+
]: ...
852+
def flatten(
853+
self, node: A, *filters: filterlib.Filter
854+
) -> tuple[
855+
GraphDef[A], tpe.Unpack[tuple[FlatState[VariableState[tp.Any]], ...]]
856+
]:
857+
ctx = (
858+
current_update_context(self.ctxtag) if self.ctxtag is not None else None
859+
)
860+
graphdef, flat_state = flatten(node, self.ref_index)
861+
flat_states = _split_state(flat_state, filters)
862+
863+
if ctx is not None:
864+
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):
865+
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
866+
graphdef = dataclasses.replace(
867+
graphdef, index_mapping=HashableMapping(index_to_index, copy=False)
868+
)
869+
870+
return graphdef, *flat_states
871+
825872

826873
@contextlib.contextmanager
827874
def split_context(ctxtag: str | None = None):
@@ -874,6 +921,39 @@ def merge(
874921
)
875922
return node
876923

924+
def unflatten(
925+
self,
926+
graphdef: GraphDef[A],
927+
flat_state: GraphFlatState,
928+
/,
929+
*flat_states: GraphFlatState,
930+
) -> A:
931+
ctx = (
932+
current_update_context(self.ctxtag) if self.ctxtag is not None else None
933+
)
934+
if (
935+
ctx is not None
936+
and isinstance(graphdef, NodeDef)
937+
and graphdef.index_mapping is not None
938+
):
939+
# outer merge (4), create index_ref_cache
940+
assert ctx.ref_index is not None
941+
index_ref_cache = compose_mapping_reversed(
942+
ctx.ref_index, graphdef.index_mapping
943+
)
944+
else:
945+
# inner merge (2)
946+
index_ref_cache = None
947+
948+
state = FlatState.merge(flat_state, *flat_states).to_nested_state()
949+
node = unflatten(
950+
graphdef,
951+
state,
952+
index_ref=self.index_ref,
953+
index_ref_cache=index_ref_cache,
954+
)
955+
return node
956+
877957

878958
@contextlib.contextmanager
879959
def merge_context(ctxtag: str | None = None):
@@ -1001,9 +1081,11 @@ def split(
10011081
filters are passed, a single :class:`State` is returned.
10021082
"""
10031083
ref_index: RefMap[tp.Any, Index] = RefMap()
1004-
graphdef, state = flatten(node, ref_index)
1005-
states = _split_state(state, filters)
1006-
1084+
graphdef, flat_state = flatten(node, ref_index)
1085+
states = tuple(
1086+
State.from_flat_path(flat_state)
1087+
for flat_state in _split_state(flat_state, filters)
1088+
)
10071089
if self.index_ref is not None and isinstance(graphdef, NodeDef):
10081090
index_to_index = compose_mapping(self.index_ref, ref_index)
10091091
graphdef = dataclasses.replace(
@@ -1195,13 +1277,13 @@ def current_update_context(tag: str) -> UpdateContext:
11951277
# --------------------------------------------------------
11961278

11971279
def _split_state(
1198-
state: GraphState,
1280+
state: FlatState[tp.Any],
11991281
filters: tuple[filterlib.Filter, ...],
1200-
) -> tuple[GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
1282+
) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]:
12011283
if not filters:
12021284
return (state,)
12031285
states = state.split(*filters)
1204-
if isinstance(states, State):
1286+
if not isinstance(states, tuple):
12051287
return (states,)
12061288
assert len(states) > 0
12071289
return states # type: ignore[return-value]
@@ -1292,9 +1374,11 @@ def split(
12921374
``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
12931375
filters are passed, a single ``State`` is returned.
12941376
"""
1295-
graphdef, state = flatten(node)
1296-
states = _split_state(state, filters)
1297-
return graphdef, *states
1377+
graphdef, flat_state = flatten(node)
1378+
flat_states = _split_state(flat_state, filters)
1379+
states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
1380+
return graphdef, *states # type: ignore[return-value]
1381+
12981382

12991383
def merge(
13001384
graphdef: GraphDef[A],
@@ -1486,6 +1570,7 @@ def state(
14861570
One or more :class:`State` mappings.
14871571
"""
14881572
_, state = flatten(node)
1573+
state = state.to_nested_state()
14891574

14901575
states: GraphState | tuple[GraphState, ...]
14911576
if len(filters) == 0:

flax/nnx/reprlib.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ def __nnx_repr__(self):
111111
for key, value in self.items():
112112
yield Attr(repr(key), value)
113113

114+
class SequenceReprMixin(tp.Sequence[A], Representable):
115+
def __nnx_repr__(self):
116+
yield Object(type='', value_sep='', start='[', end=']')
117+
118+
for value in self:
119+
yield Attr('', value)
120+
121+
114122
@dataclasses.dataclass(repr=False)
115123
class PrettyMapping(Representable):
116124
mapping: tp.Mapping

0 commit comments

Comments
 (0)