Skip to content

Commit 0a2ea67

Browse files
daiyippyglove authors
authored and
pyglove authors
committed
Introduces context manager pg.str_format and pg.repr_format.
These two (thread_local) context managers are used for controlling the behavior of __str__ and __repr__ for `pg.Formattable` objects. PiperOrigin-RevId: 590276970
1 parent a4dacb5 commit 0a2ea67

File tree

5 files changed

+138
-13
lines changed

5 files changed

+138
-13
lines changed

pyglove/core/object_utils/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
"""
7171
# pylint: enable=line-too-long
7272
# pylint: disable=g-bad-import-order
73+
# pylint: disable=g-importing-member
7374

7475
# Common traits.
7576
from pyglove.core.object_utils.json_conversion import Nestable
@@ -116,6 +117,10 @@
116117
from pyglove.core.object_utils.formatting import BracketType
117118
from pyglove.core.object_utils.formatting import bracket_chars
118119

120+
# Context managers for defining the default format for __str__ and __repr__.
121+
from pyglove.core.object_utils.common_traits import str_format
122+
from pyglove.core.object_utils.common_traits import repr_format
123+
119124
# Handling code generation.
120125
from pyglove.core.object_utils.codegen import make_function
121126

@@ -144,4 +149,5 @@
144149
from pyglove.core.object_utils.error_utils import catch_errors
145150
from pyglove.core.object_utils.error_utils import CatchErrorsContext
146151

152+
# pylint: enable=g-importing-member
147153
# pylint: enable=g-bad-import-order

pyglove/core/object_utils/common_traits.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,22 @@
1919
"""
2020

2121
import abc
22-
from typing import Any, Dict, Optional
22+
from typing import Any, ContextManager, Dict, Optional
23+
from pyglove.core.object_utils import thread_local
24+
25+
26+
_TLS_STR_FORMAT_KWARGS = '_str_format_kwargs'
27+
_TLS_REPR_FORMAT_KWARGS = '_repr_format_kwargs'
28+
29+
30+
def str_format(**kwargs) -> ContextManager[Dict[str, Any]]:
31+
"""Context manager for setting the default format kwargs for __str__."""
32+
return thread_local.thread_local_arg_scope(_TLS_STR_FORMAT_KWARGS, **kwargs)
33+
34+
35+
def repr_format(**kwargs) -> ContextManager[Dict[str, Any]]:
36+
"""Context manager for setting the default format kwargs for __repr__."""
37+
return thread_local.thread_local_arg_scope(_TLS_REPR_FORMAT_KWARGS, **kwargs)
2338

2439

