Skip to content

Commit

Permalink
Improved SimpleTestCircuit (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk authored Jan 9, 2025
1 parent 9548e01 commit bc15c17
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 8 deletions.
38 changes: 38 additions & 0 deletions test/testing/test_method_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,44 @@
from transactron.lib import *


class SimpleMethodMockTestCircuit(Elaboratable):
method: Required[Method]
wrapper: Provided[Method]

def __init__(self, width: int):
self.method = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width}))
self.wrapper = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width}))

def elaborate(self, platform):
m = TModule()

@def_method(m, self.wrapper)
def _(input):
return {"output": self.method(m, input).output + 1}

return m


class TestMethodMock(TestCaseWithSimulator):
async def process(self, sim: TestbenchContext):
for _ in range(20):
val = random.randrange(2**self.width)
ret = await self.dut.wrapper.call(sim, input=val)
assert ret.output == (val + 2) % 2**self.width

@def_method_mock(lambda self: self.dut.method, enable=lambda _: random.randint(0, 1))
def method_mock(self, input):
return {"output": input + 1}

def test_method_mock_simple(self):
random.seed(42)
self.width = 4
self.dut = SimpleTestCircuit(SimpleMethodMockTestCircuit(self.width))

with self.run_simulation(self.dut) as sim:
sim.add_testbench(self.process)


class ReverseMethodMockTestCircuit(Elaboratable):
def __init__(self, width):
self.method = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width}))
Expand Down
15 changes: 13 additions & 2 deletions transactron/core/method.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections.abc import Sequence
import enum

from transactron.utils import *
from amaranth import *
from amaranth import tracer
from typing import TYPE_CHECKING, Optional, Iterator, Unpack
from typing import TYPE_CHECKING, Annotated, Optional, Iterator, TypeAlias, TypeVar, Unpack
from .transaction_base import *
from contextlib import contextmanager
from transactron.utils.assign import AssignArg
Expand All @@ -19,7 +20,17 @@
from .transaction import Transaction # noqa: F401


__all__ = ["Method", "Methods"]
__all__ = ["MethodDir", "Provided", "Required", "Method", "Methods"]


class MethodDir(enum.Enum):
PROVIDED = enum.auto()
REQUIRED = enum.auto()


_T = TypeVar("_T")
Provided: TypeAlias = Annotated[_T, MethodDir.PROVIDED]
Required: TypeAlias = Annotated[_T, MethodDir.REQUIRED]


class Method(TransactionBase["Transaction | Method"]):
Expand Down
4 changes: 4 additions & 0 deletions transactron/lib/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def create(
method = Method(name=name, i=i, o=o, src_loc=get_src_loc(src_loc))
return Adapter(method, **kwargs)

def update_args(self, **kwargs: Unpack[AdapterBodyParams]):
self.kwargs.update(kwargs)
return self

def set(self, with_validate_arguments: Optional[bool]):
if with_validate_arguments is not None:
self.with_validate_arguments = with_validate_arguments
Expand Down
18 changes: 13 additions & 5 deletions transactron/testing/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from amaranth import *
from amaranth.sim import *
from amaranth.sim._async import SimulatorContext
from transactron.core.method import MethodDir
from transactron.lib.adapters import Adapter

from transactron.utils.dependencies import DependencyContext, DependencyManager
from .testbenchio import TestbenchIO
Expand Down Expand Up @@ -58,6 +60,7 @@ def __getattr__(self, name: str) -> Any:

def elaborate(self, platform):
def transform_methods_to_testbenchios(
adapter_type: type[Adapter] | type[AdapterTrans],
container: _T_nested_collection[Method | Methods],
) -> tuple[
_T_nested_collection["TestbenchIO"],
Expand All @@ -67,32 +70,37 @@ def transform_methods_to_testbenchios(
tb_list = []
mc_list = []
for elem in container:
tb, mc = transform_methods_to_testbenchios(elem)
tb, mc = transform_methods_to_testbenchios(adapter_type, elem)
tb_list.append(tb)
mc_list.append(mc)
return tb_list, ModuleConnector(*mc_list)
elif isinstance(container, dict):
tb_dict = {}
mc_dict = {}
for name, elem in container.items():
tb, mc = transform_methods_to_testbenchios(elem)
tb, mc = transform_methods_to_testbenchios(adapter_type, elem)
tb_dict[name] = tb
mc_dict[name] = mc
return tb_dict, ModuleConnector(*mc_dict)
elif isinstance(container, Methods):
tb_list = [TestbenchIO(AdapterTrans(method)) for method in container]
tb_list = [TestbenchIO(adapter_type(method)) for method in container]
return list(tb_list), ModuleConnector(*tb_list)
else:
tb = TestbenchIO(AdapterTrans(container))
tb = TestbenchIO(adapter_type(container))
return tb, tb

m = Module()

m.submodules.dut = self._dut
hints = self._dut.__class__.__annotations__

for name, attr in vars(self._dut).items():
if guard_nested_collection(attr, Method | Methods) and attr:
tb_cont, mc = transform_methods_to_testbenchios(attr)
if name in hints and MethodDir.REQUIRED in hints[name].__metadata__:
adapter_type = Adapter
else: # PROVIDED is the default
adapter_type = AdapterTrans
tb_cont, mc = transform_methods_to_testbenchios(adapter_type, attr)
self._io[name] = tb_cont
m.submodules[name] = mc

Expand Down
9 changes: 8 additions & 1 deletion transactron/testing/method_mock.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from contextlib import contextmanager
import functools
from typing import Callable, Any, Optional
from typing import Callable, Any, Optional, Unpack

from amaranth.sim._async import SimulatorContext
from transactron.core.body import AdapterBodyParams
from transactron.lib.adapters import Adapter, AdapterBase
from transactron.utils.transactron_helpers import async_mock_def_helper
from .testbenchio import TestbenchIO
Expand All @@ -21,7 +22,13 @@ def __init__(
validate_arguments: Optional[Callable[..., bool]] = None,
enable: Callable[[], bool] = lambda: True,
delay: float = 0,
**kwargs: Unpack[AdapterBodyParams],
):
if isinstance(adapter, Adapter):
adapter.set(with_validate_arguments=validate_arguments is not None).update_args(**kwargs)
else:
assert validate_arguments is None
assert kwargs == {}
self.adapter = adapter
self.function = function
self.validate_arguments = validate_arguments
Expand Down

0 comments on commit bc15c17

Please sign in to comment.