Skip to content

Commit

Permalink
fix(py): allow conditional cases to be defined out of order (#1599)
Browse files Browse the repository at this point in the history
Closes #1596

Was tempted to just change the type of `cases` since who would be using
it anyway but I've tried to be good and deprecate instead
  • Loading branch information
ss2165 authored Oct 23, 2024
1 parent e04fcc5 commit 583d21d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 15 deletions.
55 changes: 40 additions & 15 deletions hugr-py/src/hugr/build/cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from contextlib import AbstractContextManager
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from typing_extensions import Self
Expand All @@ -19,6 +19,7 @@
if TYPE_CHECKING:
from hugr.hugr.node_port import Node, ToNode, Wire
from hugr.tys import TypeRow
import warnings


class Case(DfBase[ops.Case]):
Expand Down Expand Up @@ -104,8 +105,8 @@ class Conditional(ParentBuilder[ops.Conditional], AbstractContextManager):
Conditional(sum_ty=Bool, other_inputs=[Qubit])
"""

#: map from case index to node holding the :class:`Case <hugr.ops.Case>`
cases: dict[int, Node | None]
#: builders for each case and whether they have been built by the user yet
_case_builders: list[tuple[Case, bool]] = field(default_factory=list)

def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None:
root_op = ops.Conditional(sum_ty, other_inputs)
Expand All @@ -115,13 +116,40 @@ def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None:
def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None:
self.hugr = hugr
self.parent_node = root
self.cases = {i: None for i in range(n_cases)}
self._case_builders = []

for case_id in range(n_cases):
new_case = Case.new_nested(
ops.Case(self.parent_op.nth_inputs(case_id)),
self.hugr,
self.parent_node,
)
new_case._parent_cond = self
self._case_builders.append((new_case, False))

@property
def cases(self) -> dict[int, Node | None]:
"""Map from case index to node holding the :class:`Case <hugr.ops.Case>`.
DEPRECATED
"""
# TODO remove in 0.10
warnings.warn(
"The 'cases' property is deprecated and"
" will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)
return {
i: case.parent_node if b else None
for i, (case, b) in enumerate(self._case_builders)
}

def __enter__(self) -> Self:
return self

def __exit__(self, *args) -> None:
if any(c is None for c in self.cases.values()):
if not all(built for _, built in self._case_builders):
msg = "All cases must be added before exiting context."
raise ConditionalError(msg)
return None
Expand Down Expand Up @@ -185,18 +213,15 @@ def add_case(self, case_id: int) -> Case:
>>> with cond.add_case(0) as case:\
case.set_outputs(*case.inputs())
"""
if case_id not in self.cases:
if case_id >= len(self._case_builders):
msg = f"Case {case_id} out of possible range."
raise ConditionalError(msg)
input_types = self.parent_op.nth_inputs(case_id)
new_case = Case.new_nested(
ops.Case(input_types),
self.hugr,
self.parent_node,
)
new_case._parent_cond = self
self.cases[case_id] = new_case.parent_node
return new_case
case, built = self._case_builders[case_id]
if built:
msg = f"Case {case_id} already built."
raise ConditionalError(msg)
self._case_builders[case_id] = (case, True)
return case

# TODO insert_case

Expand Down
10 changes: 10 additions & 0 deletions hugr-py/tests/test_cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,13 @@ def test_complex_tail_loop() -> None:
h.set_outputs(*tl[:3])

validate(h.hugr)


def test_conditional_bug() -> None:
# bug with case ordering https://github.com/CQCL/hugr/issues/1596
cond = Conditional(tys.Either([tys.USize()], [tys.Unit]), [])
with cond.add_case(1) as case:
case.set_outputs()
with cond.add_case(0) as case:
case.set_outputs()
validate(cond.hugr)

0 comments on commit 583d21d

Please sign in to comment.