Skip to content

Commit 84a4d99

Browse files
daiyippyglove authors
authored and
pyglove authors
committed
Bug fixes on default value inspection.
- `pg.Symbolic.sym_nondefault` to honor object-level default value vs. schema-level default value, and could work well with symbolic objects which do not use symbolic comparison. - `pg.format(hide_default_values=True)` to work with symbolic objects which do not use symbolic comparison. PiperOrigin-RevId: 584506769
1 parent 58b60c3 commit 84a4d99

File tree

2 files changed

+61
-21
lines changed

2 files changed

+61
-21
lines changed

pyglove/core/symbolic/dict.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -369,32 +369,50 @@ def _sym_missing(self) -> typing.Dict[str, Any]:
369369
def _sym_nondefault(self) -> typing.Dict[str, Any]:
370370
"""Returns non-default values as key/value pairs in a dict."""
371371
non_defaults = dict()
372-
if self._value_spec and self._value_spec.schema:
372+
if self._value_spec is not None and self._value_spec.schema:
373373
dict_schema = self._value_spec.schema
374-
matched_keys, unmatched_keys = dict_schema.resolve(self.keys())
375-
assert not unmatched_keys
374+
matched_keys, _ = dict_schema.resolve(self.keys())
376375
for key_spec, keys in matched_keys.items():
377376
value_spec = dict_schema[key_spec].value
378377
for key in keys:
379-
v = self.sym_getattr(key)
380-
child_has_non_defaults = False
381-
if isinstance(v, base.Symbolic):
382-
non_defaults_child = v.non_default_values(flatten=False)
383-
if non_defaults_child:
384-
non_defaults[key] = non_defaults_child
385-
child_has_non_defaults = True
386-
if not child_has_non_defaults and value_spec.default != v:
387-
non_defaults[key] = v
378+
diff = self._diff_base(self.sym_getattr(key), value_spec.default)
379+
if pg_typing.MISSING_VALUE != diff:
380+
non_defaults[key] = diff
388381
else:
389382
for k, v in self.sym_items():
390383
if isinstance(v, base.Symbolic):
391-
non_defaults_child = v.non_default_values(flatten=False)
384+
non_defaults_child = v.sym_nondefault(flatten=False)
392385
if non_defaults_child:
393386
non_defaults[k] = non_defaults_child
394387
else:
395388
non_defaults[k] = v
396389
return non_defaults
397390

391+
def _diff_base(self, value: Any, base_value: Any) -> Any:
392+
"""Computes the diff between a value and a base value."""
393+
if base.eq(value, base_value):
394+
return pg_typing.MISSING_VALUE
395+
396+
if (isinstance(value, list)
397+
or not isinstance(value, base.Symbolic)
398+
or pg_typing.MISSING_VALUE == base_value):
399+
return value
400+
401+
if value.__class__ is base_value.__class__:
402+
getter = lambda x, k: x.sym_getattr(k)
403+
elif isinstance(value, dict) and isinstance(base_value, dict):
404+
getter = lambda x, k: x[k]
405+
else:
406+
return value
407+
408+
diff = {}
409+
for k, v in value.sym_items():
410+
base_v = getter(base_value, k)
411+
child_diff = self._diff_base(v, base_v)
412+
if pg_typing.MISSING_VALUE != child_diff:
413+
diff[k] = child_diff
414+
return diff
415+
398416
def seal(self, sealed: bool = True) -> 'Dict':
399417
"""Seals or unseals current object from further modification."""
400418
if self.is_sealed == sealed:
@@ -796,7 +814,7 @@ def sym_jsonify(
796814
value = self.sym_getattr(key)
797815
if pg_typing.MISSING_VALUE == value:
798816
continue
799-
if hide_default_values and value == field.default_value:
817+
if hide_default_values and base.eq(value, field.default_value):
800818
continue
801819
json_repr[key] = base.to_json(
802820
value, hide_default_values=hide_default_values, **kwargs)
@@ -886,7 +904,7 @@ def _should_include_key(key):
886904
if pg_typing.MISSING_VALUE == v:
887905
if hide_missing_values:
888906
continue
889-
elif hide_default_values and v == field.default_value:
907+
elif hide_default_values and base.eq(v, field.default_value):
890908
continue
891909
field_list.append((field, key, v))
892910
else:

pyglove/core/symbolic/dict_test.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -988,16 +988,32 @@ def test_sym_missing(self):
988988
self.assertEqual(sd.sym_missing(), {})
989989

990990
def test_sym_nondefault(self):
991-
# Refer to `test_non_default_values` for more details.
992-
sd = Dict(x=1, value_spec=pg_typing.Dict([
991+
992+
class A(pg_object.Object):
993+
x: int
994+
use_symbolic_comparison = False
995+
996+
class B(pg_object.Object):
997+
y: int = 1
998+
use_symbolic_comparison = True
999+
1000+
sd = Dict(x=1, y=dict(a1=A(1)), value_spec=pg_typing.Dict([
9931001
('x', pg_typing.Int(default=0)),
9941002
('y', pg_typing.Dict([
995-
('z', pg_typing.Int(default=1))
1003+
('z', pg_typing.Int(default=1)),
1004+
('a1', pg_typing.Object(A)),
1005+
('a2', pg_typing.Object(A, default=A(1))),
1006+
('b', pg_typing.Object(B, default=B(2))),
9961007
])),
9971008
]))
998-
self.assertEqual(sd.sym_nondefault(), {'x': 1})
999-
sd.rebind({'y.z': 2}, x=0)
1000-
self.assertEqual(sd.sym_nondefault(), {'y.z': 2})
1009+
self.assertTrue(base.eq(sd.sym_nondefault(), {'x': 1, 'y.a1': A(1)}))
1010+
sd.rebind({'y.z': 2, 'y.a2': A(2), 'y.b': B(1)}, x=0)
1011+
self.assertTrue(
1012+
base.eq(
1013+
sd.sym_nondefault(),
1014+
{'y.z': 2, 'y.a1': A(1), 'y.a2.x': 2, 'y.b.y': 1}
1015+
)
1016+
)
10011017

10021018
# Test inferred value as the default value.
10031019
sd = Dict(
@@ -1892,9 +1908,15 @@ def __eq__(self, other):
18921908
self.assertEqual(base.from_json_str(sd.to_json_str(), value_spec=spec), sd)
18931909

18941910
def test_hide_default_values(self):
1911+
1912+
class A(pg_object.Object):
1913+
x: int = 1
1914+
use_symbolic_comparison = False
1915+
18951916
sd = Dict.partial(
18961917
x=1,
18971918
value_spec=pg_typing.Dict([
1919+
('v', pg_typing.Object(A, default=A(1))),
18981920
('w', pg_typing.Str()),
18991921
('x', pg_typing.Int()),
19001922
('y', pg_typing.Str().noneable()),

0 commit comments

Comments
 (0)