Skip to content

Commit

Permalink
Support module alias for deserialization.
Browse files Browse the repository at this point in the history
This allows users to load previous serialized objects after module rename.

PiperOrigin-RevId: 584729455
  • Loading branch information
daiyip authored and pyglove authors committed Nov 22, 2023
1 parent 121b5ef commit 536bde1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
22 changes: 21 additions & 1 deletion pyglove/core/object_utils/json_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self):
# class will always be picked up when there are multiple wrapper classes
# registered for a user class.
self._type_to_cls_map = dict()
self._prefix_mapping = dict()

def register(
self, type_name: str, cls: Type[Any], override_existing: bool = False
Expand All @@ -80,14 +81,28 @@ def register(
f'{self._type_to_cls_map[type_name].__name__}.')
self._type_to_cls_map[type_name] = cls

def add_module_alias(self, module: str, alias: str) -> None:
"""Maps a module name to another name. Usually due to rename."""
self._prefix_mapping[alias] = module

def is_registered(self, type_name: str) -> bool:
"""Returns whether a type name is registered."""
return type_name in self._type_to_cls_map

def class_from_typename(
self, type_name: str) -> Optional[Type[Any]]:
"""Get class from type name."""
return self._type_to_cls_map.get(type_name, None)
cls = self._type_to_cls_map.get(type_name, None)
if cls is None:
# Modules could be renamed, to load legacy serialized objects, we
# use prefix mapping to get to their latest registry.
for k, v in self._prefix_mapping.items():
if type_name.startswith(f'{k}.'):
remapped_type_name = type_name.replace(k, v)
cls = self._type_to_cls_map.get(remapped_type_name, None)
if cls is not None:
break
return cls

def iteritems(self) -> Iterable[Tuple[str, Type[Any]]]:
"""Iterate type registry."""
Expand Down Expand Up @@ -205,6 +220,11 @@ def register(
"""
cls._TYPE_REGISTRY.register(type_name, subclass, override_existing)

@classmethod
def add_module_alias(cls, source_name: str, target_name: str) -> None:
"""Adds a module alias so previous serialized objects could be loaded."""
cls._TYPE_REGISTRY.add_module_alias(source_name, target_name)

@classmethod
def is_registered(cls, type_name: str) -> bool:
"""Returns True if a type name is registered. Otherwise False."""
Expand Down
7 changes: 7 additions & 0 deletions pyglove/core/object_utils/json_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ def __ne__(self, other):
self.assertEqual(json_conversion.from_json(json_value),
[(T(), 2), {'y': T(3)}])

# Test module alias.
json_conversion.JSONConvertible.add_module_alias(T.__module__, 'mymodule')
self.assertEqual(
json_conversion.from_json({'_type': f'mymodule.{T.__qualname__}'}),
T()
)

# Test bad cases.
with self.assertRaisesRegex(
ValueError, 'Tuple should have at least one element besides .*'):
Expand Down

0 comments on commit 536bde1

Please sign in to comment.