Skip to content

Commit dd3c18b

Browse files
author
Flax Authors
committed
Merge pull request #4632 from google:add-max-repr-depth
PiperOrigin-RevId: 737815908
2 parents 42098ad + 1bdbab3 commit dd3c18b

File tree

2 files changed

+109
-36
lines changed

2 files changed

+109
-36
lines changed

flax/configurations.py

+60
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
class Config:
2525
flax_use_flaxlib: bool
26+
flax_max_repr_depth: int | None
2627
# See https://google.github.io/pytype/faq.html.
2728
_HAS_DYNAMIC_ATTRIBUTES = True
2829

@@ -122,6 +123,38 @@ def bool_flag(name: str, *, default: bool, help: str) -> FlagHolder[bool]:
122123
return fh
123124

124125

126+
def int_flag(name: str, *, default: int | None, help: str) -> FlagHolder[int]:
127+
"""Set up an integer flag.
128+
129+
Example::
130+
131+
num_foo = int_flag(
132+
name='flax_num_foo',
133+
default=42,
134+
help='Number of foo.',
135+
)
136+
137+
Now the ``FLAX_NUM_FOO`` shell environment variable can be used to
138+
control the process-level value of the flag, in addition to using e.g.
139+
``config.update("flax_num_foo", 42)`` directly.
140+
141+
Args:
142+
name: converted to lowercase to define the name of the flag. It is
143+
converted to uppercase to define the corresponding shell environment
144+
variable.
145+
default: a default value for the flag.
146+
help: used to populate the docstring of the returned flag holder object.
147+
148+
Returns:
149+
A flag holder object for accessing the value of the flag.
150+
"""
151+
name = name.lower()
152+
config._add_option(name, static_int_env(name.upper(), default))
153+
fh = FlagHolder[int](name, help)
154+
setattr(Config, name, property(lambda _: fh.value, doc=help))
155+
return fh
156+
157+
125158
def static_bool_env(varname: str, default: bool) -> bool:
126159
"""Read an environment variable and interpret it as a boolean.
127160
@@ -149,6 +182,27 @@ def static_bool_env(varname: str, default: bool) -> bool:
149182
)
150183

151184

185+
def static_int_env(varname: str, default: int | None) -> int | None:
186+
"""Read an environment variable and interpret it as an integer.
187+
188+
Args:
189+
varname: the name of the variable
190+
default: the default integer value
191+
Returns:
192+
integer return value derived from defaults and environment.
193+
Raises: ValueError if the environment variable is not an integer.
194+
"""
195+
val = os.getenv(varname)
196+
if val is None:
197+
return default
198+
try:
199+
return int(val)
200+
except ValueError:
201+
raise ValueError(
202+
f'invalid integer value {val!r} for environment {varname!r}'
203+
) from None
204+
205+
152206
@contextmanager
153207
def temp_flip_flag(var_name: str, var_value: bool):
154208
"""Context manager to temporarily flip feature flags for test functions.
@@ -211,4 +265,10 @@ def temp_flip_flag(var_name: str, var_value: bool):
211265
name='flax_use_flaxlib',
212266
default=False,
213267
help='Whether to use flaxlib for C++ acceleration.',
268+
)
269+
270+
flax_max_repr_depth = int_flag(
271+
name='flax_max_repr_depth',
272+
default=None,
273+
help='Maximum depth of reprs for nested flax objects. Default is None (no limit).',
214274
)

flax/nnx/reprlib.py

+49-36
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import sys
1818
import threading
1919
import typing as tp
20+
from flax import config as flax_config
2021

2122
A = tp.TypeVar('A')
2223
B = tp.TypeVar('B')
@@ -90,6 +91,7 @@ class Color(tp.NamedTuple):
9091
@dataclasses.dataclass
9192
class ReprContext(threading.local):
9293
current_color: Color = COLOR
94+
depth: int = 0
9395

9496

9597
REPR_CONTEXT = ReprContext()
@@ -172,51 +174,62 @@ def __str__(self) -> str:
172174

173175

174176
def get_repr(obj: Representable) -> str:
175-
if not isinstance(obj, Representable):
176-
raise TypeError(f'Object {obj!r} is not representable')
177-
178-
c = REPR_CONTEXT.current_color
179-
iterator = obj.__nnx_repr__()
180-
config = next(iterator)
177+
REPR_CONTEXT.depth += 1
178+
try:
179+
if not isinstance(obj, Representable):
180+
raise TypeError(f'Object {obj!r} is not representable')
181181

