Skip to content

Commit

Permalink
Support ArrayLayouts in assign (kuznia-rdzeni/coreblocks#640)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored Apr 2, 2024
1 parent 7e61c11 commit 5651f90
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions transactron/utils/assign.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from typing import Optional, TypeAlias, cast, TYPE_CHECKING
from collections.abc import Iterable, Mapping
from collections.abc import Sequence, Iterable, Mapping
from amaranth import *
from amaranth.hdl._ast import ArrayProxy
from amaranth.lib import data
Expand All @@ -21,11 +21,11 @@ class AssignType(Enum):
ALL = 3


AssignFields: TypeAlias = AssignType | Iterable[str] | Mapping[str, "AssignFields"]
AssignArg: TypeAlias = ValueLike | Mapping[str, "AssignArg"]
AssignFields: TypeAlias = AssignType | Iterable[str | int] | Mapping[str | int, "AssignFields"]
AssignArg: TypeAlias = ValueLike | Mapping[str, "AssignArg"] | Mapping[int, "AssignArg"] | Sequence["AssignArg"]


def arrayproxy_fields(proxy: ArrayProxy) -> Optional[set[str]]:
def arrayproxy_fields(proxy: ArrayProxy) -> Optional[set[str | int]]:
def flatten_elems(proxy: ArrayProxy):
for elem in proxy.elems:
if isinstance(elem, ArrayProxy):
Expand All @@ -38,15 +38,19 @@ def flatten_elems(proxy: ArrayProxy):
return set.intersection(*[set(cast(data.View, el).shape().members.keys()) for el in elems])


def assign_arg_fields(val: AssignArg) -> Optional[set[str]]:
def assign_arg_fields(val: AssignArg) -> Optional[set[str | int]]:
if isinstance(val, ArrayProxy):
return arrayproxy_fields(val)
elif isinstance(val, data.View):
layout = val.shape()
if isinstance(layout, data.StructLayout):
return set(k for k in layout.members)
if isinstance(layout, data.ArrayLayout):
return set(range(layout.length))
elif isinstance(val, dict):
return set(val.keys())
elif isinstance(val, list):
return set(range(len(val)))


def assign(
Expand Down Expand Up @@ -107,8 +111,18 @@ def assign(

if lhs_fields is not None and rhs_fields is not None:
# asserts for type checking
assert isinstance(lhs, ArrayProxy) or isinstance(lhs, Mapping) or isinstance(lhs, data.View)
assert isinstance(rhs, ArrayProxy) or isinstance(rhs, Mapping) or isinstance(rhs, data.View)
assert (
isinstance(lhs, ArrayProxy)
or isinstance(lhs, Mapping)
or isinstance(lhs, Sequence)
or isinstance(lhs, data.View)
)
assert (
isinstance(rhs, ArrayProxy)
or isinstance(rhs, Mapping)
or isinstance(lhs, Sequence)
or isinstance(rhs, data.View)
)

if fields is AssignType.COMMON:
names = lhs_fields & rhs_fields
Expand All @@ -135,8 +149,8 @@ def assign(
subfields = AssignType.ALL

yield from assign(
lhs[name],
rhs[name],
lhs[name], # type: ignore
rhs[name], # type: ignore
fields=subfields,
lhs_strict=not isinstance(lhs, Mapping),
rhs_strict=not isinstance(rhs, Mapping),
Expand Down

0 comments on commit 5651f90

Please sign in to comment.