Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ArrayLayouts in assign #640

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/transactions/test_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
("normal", lambda mk, lay: mk(lay), lambda x: x, lambda r: r),
("rec", lambda mk, lay: mk([("x", lay)]), lambda x: {"x": x}, lambda r: r.x),
("dict", lambda mk, lay: {"x": mk(lay)}, lambda x: {"x": x}, lambda r: r["x"]),
("list", lambda mk, lay: [mk(lay)], lambda x: {0: x}, lambda r: r[0]),
("array", lambda mk, lay: Signal(data.ArrayLayout(reclayout2datalayout(lay), 1)), lambda x: {0: x}, lambda r: r[0]),
]


Expand Down
32 changes: 23 additions & 9 deletions transactron/utils/assign.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from typing import Optional, TypeAlias, cast, TYPE_CHECKING
from collections.abc import Iterable, Mapping
from collections.abc import Sequence, Iterable, Mapping
from amaranth import *
from amaranth.hdl._ast import ArrayProxy
from amaranth.lib import data
Expand All @@ -21,11 +21,11 @@ class AssignType(Enum):
ALL = 3


AssignFields: TypeAlias = AssignType | Iterable[str] | Mapping[str, "AssignFields"]
AssignArg: TypeAlias = ValueLike | Mapping[str, "AssignArg"]
AssignFields: TypeAlias = AssignType | Iterable[str | int] | Mapping[str | int, "AssignFields"]
AssignArg: TypeAlias = ValueLike | Mapping[str, "AssignArg"] | Mapping[int, "AssignArg"] | Sequence["AssignArg"]


def arrayproxy_fields(proxy: ArrayProxy) -> Optional[set[str]]:
def arrayproxy_fields(proxy: ArrayProxy) -> Optional[set[str | int]]:
def flatten_elems(proxy: ArrayProxy):
for elem in proxy.elems:
if isinstance(elem, ArrayProxy):
Expand All @@ -38,15 +38,19 @@ def flatten_elems(proxy: ArrayProxy):
return set.intersection(*[set(cast(data.View, el).shape().members.keys()) for el in elems])


def assign_arg_fields(val: AssignArg) -> Optional[set[str]]:
def assign_arg_fields(val: AssignArg) -> Optional[set[str | int]]:
if isinstance(val, ArrayProxy):
return arrayproxy_fields(val)
elif isinstance(val, data.View):
layout = val.shape()
if isinstance(layout, data.StructLayout):
return set(k for k in layout.members)
if isinstance(layout, data.ArrayLayout):
return set(range(layout.length))
elif isinstance(val, dict):
return set(val.keys())
elif isinstance(val, list):
return set(range(len(val)))


def assign(
Expand Down Expand Up @@ -107,8 +111,18 @@ def assign(

if lhs_fields is not None and rhs_fields is not None:
# asserts for type checking
assert isinstance(lhs, ArrayProxy) or isinstance(lhs, Mapping) or isinstance(lhs, data.View)
assert isinstance(rhs, ArrayProxy) or isinstance(rhs, Mapping) or isinstance(rhs, data.View)
assert (
isinstance(lhs, ArrayProxy)
or isinstance(lhs, Mapping)
or isinstance(lhs, Sequence)
or isinstance(lhs, data.View)
)
assert (
isinstance(rhs, ArrayProxy)
or isinstance(rhs, Mapping)
or isinstance(lhs, Sequence)
or isinstance(rhs, data.View)
)

if fields is AssignType.COMMON:
names = lhs_fields & rhs_fields
Expand All @@ -135,8 +149,8 @@ def assign(
subfields = AssignType.ALL

yield from assign(
lhs[name],
rhs[name],
lhs[name], # type: ignore
rhs[name], # type: ignore
fields=subfields,
lhs_strict=not isinstance(lhs, Mapping),
rhs_strict=not isinstance(rhs, Mapping),
Expand Down