Skip to content

Commit

Permalink
[nnx] add tabulate
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 8, 2025
1 parent e2134af commit 3406332
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 53 deletions.
5 changes: 0 additions & 5 deletions flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,6 @@ def tabulate(
Total Parameters: 50 (200 B)
**Note**: rows order in the table does not represent execution order,
instead it aligns with the order of keys in `variables` which are sorted
alphabetically.
**Note**: `vjp_flops` returns `0` if the module is not differentiable.
Args:
Expand Down
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@
from .extract import to_tree as to_tree
from .extract import from_tree as from_tree
from .extract import NodeStates as NodeStates
from .summary import tabulate as tabulate
from . import traversals as traversals
4 changes: 3 additions & 1 deletion flax/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def to_predicate(filter: Filter) -> Predicate:
else:
raise TypeError(f'Invalid collection filter: {filter:!r}. ')

def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]:
def filters_to_predicates(
filters: tp.Sequence[Filter],
) -> tuple[Predicate, ...]:
for i, filter_ in enumerate(filters):
if filter_ in (..., True) and i != len(filters) - 1:
remaining_filters = filters[i + 1 :]
Expand Down
301 changes: 301 additions & 0 deletions flax/nnx/summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file


import io
import typing as tp
from itertools import groupby
from types import MappingProxyType

import jax
import rich.console
import rich.table
import rich.text
import yaml
import jax.numpy as jnp

from flax.nnx import graph, rnglib, variablelib


def tabulate(
obj,
depth: int | None = None,
table_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}),
column_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}),
console_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}),
) -> str:
"""Creates a summary of the graph object represented as a table.
The table summarizes the object's state and metadata. The table is
structured as follows:
- The first column represents the path of the object in the graph.
- The second column represents the type of the object.
- The following columns provide information about the object's state,
grouped by Variable types.
Example:
>>> from flax import nnx
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.bn = nnx.BatchNorm(dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.2, rngs=rngs)
...
... def __call__(self, x):
... return nnx.relu(self.dropout(self.bn(self.linear(x))))
...
>>> class Foo(nnx.Module):
... def __init__(self, rngs: nnx.Rngs):
... self.block1 = Block(32, 128, rngs=rngs)
... self.block2 = Block(128, 10, rngs=rngs)
...
... def __call__(self, x):
... return self.block2(self.block1(x))
...
>>> foo = Foo(nnx.Rngs(0))
>>> # print(nnx.tabulate(foo))
Foo Summary
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ type ┃ BatchStat ┃ Param ┃ RngState ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ block1/bn │ BatchNorm │ mean: float32[128] │ bias: float32[128] │ │
│ │ │ var: float32[128] │ scale: float32[128] │ │
│ │ │ │ │ │
│ │ │ 256 (1.0 KB) │ 256 (1.0 KB) │ │
├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤
│ block1/dropout/rngs/default │ RngStream │ │ │ count: │
│ │ │ │ │ value: uint32[] │
│ │ │ │ │ tag: default │
│ │ │ │ │ key: │
│ │ │ │ │ value: key<fry>[] │
│ │ │ │ │ tag: default │
│ │ │ │ │ │
│ │ │ │ │ 2 (12 B) │
├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤
│ block1/linear │ Linear │ │ bias: float32[128] │ │
│ │ │ │ kernel: float32[32,128] │ │
│ │ │ │ │ │
│ │ │ │ 4,224 (16.9 KB) │ │
├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤
│ block2/bn │ BatchNorm │ mean: float32[10] │ bias: float32[10] │ │
│ │ │ var: float32[10] │ scale: float32[10] │ │
│ │ │ │ │ │
│ │ │ 20 (80 B) │ 20 (80 B) │ │
├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤
│ block2/linear │ Linear │ │ bias: float32[10] │ │
│ │ │ │ kernel: float32[128,10] │ │
│ │ │ │ │ │
│ │ │ │ 1,290 (5.2 KB) │ │
├─────────────────────────────┼───────────┼────────────────────┼─────────────────────────┼─────────────────────┤
│ │ Total │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │
└─────────────────────────────┴───────────┴────────────────────┴─────────────────────────┴─────────────────────┘
Total Parameters: 6,068 (24.3 KB)
Note that ``block2/dropout`` is not shown in the table because it shares the
same ``RngState`` with ``block1/dropout``.
Args:
obj: A object to summarize. It can a pytree or a graph objects
such as nnx.Module or nnx.Optimizer.
depth: The depth of the table.
table_kwargs: An optional dictionary with additional keyword arguments
that are passed to ``rich.table.Table`` constructor.
column_kwargs: An optional dictionary with additional keyword arguments
that are passed to ``rich.table.Table.add_column`` when adding columns to
the table.
console_kwargs: An optional dictionary with additional keyword arguments
that are passed to ``rich.console.Console`` when rendering the table.
Default arguments are ``{'force_terminal': True, 'force_jupyter':
False}``.
Returns:
A string summarizing the object.
"""
_console_kwargs = {'force_terminal': True, 'force_jupyter': False}
_console_kwargs.update(console_kwargs)
state = graph.state(obj)
graph_map = dict(graph.iter_graph(obj))
flat_state = sorted(state.flat_state())

def key_fn(
path_state: tuple[graph.PathParts, variablelib.VariableState[tp.Any]],
):
path, _ = path_state
if depth is None or len(path) <= depth:
return path[:-1]
else:
return path[:depth]

