diff --git a/test/transactions/test_assign.py b/test/transactions/test_assign.py index 73d5b28f9..8675d4b46 100644 --- a/test/transactions/test_assign.py +++ b/test/transactions/test_assign.py @@ -20,6 +20,8 @@ ("normal", lambda mk, lay: mk(lay), lambda x: x, lambda r: r), ("rec", lambda mk, lay: mk([("x", lay)]), lambda x: {"x": x}, lambda r: r.x), ("dict", lambda mk, lay: {"x": mk(lay)}, lambda x: {"x": x}, lambda r: r["x"]), + ("list", lambda mk, lay: [mk(lay)], lambda x: {0: x}, lambda r: r[0]), + ("array", lambda mk, lay: Signal(data.ArrayLayout(reclayout2datalayout(lay), 1)), lambda x: {0: x}, lambda r: r[0]), ] diff --git a/transactron/utils/assign.py b/transactron/utils/assign.py index 0be471e80..57c9467c7 100644 --- a/transactron/utils/assign.py +++ b/transactron/utils/assign.py @@ -1,6 +1,6 @@ from enum import Enum from typing import Optional, TypeAlias, cast, TYPE_CHECKING -from collections.abc import Iterable, Mapping +from collections.abc import Sequence, Iterable, Mapping from amaranth import * from amaranth.hdl._ast import ArrayProxy from amaranth.lib import data @@ -21,11 +21,11 @@ class AssignType(Enum): ALL = 3 -AssignFields: TypeAlias = AssignType | Iterable[str] | Mapping[str, "AssignFields"] -AssignArg: TypeAlias = ValueLike | Mapping[str, "AssignArg"] +AssignFields: TypeAlias = AssignType | Iterable[str | int] | Mapping[str | int, "AssignFields"] +AssignArg: TypeAlias = ValueLike | Mapping[str, "AssignArg"] | Mapping[int, "AssignArg"] | Sequence["AssignArg"] -def arrayproxy_fields(proxy: ArrayProxy) -> Optional[set[str]]: +def arrayproxy_fields(proxy: ArrayProxy) -> Optional[set[str | int]]: def flatten_elems(proxy: ArrayProxy): for elem in proxy.elems: if isinstance(elem, ArrayProxy): @@ -38,15 +38,19 @@ def flatten_elems(proxy: ArrayProxy): return set.intersection(*[set(cast(data.View, el).shape().members.keys()) for el in elems]) -def assign_arg_fields(val: AssignArg) -> Optional[set[str]]: +def assign_arg_fields(val: AssignArg) -> Optional[set[str | int]]: if isinstance(val, ArrayProxy): return arrayproxy_fields(val) elif isinstance(val, data.View): layout = val.shape() if isinstance(layout, data.StructLayout): return set(k for k in layout.members) + if isinstance(layout, data.ArrayLayout): + return set(range(layout.length)) elif isinstance(val, dict): return set(val.keys()) + elif isinstance(val, list): + return set(range(len(val))) def assign( @@ -107,8 +111,18 @@ def assign( if lhs_fields is not None and rhs_fields is not None: # asserts for type checking - assert isinstance(lhs, ArrayProxy) or isinstance(lhs, Mapping) or isinstance(lhs, data.View) - assert isinstance(rhs, ArrayProxy) or isinstance(rhs, Mapping) or isinstance(rhs, data.View) + assert ( + isinstance(lhs, ArrayProxy) + or isinstance(lhs, Mapping) + or isinstance(lhs, Sequence) + or isinstance(lhs, data.View) + ) + assert ( + isinstance(rhs, ArrayProxy) + or isinstance(rhs, Mapping) + or isinstance(lhs, Sequence) + or isinstance(rhs, data.View) + ) if fields is AssignType.COMMON: names = lhs_fields & rhs_fields @@ -135,8 +149,8 @@ def assign( subfields = AssignType.ALL yield from assign( - lhs[name], - rhs[name], + lhs[name], # type: ignore + rhs[name], # type: ignore fields=subfields, lhs_strict=not isinstance(lhs, Mapping), rhs_strict=not isinstance(rhs, Mapping),