Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: to/from json for extension/package #1575

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion hugr-py/src/hugr/_serialization/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pydantic as pd
from pydantic_extra_types.semantic_version import SemanticVersion # noqa: TCH002

from hugr.hugr.base import Hugr
from hugr.utils import deser_it

from .ops import Value
Expand Down Expand Up @@ -156,5 +157,14 @@ class Package(ConfiguredBaseModel):
def get_version(cls) -> str:
return serialization_version()

def deserialize(self) -> package.Package:
return package.Package(
modules=[Hugr._from_serial(m) for m in self.modules],
extensions=[e.deserialize() for e in self.extensions],
)


from hugr import ext # noqa: E402
from hugr import ( # noqa: E402
ext,
package,
)
16 changes: 16 additions & 0 deletions hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,22 @@ def _to_serial(self) -> ext_s.Extension:
operations={k: v._to_serial() for k, v in self.operations.items()},
)

def to_json(self) -> str:
"""Serialize the extension to a JSON string."""
return self._to_serial().model_dump_json()

@classmethod
def from_json(cls, json_str: str) -> Extension:
"""Deserialize a JSON string to a Extension object.
Args:
json_str: The JSON string representing a Extension.
Returns:
The deserialized Extension object.
"""
return ext_s.Extension.model_validate_json(json_str).deserialize()

def add_op_def(self, op_def: OpDef) -> OpDef:
"""Add an operation definition to the extension.
Expand Down
12 changes: 12 additions & 0 deletions hugr-py/src/hugr/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ def _to_serial(self) -> ext_s.Package:
def to_json(self) -> str:
return self._to_serial().model_dump_json()

@classmethod
def from_json(cls, json_str: str) -> Package:
"""Deserialize a JSON string to a Package object.
Args:
json_str: The JSON string representing a Package.
Returns:
The deserialized Package object.
"""
return ext_s.Package.model_validate_json(json_str).deserialize()


@dataclass(frozen=True)
class PackagePointer:
Expand Down
10 changes: 9 additions & 1 deletion hugr-py/tests/serialization/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TypeDef,
TypeDefBound,
)
from hugr._serialization.ops import Module, OpType
from hugr._serialization.serial_hugr import SerialHugr, serialization_version
from hugr._serialization.tys import (
FunctionType,
Expand Down Expand Up @@ -109,6 +110,8 @@ def test_extension():
dumped_json = ext.model_dump_json()

assert Extension.model_validate_json(dumped_json) == ext
hugr_ext = ext.deserialize()
assert hugr_ext.from_json(hugr_ext.to_json()) == hugr_ext


def test_package():
Expand All @@ -123,9 +126,14 @@ def test_package():
operations={},
)
ext_load = Extension.model_validate_json(EXAMPLE)

package = Package(
extensions=[ext, ext_load], modules=[SerialHugr(nodes=[], edges=[])]
extensions=[ext, ext_load],
modules=[SerialHugr(nodes=[OpType(root=Module(parent=0))], edges=[])],
)

package_load = Package.model_validate_json(package.model_dump_json())
assert package == package_load

hugr_package = package.deserialize()
assert hugr_package.from_json(hugr_package.to_json()) == hugr_package
Loading