2540
class Formattable(metaclass=abc.ABCMeta):
@@ -59,11 +74,15 @@ def format(self,
5974

6075
def __str__(self) -> str:
6176
"""Returns the full (maybe multi-line) representation of this object."""
62-
return self.format(**self.__str_format_kwargs__)
77+
kwargs = dict(self.__str_format_kwargs__)
78+
kwargs.update(thread_local.thread_local_kwargs(_TLS_STR_FORMAT_KWARGS))
79+
return self.format(**kwargs)
6380

6481
def __repr__(self) -> str:
6582
"""Returns a single-line representation of this object."""
66-
return self.format(**self.__repr_format_kwargs__)
83+
kwargs = dict(self.__repr_format_kwargs__)
84+
kwargs.update(thread_local.thread_local_kwargs(_TLS_REPR_FORMAT_KWARGS))
85+
return self.format(**kwargs)
6786

6887

6988
class MaybePartial(metaclass=abc.ABCMeta):
@@ -156,4 +175,3 @@ def ensure_explicit_method_override(
156175
f'{method} is a PyGlove managed method. If you do need to override '
157176
'it, please decorate the method with `@pg.explicit_method_override`.')
158177
raise TypeError(error_message)
159-

pyglove/core/object_utils/common_traits_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,36 @@
1717
from pyglove.core.object_utils import common_traits
1818

1919

20+
class Foo(common_traits.Formattable):
21+
22+
def format(self, compact: bool = False, verbose: bool = True, **kwargs):
23+
return f'{self.__class__.__name__}(compact={compact}, verbose={verbose})'
24+
25+
26+
class FormattableTest(unittest.TestCase):
27+
28+
def test_formattable(self):
29+
foo = Foo()
30+
self.assertEqual(repr(foo), 'Foo(compact=True, verbose=True)')
31+
self.assertEqual(str(foo), 'Foo(compact=False, verbose=True)')
32+
33+
def test_formattable_with_custom_format(self):
34+
class Bar(Foo):
35+
__str_format_kwargs__ = {'compact': False, 'verbose': False}
36+
__repr_format_kwargs__ = {'compact': True, 'verbose': False}
37+
38+
bar = Bar()
39+
self.assertEqual(repr(bar), 'Bar(compact=True, verbose=False)')
40+
self.assertEqual(str(bar), 'Bar(compact=False, verbose=False)')
41+
42+
def test_formattable_with_context_managers(self):
43+
foo = Foo()
44+
with common_traits.str_format(verbose=False):
45+
with common_traits.repr_format(compact=False):
46+
self.assertEqual(repr(foo), 'Foo(compact=False, verbose=True)')
47+
self.assertEqual(str(foo), 'Foo(compact=False, verbose=False)')
48+
49+
2050
class ExplicitlyOverrideTest(unittest.TestCase):
2151

2252
def test_explicitly_override(self):

pyglove/core/object_utils/thread_local.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515

1616
import contextlib
1717
import threading
18-
from typing import Any, Callable, Iterator
18+
from typing import Any, Callable, Dict, Iterator
1919

20-
from pyglove.core.object_utils.missing import MISSING_VALUE
21-
22-
_RAISE_IF_NOT_FOUND = (MISSING_VALUE,)
20+
_MISSING = KeyError()
21+
_RAISE_IF_NOT_FOUND = ValueError()
2322
_thread_local_state = threading.local()
2423

2524

@@ -40,6 +39,25 @@ def thread_local_value_scope(
4039
thread_local_del(key)
4140

4241

42+
@contextlib.contextmanager
43+
def thread_local_arg_scope(key: str, **kwargs) -> Iterator[Dict[str, Any]]:
44+
"""Context manager to update args associated with key."""
45+
previous_kwargs = thread_local_peek(key, {})
46+
current_kwargs = previous_kwargs.copy()
47+
current_kwargs.update(kwargs)
48+
49+
try:
50+
thread_local_push(key, current_kwargs)
51+
yield current_kwargs
52+
finally:
53+
thread_local_pop(key)
54+
55+
56+
def thread_local_kwargs(key: str) -> Dict[str, Any]:
57+
"""Returns the args associated with key in current thread."""
58+
return thread_local_peek(key, {})
59+
60+
4361
def thread_local_has(key: str) -> bool:
4462
"""Deletes thread-local value by key."""
4563
return hasattr(_thread_local_state, key)
@@ -69,8 +87,8 @@ def thread_local_map(
6987
value_fn: Callable[[Any], Any],
7088
default_initial_value: Any = _RAISE_IF_NOT_FOUND) -> Any:
7189
"""Map a thread-local value."""
72-
value = thread_local_get(key, MISSING_VALUE)
73-
if value == MISSING_VALUE:
90+
value = thread_local_get(key, _MISSING)
91+
if value is _MISSING:
7492
value = default_initial_value
7593
if value is _RAISE_IF_NOT_FOUND:
7694
raise ValueError(f'Key {key!r} does not exist in thread-local storage.')
@@ -112,10 +130,25 @@ def thread_local_push(key: str, value: Any) -> None:
112130
)
113131

114132

133+
def thread_local_peek(
134+
key: str, default_value: Any = _RAISE_IF_NOT_FOUND
135+
) -> Any:
136+
"""Peaks a value at stack top."""
137+
stack = thread_local_get(key, _MISSING)
138+
if stack is _MISSING or not stack:
139+
if default_value is _RAISE_IF_NOT_FOUND:
140+
raise ValueError(
141+
f'Stack associated with key {key!r} does not exist in thread-local '
142+
'storage or is empty.'
143+
)
144+
return default_value
145+
return stack[-1]
146+
147+
115148
def thread_local_pop(key: str, default_value: Any = _RAISE_IF_NOT_FOUND) -> Any:
116149
"""Pops a value from a stack identified by key."""
117-
stack = thread_local_get(key, MISSING_VALUE)
118-
if stack == MISSING_VALUE:
150+
stack = thread_local_get(key, _MISSING)
151+
if stack is _MISSING:
119152
if default_value is _RAISE_IF_NOT_FOUND:
120153
raise ValueError(f'Key {key!r} does not exist in thread-local storage.')
121154
return default_value

pyglove/core/object_utils/thread_local_test.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,33 @@ def _fn():
8484
return _fn
8585
self.assert_thread_func([thread_fun(i) for i in range(5)], 2)
8686

87+
def test_thread_local_arg_scope(self):
88+
with thread_local.thread_local_arg_scope('arg_scope', x=1, y=2):
89+
self.assertEqual(
90+
thread_local.thread_local_kwargs('arg_scope'), dict(x=1, y=2)
91+
)
92+
with thread_local.thread_local_arg_scope('arg_scope', y=3, z=4):
93+
self.assertEqual(
94+
thread_local.thread_local_kwargs('arg_scope'), dict(x=1, y=3, z=4)
95+
)
96+
self.assertEqual(
97+
thread_local.thread_local_kwargs('arg_scope'), dict(x=1, y=2)
98+
)
99+
self.assertEqual(thread_local.thread_local_kwargs('arg_scope'), dict())
100+
101+
# Test thread locality.
102+
def thread_fun(i):
103+
def _fn():
104+
with thread_local.thread_local_arg_scope('arg_scope', x=i):
105+
self.assertEqual(
106+
thread_local.thread_local_kwargs('arg_scope'), dict(x=i)
107+
)
108+
self.assertEqual(thread_local.thread_local_kwargs('arg_scope'), dict())
109+
110+
return _fn
111+
112+
self.assert_thread_func([thread_fun(i) for i in range(5)], 2)
113+
87114
def test_thread_local_increment_decrement(self):
88115
k = 'z'
89116
self.assertEqual(thread_local.thread_local_increment(k, 5), 6)
@@ -103,18 +130,29 @@ def _fn():
103130
return _fn
104131
self.assert_thread_func([thread_fun(i) for i in range(5)], 2)
105132

106-
def test_thread_local_push_pop(self):
133+
def test_thread_local_push_peak_pop(self):
107134
k = 'p'
108135
self.assertFalse(thread_local.thread_local_has(k))
109136
thread_local.thread_local_push(k, 1)
110137
self.assertEqual(thread_local.thread_local_get(k), [1])
138+
self.assertEqual(thread_local.thread_local_peek(k), 1)
111139
thread_local.thread_local_push(k, 2)
112140
self.assertEqual(thread_local.thread_local_get(k), [1, 2])
141+
self.assertEqual(thread_local.thread_local_peek(k), 2)
113142
self.assertEqual(thread_local.thread_local_pop(k), 2)
114143
self.assertEqual(thread_local.thread_local_get(k), [1])
144+
self.assertEqual(thread_local.thread_local_peek(k), 1)
115145
self.assertEqual(thread_local.thread_local_pop(k), 1)
146+
147+
with self.assertRaisesRegex(
148+
ValueError, 'Stack associated with key .* does not exist'
149+
):
150+
thread_local.thread_local_peek(k)
151+
self.assertEqual(thread_local.thread_local_peek(k, -1), -1)
152+
116153
with self.assertRaisesRegex(IndexError, 'pop from empty list'):
117154
thread_local.thread_local_pop(k)
155+
118156
self.assertEqual(thread_local.thread_local_pop(k, -1), -1)
119157
with self.assertRaisesRegex(
120158
ValueError, 'Key .* does not exist in thread-local storage'):

0 commit comments

Comments
 (0)