rows = groupby(flat_state, key_fn)
table = sorted((path, list(flat_states)) for path, flat_states in rows)

state_types_set = {variable_state.type for _, variable_state in flat_state}
# replace RngKey and RngCount with RngState
if rnglib.RngKey in state_types_set:
state_types_set.remove(rnglib.RngKey)
state_types_set.add(rnglib.RngState)
if rnglib.RngCount in state_types_set:
state_types_set.remove(rnglib.RngCount)
state_types_set.add(rnglib.RngState)
# sort based on MRO
state_types = _sort_variable_types(state_types_set)

rich_table = rich.table.Table(
show_header=True,
show_lines=True,
show_footer=True,
title=f'{type(obj).__name__} Summary',
**table_kwargs,
)

rich_table.add_column('path', **column_kwargs)
rich_table.add_column('type', **column_kwargs)

for state_type in state_types:
rich_table.add_column(state_type.__name__, **column_kwargs)

for key_path, row_states in table:
row: list[str] = []
node = graph_map[key_path]
type_state_groups = variablelib.split_flat_state(row_states, state_types)
path_str = '/'.join(map(str, key_path))
node_type = type(node).__name__
row.extend([path_str, node_type])

for state_type, type_path_and_states in zip(state_types, type_state_groups):
attributes = {}
for state_path, variable_state in type_path_and_states:
if len(state_path) == len(key_path) + 1:
name = str(state_path[-1])
value = variable_state.value
value_repr = _render_array(value) if _has_shape_dtype(value) else ''
metadata = variable_state.get_metadata()

if metadata:
attributes[name] = {
'value': value_repr,
**metadata,
}
elif value_repr:
attributes[name] = value_repr

if attributes:
col_repr = _as_yaml_str(attributes) + '\n\n'
else:
col_repr = ''

type_states = [state for _, state in type_path_and_states]
size_, bytes_ = _size_and_bytes(type_states)
col_repr += f'[bold]{_size_and_bytes_repr(size_, bytes_)}[/bold]'
row.append(col_repr)

rich_table.add_row(*row)

rich_table.columns[1].footer = rich.text.Text.from_markup(
'Total', justify='right'
)
flat_states = variablelib.split_flat_state(flat_state, state_types)

for i, (state_type, type_path_and_states) in enumerate(
zip(state_types, flat_states)
):
type_states = [state for _, state in type_path_and_states]
size_, bytes_ = _size_and_bytes(type_states)
size_repr = _size_and_bytes_repr(size_, bytes_)
rich_table.columns[i + 2].footer = size_repr

rich_table.caption_style = 'bold'
rich_table.caption = (
f'\nTotal Parameters: {_size_and_bytes_repr(*_size_and_bytes(state))}'
)

return '\n' + _get_rich_repr(rich_table, _console_kwargs) + '\n'


def _get_rich_repr(obj, console_kwargs):
f = io.StringIO()
console = rich.console.Console(file=f, **console_kwargs)
console.print(obj)
return f.getvalue()


def _size_and_bytes(pytree: tp.Any) -> tuple[int, int]:
leaves = jax.tree.leaves(pytree)
size = sum(x.size for x in leaves if hasattr(x, 'size'))
num_bytes = sum(
x.size * x.dtype.itemsize for x in leaves if hasattr(x, 'size')
)
return size, num_bytes


def _size_and_bytes_repr(size: int, num_bytes: int) -> str:
if not size:
return ''
bytes_repr = _bytes_repr(num_bytes)
return f'{size:,} [dim]({bytes_repr})[/dim]'


def _bytes_repr(num_bytes):
count, units = (
(f'{num_bytes / 1e9 :,.1f}', 'GB')
if num_bytes > 1e9
else (f'{num_bytes / 1e6 :,.1f}', 'MB')
if num_bytes > 1e6
else (f'{num_bytes / 1e3 :,.1f}', 'KB')
if num_bytes > 1e3
else (f'{num_bytes:,}', 'B')
)

return f'{count} {units}'


def _has_shape_dtype(value):
return hasattr(value, 'shape') and hasattr(value, 'dtype')


def _as_yaml_str(value) -> str:
if (hasattr(value, '__len__') and len(value) == 0) or value is None:
return ''

file = io.StringIO()
yaml.safe_dump(
value,
file,
default_flow_style=False,
indent=2,
sort_keys=False,
explicit_end=False,
)
return file.getvalue().replace('\n...', '').replace("'", '').strip()


def _render_array(x):
shape, dtype = jnp.shape(x), jnp.result_type(x)
shape_repr = ','.join(str(x) for x in shape)
return f'[dim]{dtype}[/dim][{shape_repr}]'


def _sort_variable_types(types: tp.Iterable[type]) -> list[type]:
def _variable_parents_count(t: type):
return sum(1 for p in t.mro() if issubclass(p, variablelib.Variable))

type_sort_key = {t: (-_variable_parents_count(t), t.__name__) for t in types}
return sorted(types, key=lambda t: type_sort_key[t])
2 changes: 1 addition & 1 deletion flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def wrapper(*args):

def split_flat_state(
flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]],
filters: tuple[filterlib.Filter, ...],
filters: tp.Sequence[filterlib.Filter],
) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]:
predicates = filterlib.filters_to_predicates(filters)
# we have n + 1 states, where n is the number of predicates
Expand Down
Loading

0 comments on commit 3406332

Please sign in to comment.