-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat!: define wrappers around package that point to internals (#1573)
Closes #1561 So far just top level things, can add inner ones (cfg, blocks, etc.) later. BREAKING CHANGE: `Package` moved to new `hugr.package` module
- Loading branch information
Showing
7 changed files
with
232 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
"""HUGR package and pointed package interfaces.""" | ||
|
||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING, Generic, TypeVar, cast | ||
|
||
import hugr._serialization.extension as ext_s | ||
from hugr.ops import FuncDecl, FuncDefn, Op | ||
|
||
if TYPE_CHECKING: | ||
from hugr.ext import Extension | ||
from hugr.hugr.base import Hugr | ||
from hugr.hugr.node_port import Node | ||
|
||
__all__ = [ | ||
"Package", | ||
"PackagePointer", | ||
"ModulePointer", | ||
"ExtensionPointer", | ||
"NodePointer", | ||
"FuncDeclPointer", | ||
"FuncDefnPointer", | ||
] | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Package: | ||
"""A package of HUGR modules and extensions. | ||
The HUGRs may refer to the included extensions or those not included. | ||
""" | ||
|
||
#: HUGR modules in the package. | ||
modules: list[Hugr] | ||
#: Extensions included in the package. | ||
extensions: list[Extension] = field(default_factory=list) | ||
|
||
def _to_serial(self) -> ext_s.Package: | ||
return ext_s.Package( | ||
modules=[m._to_serial() for m in self.modules], | ||
extensions=[e._to_serial() for e in self.extensions], | ||
) | ||
|
||
def to_json(self) -> str: | ||
return self._to_serial().model_dump_json() | ||
|
||
|
||
@dataclass(frozen=True) | ||
class PackagePointer: | ||
"""Classes that point to packages and their inner contents.""" | ||
|
||
#: Package pointed to. | ||
package: Package | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ModulePointer(PackagePointer): | ||
"""Pointer to a module in a package. | ||
Args: | ||
package: Package pointed to. | ||
module_index: Index of the module in the package. | ||
""" | ||
|
||
#: Index of the module in the package. | ||
module_index: int | ||
|
||
@property | ||
def module(self) -> Hugr: | ||
"""Hugr definition of the module.""" | ||
return self.package.modules[self.module_index] | ||
|
||
def to_executable_package(self) -> ExecutablePackage: | ||
"""Create an executable package from a module containing a main function. | ||
Raises: | ||
ValueError: If the module does not contain a main function. | ||
""" | ||
module = self.module | ||
try: | ||
main_node = next( | ||
n | ||
for n in module.children() | ||
if isinstance((f_def := module[n].op), FuncDefn) | ||
and f_def.f_name == "main" | ||
) | ||
except StopIteration as e: | ||
msg = "Module does not contain a main function" | ||
raise ValueError(msg) from e | ||
return ExecutablePackage(self.package, self.module_index, main_node) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ExtensionPointer(PackagePointer): | ||
"""Pointer to an extension in a package. | ||
Args: | ||
package: Package pointed to. | ||
extension_index: Index of the extension in the package. | ||
""" | ||
|
||
#: Index of the extension in the package. | ||
extension_index: int | ||
|
||
@property | ||
def extension(self) -> Extension: | ||
"""Extension definition.""" | ||
return self.package.extensions[self.extension_index] | ||
|
||
|
||
OpType = TypeVar("OpType", bound=Op) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class NodePointer(Generic[OpType], ModulePointer): | ||
"""Pointer to a node in a module. | ||
Args: | ||
package: Package pointed to. | ||
module_index: Index of the module in the package. | ||
node: Node pointed to | ||
""" | ||
|
||
#: Node pointed to. | ||
node: Node | ||
|
||
@property | ||
def node_op(self) -> OpType: | ||
"""Get the operation of the node.""" | ||
return cast(OpType, self.module[self.node].op) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class FuncDeclPointer(NodePointer[FuncDecl]): | ||
"""Pointer to a function declaration in a module. | ||
Args: | ||
package: Package pointed to. | ||
module_index: Index of the module in the package. | ||
node: Node containing the function declaration. | ||
""" | ||
|
||
@property | ||
def func_decl(self) -> FuncDecl: | ||
"""Function declaration.""" | ||
return self.node_op | ||
|
||
|
||
@dataclass(frozen=True) | ||
class FuncDefnPointer(NodePointer[FuncDefn]): | ||
"""Pointer to a function definition in a module. | ||
Args: | ||
package: Package pointed to. | ||
module_index: Index of the module in the package. | ||
node: Node containing the function definition | ||
""" | ||
|
||
@property | ||
def func_defn(self) -> FuncDefn: | ||
"""Function definition.""" | ||
return self.node_op | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ExecutablePackage(FuncDefnPointer): | ||
"""PackagePointer with a defined entrypoint node. | ||
Args: | ||
package: Package pointed to. | ||
module_index: Index of the module in the package. | ||
node: Node containing the entry point function definition. | ||
""" | ||
|
||
@property | ||
def entry_point_node(self) -> Node: | ||
"""Get the entry point node of the package.""" | ||
return self.node |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from hugr import tys | ||
from hugr.build.function import Module | ||
from hugr.package import ( | ||
FuncDeclPointer, | ||
FuncDefnPointer, | ||
ModulePointer, | ||
Package, | ||
PackagePointer, | ||
) | ||
|
||
from .conftest import validate | ||
|
||
|
||
def test_package(): | ||
mod = Module() | ||
f_id = mod.define_function("id", [tys.Qubit]) | ||
f_id.set_outputs(f_id.input_node[0]) | ||
|
||
mod2 = Module() | ||
f_id_decl = mod2.declare_function( | ||
"id", tys.PolyFuncType([], tys.FunctionType([tys.Qubit], [tys.Qubit])) | ||
) | ||
f_main = mod2.define_main([tys.Qubit]) | ||
q = f_main.input_node[0] | ||
call = f_main.call(f_id_decl, q) | ||
f_main.set_outputs(call) | ||
|
||
package = Package([mod.hugr, mod2.hugr]) | ||
validate(package) | ||
|
||
p = PackagePointer(package) | ||
assert p.package == package | ||
|
||
m = ModulePointer(package, 1) | ||
assert m.module == mod2.hugr | ||
|
||
f = FuncDeclPointer(package, 1, f_id_decl) | ||
assert f.func_decl == mod2.hugr[f_id_decl].op | ||
|
||
f = FuncDefnPointer(package, 0, f_id.to_node()) | ||
|
||
assert f.func_defn == mod.hugr[f_id.to_node()].op | ||
|
||
main = m.to_executable_package() | ||
assert main.entry_point_node == f_main.to_node() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters