Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 8 additions & 34 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,34 +188,6 @@ class AttrPriority(enum.IntEnum):
LOW = 100


class PriorityStr(str):
_priority: AttrPriority

def __new__(cls, priority: AttrPriority, value: str):
obj = super().__new__(cls, value)
obj._priority = priority
return obj

def _check_and_get_priority(self, other) -> AttrPriority:
if not isinstance(other, (str, PriorityStr)):
raise NotImplementedError(
f'Cannot compare {type(self)} with {type(other)}'
)
if isinstance(other, PriorityStr):
return other._priority
return AttrPriority.DEFAULT

def __lt__(self, other) -> bool:
other_priority = self._check_and_get_priority(other)
if self._priority == other_priority:
return super().__lt__(other)
return self._priority < other_priority

def __gt__(self, other) -> bool:
other_priority = self._check_and_get_priority(other)
if self._priority == other_priority:
return super().__gt__(other)
return self._priority > other_priority

class ModuleBase:
if tp.TYPE_CHECKING:
Expand All @@ -241,7 +213,7 @@ def _getattr(self, name: str) -> tp.Any:
return value

def _setattr(self, name: str, value: tp.Any) -> None:
if self.scope is not None:
if getattr(self, 'scope', None) is not None:
if name in vars(self) and isinstance(
state := vars(self)[name], ModuleState
):
Expand All @@ -254,11 +226,13 @@ def _setattr(self, name: str, value: tp.Any) -> None:

def _graph_node_flatten(self):
nodes = vars(self).copy()
keys = (
PriorityStr(self.attr_priorities.get(k, AttrPriority.DEFAULT), k)
for k in nodes.keys()
)
sorted_nodes = list((k, nodes[k]) for k in sorted(keys))
def get_priority(k):
if k in ('scope', '_pytree__state', 'attr_priorities'):
return AttrPriority.HIGH
return self.attr_priorities.get(k, AttrPriority.DEFAULT)

sorted_keys = sorted(nodes.keys(), key=lambda k: (get_priority(k), k))
sorted_nodes = list((k, nodes[k]) for k in sorted_keys)
return sorted_nodes, type(self)

def set_attr_priority(self, name: str, value: AttrPriority):
Expand Down
Loading