182-
if not isinstance(config, Object):
183-
raise TypeError(f'First item must be Config, got {type(config).__name__}')
182+
c = REPR_CONTEXT.current_color
183+
iterator = obj.__nnx_repr__()
184+
config = next(iterator)
184185

185-
kv_sep = f'{c.SEP}{config.kv_sep}{c.END}'
186+
if not isinstance(config, Object):
187+
raise TypeError(f'First item must be Config, got {type(config).__name__}')
186188

187-
def _repr_elem(elem: tp.Any) -> str:
188-
if not isinstance(elem, Attr):
189-
raise TypeError(f'Item must be Elem, got {type(elem).__name__}')
189+
kv_sep = f'{c.SEP}{config.kv_sep}{c.END}'
190190

191-
value_repr = elem.value if elem.use_raw_value else colorized(elem.value)
192-
value_repr = value_repr.replace('\n', '\n' + config.indent)
193-
key = elem.key if elem.use_raw_key else f'{c.ATTRIBUTE}{elem.key}{c.END}'
194-
indent = '' if config.same_line else config.indent
191+
def _repr_elem(elem: tp.Any) -> str:
192+
if not isinstance(elem, Attr):
193+
raise TypeError(f'Item must be Elem, got {type(elem).__name__}')
195194

196-
return f'{indent}{elem.start}{key}{kv_sep}{value_repr}{elem.end}'
195+
value_repr = elem.value if elem.use_raw_value else colorized(elem.value)
196+
value_repr = value_repr.replace('\n', '\n' + config.indent)
197+
key = elem.key if elem.use_raw_key else f'{c.ATTRIBUTE}{elem.key}{c.END}'
198+
indent = '' if config.same_line else config.indent
197199

198-
elems = config.elem_sep.join(map(_repr_elem, iterator))
200+
return f'{indent}{elem.start}{key}{kv_sep}{value_repr}{elem.end}'
199201

200-
if elems:
201-
if config.same_line:
202-
elems_repr = elems
203-
comment = ''
202+
max_depth_reached = (
203+
flax_config.flax_max_repr_depth is not None
204+
and REPR_CONTEXT.depth > flax_config.flax_max_repr_depth
205+
)
206+
if max_depth_reached:
207+
elems = '...'
204208
else:
205-
elems_repr = '\n' + elems + '\n'
206-
comment = f'{c.COMMENT}{config.comment}{c.END}'
207-
else:
208-
elems_repr = config.empty_repr
209-
comment = ''
210-
211-
type_repr = (
212-
config.type if isinstance(config.type, str) else config.type.__name__
213-
)
214-
type_repr = f'{c.TYPE}{type_repr}{c.END}' if type_repr else ''
215-
start = f'{c.PAREN}{config.start}{c.END}' if config.start else ''
216-
end = f'{c.PAREN}{config.end}{c.END}' if config.end else ''
209+
elems = config.elem_sep.join(map(_repr_elem, iterator))
210+
211+
if elems:
212+
if config.same_line or max_depth_reached:
213+
elems_repr = elems
214+
comment = ''
215+
else:
216+
elems_repr = '\n' + elems + '\n'
217+
comment = f'{c.COMMENT}{config.comment}{c.END}'
218+
else:
219+
elems_repr = config.empty_repr
220+
comment = ''
217221

218-
out = f'{type_repr}{start}{comment}{elems_repr}{end}'
219-
return out
222+
type_repr = (
223+
config.type if isinstance(config.type, str) else config.type.__name__
224+
)
225+
type_repr = f'{c.TYPE}{type_repr}{c.END}' if type_repr else ''
226+
start = f'{c.PAREN}{config.start}{c.END}' if config.start else ''
227+
end = f'{c.PAREN}{config.end}{c.END}' if config.end else ''
228+
229+
out = f'{type_repr}{start}{comment}{elems_repr}{end}'
230+
return out
231+
finally:
232+
REPR_CONTEXT.depth -= 1
220233

221234
class MappingReprMixin(Representable):
222235
def __nnx_repr__(self):

0 commit comments

Comments
 (0)