diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 291f7e5c8..6420bffff 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -5,6 +5,7 @@ import pydantic as pd from pydantic_extra_types.semantic_version import SemanticVersion # noqa: TCH002 +from hugr.hugr.base import Hugr from hugr.utils import deser_it from .ops import Value @@ -156,5 +157,14 @@ class Package(ConfiguredBaseModel): def get_version(cls) -> str: return serialization_version() + def deserialize(self) -> package.Package: + return package.Package( + modules=[Hugr._from_serial(m) for m in self.modules], + extensions=[e.deserialize() for e in self.extensions], + ) + -from hugr import ext # noqa: E402 +from hugr import ( # noqa: E402 + ext, + package, +) diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 6bed102d6..3910b7ae7 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -272,6 +272,22 @@ def _to_serial(self) -> ext_s.Extension: operations={k: v._to_serial() for k, v in self.operations.items()}, ) + def to_json(self) -> str: + """Serialize the extension to a JSON string.""" + return self._to_serial().model_dump_json() + + @classmethod + def from_json(cls, json_str: str) -> Extension: + """Deserialize a JSON string to a Extension object. + + Args: + json_str: The JSON string representing a Extension. + + Returns: + The deserialized Extension object. + """ + return ext_s.Extension.model_validate_json(json_str).deserialize() + def add_op_def(self, op_def: OpDef) -> OpDef: """Add an operation definition to the extension. diff --git a/hugr-py/src/hugr/package.py b/hugr-py/src/hugr/package.py index f7c5a9fc7..f9cd36a3b 100644 --- a/hugr-py/src/hugr/package.py +++ b/hugr-py/src/hugr/package.py @@ -46,6 +46,18 @@ def _to_serial(self) -> ext_s.Package: def to_json(self) -> str: return self._to_serial().model_dump_json() + @classmethod + def from_json(cls, json_str: str) -> Package: + """Deserialize a JSON string to a Package object. + + Args: + json_str: The JSON string representing a Package. + + Returns: + The deserialized Package object. + """ + return ext_s.Package.model_validate_json(json_str).deserialize() + @dataclass(frozen=True) class PackagePointer: diff --git a/hugr-py/tests/serialization/test_extension.py b/hugr-py/tests/serialization/test_extension.py index bd08055ec..4b105c393 100644 --- a/hugr-py/tests/serialization/test_extension.py +++ b/hugr-py/tests/serialization/test_extension.py @@ -8,6 +8,7 @@ TypeDef, TypeDefBound, ) +from hugr._serialization.ops import Module, OpType from hugr._serialization.serial_hugr import SerialHugr, serialization_version from hugr._serialization.tys import ( FunctionType, @@ -109,6 +110,8 @@ def test_extension(): dumped_json = ext.model_dump_json() assert Extension.model_validate_json(dumped_json) == ext + hugr_ext = ext.deserialize() + assert hugr_ext.from_json(hugr_ext.to_json()) == hugr_ext def test_package(): @@ -123,9 +126,14 @@ def test_package(): operations={}, ) ext_load = Extension.model_validate_json(EXAMPLE) + package = Package( - extensions=[ext, ext_load], modules=[SerialHugr(nodes=[], edges=[])] + extensions=[ext, ext_load], + modules=[SerialHugr(nodes=[OpType(root=Module(parent=0))], edges=[])], ) package_load = Package.model_validate_json(package.model_dump_json()) assert package == package_load + + hugr_package = package.deserialize() + assert hugr_package.from_json(hugr_package.to_json()) == hugr_package