|
17 | 17 | import sys
|
18 | 18 | import threading
|
19 | 19 | import typing as tp
|
| 20 | +from flax import config as flax_config |
20 | 21 |
|
21 | 22 | A = tp.TypeVar('A')
|
22 | 23 | B = tp.TypeVar('B')
|
@@ -90,6 +91,7 @@ class Color(tp.NamedTuple):
|
90 | 91 | @dataclasses.dataclass
|
91 | 92 | class ReprContext(threading.local):
|
92 | 93 | current_color: Color = COLOR
|
| 94 | + depth: int = 0 |
93 | 95 |
|
94 | 96 |
|
95 | 97 | REPR_CONTEXT = ReprContext()
|
@@ -172,51 +174,62 @@ def __str__(self) -> str:
|
172 | 174 |
|
173 | 175 |
|
174 | 176 | def get_repr(obj: Representable) -> str:
|
175 |
| - if not isinstance(obj, Representable): |
176 |
| - raise TypeError(f'Object {obj!r} is not representable') |
177 |
| - |
178 |
| - c = REPR_CONTEXT.current_color |
179 |
| - iterator = obj.__nnx_repr__() |
180 |
| - config = next(iterator) |
| 177 | + REPR_CONTEXT.depth += 1 |
| 178 | + try: |
| 179 | + if not isinstance(obj, Representable): |
| 180 | + raise TypeError(f'Object {obj!r} is not representable') |
181 | 181 |
|
182 |
| - if not isinstance(config, Object): |
183 |
| - raise TypeError(f'First item must be Config, got {type(config).__name__}') |
| 182 | + c = REPR_CONTEXT.current_color |
| 183 | + iterator = obj.__nnx_repr__() |
| 184 | + config = next(iterator) |
184 | 185 |
|
185 |
| - kv_sep = f'{c.SEP}{config.kv_sep}{c.END}' |
| 186 | + if not isinstance(config, Object): |
| 187 | + raise TypeError(f'First item must be Config, got {type(config).__name__}') |
186 | 188 |
|
187 |
| - def _repr_elem(elem: tp.Any) -> str: |
188 |
| - if not isinstance(elem, Attr): |
189 |
| - raise TypeError(f'Item must be Elem, got {type(elem).__name__}') |
| 189 | + kv_sep = f'{c.SEP}{config.kv_sep}{c.END}' |
190 | 190 |
|
191 |
| - value_repr = elem.value if elem.use_raw_value else colorized(elem.value) |
192 |
| - value_repr = value_repr.replace('\n', '\n' + config.indent) |
193 |
| - key = elem.key if elem.use_raw_key else f'{c.ATTRIBUTE}{elem.key}{c.END}' |
194 |
| - indent = '' if config.same_line else config.indent |
| 191 | + def _repr_elem(elem: tp.Any) -> str: |
| 192 | + if not isinstance(elem, Attr): |
| 193 | + raise TypeError(f'Item must be Elem, got {type(elem).__name__}') |
195 | 194 |
|
196 |
| - return f'{indent}{elem.start}{key}{kv_sep}{value_repr}{elem.end}' |
| 195 | + value_repr = elem.value if elem.use_raw_value else colorized(elem.value) |
| 196 | + value_repr = value_repr.replace('\n', '\n' + config.indent) |
| 197 | + key = elem.key if elem.use_raw_key else f'{c.ATTRIBUTE}{elem.key}{c.END}' |
| 198 | + indent = '' if config.same_line else config.indent |
197 | 199 |
|
198 |
| - elems = config.elem_sep.join(map(_repr_elem, iterator)) |
| 200 | + return f'{indent}{elem.start}{key}{kv_sep}{value_repr}{elem.end}' |
199 | 201 |
|
200 |
| - if elems: |
201 |
| - if config.same_line: |
202 |
| - elems_repr = elems |
203 |
| - comment = '' |
| 202 | + max_depth_reached = ( |
| 203 | + flax_config.flax_max_repr_depth is not None |
| 204 | + and REPR_CONTEXT.depth > flax_config.flax_max_repr_depth |
| 205 | + ) |
| 206 | + if max_depth_reached: |
| 207 | + elems = '...' |
204 | 208 | else:
|
205 |
| - elems_repr = '\n' + elems + '\n' |
206 |
| - comment = f'{c.COMMENT}{config.comment}{c.END}' |
207 |
| - else: |
208 |
| - elems_repr = config.empty_repr |
209 |
| - comment = '' |
210 |
| - |
211 |
| - type_repr = ( |
212 |
| - config.type if isinstance(config.type, str) else config.type.__name__ |
213 |
| - ) |
214 |
| - type_repr = f'{c.TYPE}{type_repr}{c.END}' if type_repr else '' |
215 |
| - start = f'{c.PAREN}{config.start}{c.END}' if config.start else '' |
216 |
| - end = f'{c.PAREN}{config.end}{c.END}' if config.end else '' |
| 209 | + elems = config.elem_sep.join(map(_repr_elem, iterator)) |
| 210 | + |
| 211 | + if elems: |
| 212 | + if config.same_line or max_depth_reached: |
| 213 | + elems_repr = elems |
| 214 | + comment = '' |
| 215 | + else: |
| 216 | + elems_repr = '\n' + elems + '\n' |
| 217 | + comment = f'{c.COMMENT}{config.comment}{c.END}' |
| 218 | + else: |
| 219 | + elems_repr = config.empty_repr |
| 220 | + comment = '' |
217 | 221 |
|
218 |
| - out = f'{type_repr}{start}{comment}{elems_repr}{end}' |
219 |
| - return out |
| 222 | + type_repr = ( |
| 223 | + config.type if isinstance(config.type, str) else config.type.__name__ |
| 224 | + ) |
| 225 | + type_repr = f'{c.TYPE}{type_repr}{c.END}' if type_repr else '' |
| 226 | + start = f'{c.PAREN}{config.start}{c.END}' if config.start else '' |
| 227 | + end = f'{c.PAREN}{config.end}{c.END}' if config.end else '' |
| 228 | + |
| 229 | + out = f'{type_repr}{start}{comment}{elems_repr}{end}' |
| 230 | + return out |
| 231 | + finally: |
| 232 | + REPR_CONTEXT.depth -= 1 |
220 | 233 |
|
221 | 234 | class MappingReprMixin(Representable):
|
222 | 235 | def __nnx_repr__(self):
|
|
0 commit comments