30
30
CallableProxy ,
31
31
DelayedAccessor ,
32
32
)
33
- from flax .nnx .statelib import State
33
+ from flax .nnx .statelib import FlatState , State
34
34
from flax .nnx import variablelib
35
35
from flax .nnx .variablelib import Variable , VariableState
36
36
from flax .typing import Key , PathParts , is_key_like
53
53
StateLeaf = VariableState [tp .Any ]
54
54
NodeLeaf = Variable [tp .Any ]
55
55
GraphState = State [Key , StateLeaf ]
56
+ GraphFlatState = FlatState [StateLeaf ]
56
57
57
58
58
59
def is_state_leaf (x : tp .Any ) -> tpe .TypeGuard [StateLeaf ]:
@@ -377,7 +378,9 @@ def _apply(
377
378
module = merge (self , state , * states )
378
379
fn = accessor (module )
379
380
out = fn (* args , ** kwargs )
380
- return out , flatten (module )
381
+ graphdef , flat_state = flatten (module )
382
+ state_ = State .from_flat_path (flat_state )
383
+ return out , (graphdef , state_ )
381
384
382
385
return CallableProxy (_apply , accessor ) # type: ignore
383
386
@@ -389,7 +392,7 @@ def _apply(
389
392
390
393
def flatten (
391
394
node : Node , / , ref_index : RefMap [tp .Any , Index ] | None = None
392
- ) -> tuple [GraphDef [Node ], GraphState ]:
395
+ ) -> tuple [GraphDef [Node ], FlatState [ tp . Any ] ]:
393
396
"""Flattens a graph node into a (graphdef, state) pair.
394
397
395
398
Args:
@@ -402,7 +405,7 @@ def flatten(
402
405
ref_index = RefMap ()
403
406
flat_state : list [tuple [PathParts , StateLeaf ]] = []
404
407
graphdef = _graph_flatten ((), ref_index , flat_state , node )
405
- return graphdef , GraphState . from_flat_path (flat_state )
408
+ return graphdef , FlatState (flat_state )
406
409
407
410
408
411
def _graph_flatten (
@@ -811,8 +814,11 @@ def split(
811
814
ctx = (
812
815
current_update_context (self .ctxtag ) if self .ctxtag is not None else None
813
816
)
814
- graphdef , state = flatten (node , self .ref_index )
815
- states = _split_state (state , filters )
817
+ graphdef , flat_state = flatten (node , self .ref_index )
818
+ flat_states = _split_state (flat_state , filters )
819
+ states = tuple (
820
+ State .from_flat_path (flat_state ) for flat_state in flat_states
821
+ )
816
822
if ctx is not None :
817
823
if ctx .index_ref is not None and isinstance (graphdef , NodeDef ):
818
824
index_to_index = compose_mapping (ctx .index_ref , self .ref_index )
@@ -822,6 +828,47 @@ def split(
822
828
823
829
return graphdef , * states
824
830
831
+ @tp .overload
832
+ def flatten (
833
+ self , graph_node : A , /
834
+ ) -> tuple [GraphDef [A ], FlatState [VariableState [tp .Any ]]]: ...
835
+ @tp .overload
836
+ def flatten (
837
+ self , graph_node : A , first : filterlib .Filter , /
838
+ ) -> tuple [GraphDef [A ], FlatState [VariableState [tp .Any ]]]: ...
839
+ @tp .overload
840
+ def flatten (
841
+ self ,
842
+ graph_node : A ,
843
+ first : filterlib .Filter ,
844
+ second : filterlib .Filter ,
845
+ / ,
846
+ * filters : filterlib .Filter ,
847
+ ) -> tuple [
848
+ GraphDef [A ],
849
+ FlatState [VariableState [tp .Any ]],
850
+ tpe .Unpack [tuple [FlatState [VariableState [tp .Any ]], ...]],
851
+ ]: ...
852
+ def flatten (
853
+ self , node : A , * filters : filterlib .Filter
854
+ ) -> tuple [
855
+ GraphDef [A ], tpe .Unpack [tuple [FlatState [VariableState [tp .Any ]], ...]]
856
+ ]:
857
+ ctx = (
858
+ current_update_context (self .ctxtag ) if self .ctxtag is not None else None
859
+ )
860
+ graphdef , flat_state = flatten (node , self .ref_index )
861
+ flat_states = _split_state (flat_state , filters )
862
+
863
+ if ctx is not None :
864
+ if ctx .index_ref is not None and isinstance (graphdef , NodeDef ):
865
+ index_to_index = compose_mapping (ctx .index_ref , self .ref_index )
866
+ graphdef = dataclasses .replace (
867
+ graphdef , index_mapping = HashableMapping (index_to_index , copy = False )
868
+ )
869
+
870
+ return graphdef , * flat_states
871
+
825
872
826
873
@contextlib .contextmanager
827
874
def split_context (ctxtag : str | None = None ):
@@ -874,6 +921,39 @@ def merge(
874
921
)
875
922
return node
876
923
924
+ def unflatten (
925
+ self ,
926
+ graphdef : GraphDef [A ],
927
+ flat_state : GraphFlatState ,
928
+ / ,
929
+ * flat_states : GraphFlatState ,
930
+ ) -> A :
931
+ ctx = (
932
+ current_update_context (self .ctxtag ) if self .ctxtag is not None else None
933
+ )
934
+ if (
935
+ ctx is not None
936
+ and isinstance (graphdef , NodeDef )
937
+ and graphdef .index_mapping is not None
938
+ ):
939
+ # outer merge (4), create index_ref_cache
940
+ assert ctx .ref_index is not None
941
+ index_ref_cache = compose_mapping_reversed (
942
+ ctx .ref_index , graphdef .index_mapping
943
+ )
944
+ else :
945
+ # inner merge (2)
946
+ index_ref_cache = None
947
+
948
+ state = FlatState .merge (flat_state , * flat_states ).to_nested_state ()
949
+ node = unflatten (
950
+ graphdef ,
951
+ state ,
952
+ index_ref = self .index_ref ,
953
+ index_ref_cache = index_ref_cache ,
954
+ )
955
+ return node
956
+
877
957
878
958
@contextlib .contextmanager
879
959
def merge_context (ctxtag : str | None = None ):
@@ -1001,9 +1081,11 @@ def split(
1001
1081
filters are passed, a single :class:`State` is returned.
1002
1082
"""
1003
1083
ref_index : RefMap [tp .Any , Index ] = RefMap ()
1004
- graphdef , state = flatten (node , ref_index )
1005
- states = _split_state (state , filters )
1006
-
1084
+ graphdef , flat_state = flatten (node , ref_index )
1085
+ states = tuple (
1086
+ State .from_flat_path (flat_state )
1087
+ for flat_state in _split_state (flat_state , filters )
1088
+ )
1007
1089
if self .index_ref is not None and isinstance (graphdef , NodeDef ):
1008
1090
index_to_index = compose_mapping (self .index_ref , ref_index )
1009
1091
graphdef = dataclasses .replace (
@@ -1195,13 +1277,13 @@ def current_update_context(tag: str) -> UpdateContext:
1195
1277
# --------------------------------------------------------
1196
1278
1197
1279
def _split_state (
1198
- state : GraphState ,
1280
+ state : FlatState [ tp . Any ] ,
1199
1281
filters : tuple [filterlib .Filter , ...],
1200
- ) -> tuple [GraphState , tpe .Unpack [tuple [GraphState , ...]]]:
1282
+ ) -> tuple [FlatState [ tp . Any ] , tpe .Unpack [tuple [FlatState [ tp . Any ] , ...]]]:
1201
1283
if not filters :
1202
1284
return (state ,)
1203
1285
states = state .split (* filters )
1204
- if isinstance (states , State ):
1286
+ if not isinstance (states , tuple ):
1205
1287
return (states ,)
1206
1288
assert len (states ) > 0
1207
1289
return states # type: ignore[return-value]
@@ -1292,9 +1374,11 @@ def split(
1292
1374
``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no
1293
1375
filters are passed, a single ``State`` is returned.
1294
1376
"""
1295
- graphdef , state = flatten (node )
1296
- states = _split_state (state , filters )
1297
- return graphdef , * states
1377
+ graphdef , flat_state = flatten (node )
1378
+ flat_states = _split_state (flat_state , filters )
1379
+ states = tuple (State .from_flat_path (flat_state ) for flat_state in flat_states )
1380
+ return graphdef , * states # type: ignore[return-value]
1381
+
1298
1382
1299
1383
def merge (
1300
1384
graphdef : GraphDef [A ],
@@ -1486,6 +1570,7 @@ def state(
1486
1570
One or more :class:`State` mappings.
1487
1571
"""
1488
1572
_ , state = flatten (node )
1573
+ state = state .to_nested_state ()
1489
1574
1490
1575
states : GraphState | tuple [GraphState , ...]
1491
1576
if len (filters ) == 0 :
0 commit comments