From 617866157e994578fe29ac4f30f430474d8e7b45 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Thu, 19 Dec 2024 11:19:24 +0100 Subject: [PATCH 1/2] Auto adapters for required methods --- test/testing/test_method_mock.py | 38 +++++++++++++++++++++++++++ transactron/core/method.py | 15 +++++++++-- transactron/lib/adapters.py | 4 +++ transactron/testing/infrastructure.py | 18 +++++++++---- 4 files changed, 68 insertions(+), 7 deletions(-) diff --git a/test/testing/test_method_mock.py b/test/testing/test_method_mock.py index 7ad198e..9db9561 100644 --- a/test/testing/test_method_mock.py +++ b/test/testing/test_method_mock.py @@ -10,6 +10,44 @@ from transactron.lib import * +class SimpleMethodMockTestCircuit(Elaboratable): + method: Required[Method] + wrapper: Provided[Method] + + def __init__(self, width: int): + self.method = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width})) + self.wrapper = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width})) + + def elaborate(self, platform): + m = TModule() + + @def_method(m, self.wrapper) + def _(input): + return {"output": self.method(m, input).output + 1} + + return m + + +class TestMethodMock(TestCaseWithSimulator): + async def process(self, sim: TestbenchContext): + for _ in range(20): + val = random.randrange(2**self.width) + ret = await self.dut.wrapper.call(sim, input=val) + assert ret.output == (val + 2) % 2**self.width + + @def_method_mock(lambda self: self.dut.method, enable=lambda _: random.randint(0, 1)) + def method_mock(self, input): + return {"output": input + 1} + + def test_method_mock_simple(self): + random.seed(42) + self.width = 4 + self.dut = SimpleTestCircuit(SimpleMethodMockTestCircuit(self.width)) + + with self.run_simulation(self.dut) as sim: + sim.add_testbench(self.process) + + class ReverseMethodMockTestCircuit(Elaboratable): def __init__(self, width): self.method = Method(i=StructLayout({"input": width}), o=StructLayout({"output": width})) diff --git a/transactron/core/method.py b/transactron/core/method.py index dce26a8..7484b86 100644 --- a/transactron/core/method.py +++ b/transactron/core/method.py @@ -1,9 +1,10 @@ from collections.abc import Sequence +import enum from transactron.utils import * from amaranth import * from amaranth import tracer -from typing import TYPE_CHECKING, Optional, Iterator, Unpack +from typing import TYPE_CHECKING, Annotated, Optional, Iterator, TypeAlias, TypeVar, Unpack from .transaction_base import * from contextlib import contextmanager from transactron.utils.assign import AssignArg @@ -19,7 +20,17 @@ from .transaction import Transaction # noqa: F401 -__all__ = ["Method", "Methods"] +__all__ = ["MethodDir", "Provided", "Required", "Method", "Methods"] + + +class MethodDir(enum.Enum): + PROVIDED = enum.auto() + REQUIRED = enum.auto() + + +_T = TypeVar("_T") +Provided: TypeAlias = Annotated[_T, MethodDir.PROVIDED] +Required: TypeAlias = Annotated[_T, MethodDir.REQUIRED] class Method(TransactionBase["Transaction | Method"]): diff --git a/transactron/lib/adapters.py b/transactron/lib/adapters.py index b06d99d..be8b5a8 100644 --- a/transactron/lib/adapters.py +++ b/transactron/lib/adapters.py @@ -125,6 +125,10 @@ def create( method = Method(name=name, i=i, o=o, src_loc=get_src_loc(src_loc)) return Adapter(method, **kwargs) + def update_args(self, **kwargs: Unpack[AdapterBodyParams]): + self.kwargs.update(kwargs) + return self + def set(self, with_validate_arguments: Optional[bool]): if with_validate_arguments is not None: self.with_validate_arguments = with_validate_arguments diff --git a/transactron/testing/infrastructure.py b/transactron/testing/infrastructure.py index 70580cf..43edd97 100644 --- a/transactron/testing/infrastructure.py +++ b/transactron/testing/infrastructure.py @@ -10,6 +10,8 @@ from amaranth import * from amaranth.sim import * from amaranth.sim._async import SimulatorContext +from transactron.core.method import MethodDir +from transactron.lib.adapters import Adapter from transactron.utils.dependencies import DependencyContext, DependencyManager from .testbenchio import TestbenchIO @@ -58,6 +60,7 @@ def __getattr__(self, name: str) -> Any: def elaborate(self, platform): def transform_methods_to_testbenchios( + adapter_type: type[Adapter] | type[AdapterTrans], container: _T_nested_collection[Method | Methods], ) -> tuple[ _T_nested_collection["TestbenchIO"], @@ -67,7 +70,7 @@ def transform_methods_to_testbenchios( tb_list = [] mc_list = [] for elem in container: - tb, mc = transform_methods_to_testbenchios(elem) + tb, mc = transform_methods_to_testbenchios(adapter_type, elem) tb_list.append(tb) mc_list.append(mc) return tb_list, ModuleConnector(*mc_list) @@ -75,24 +78,29 @@ def transform_methods_to_testbenchios( tb_dict = {} mc_dict = {} for name, elem in container.items(): - tb, mc = transform_methods_to_testbenchios(elem) + tb, mc = transform_methods_to_testbenchios(adapter_type, elem) tb_dict[name] = tb mc_dict[name] = mc return tb_dict, ModuleConnector(*mc_dict) elif isinstance(container, Methods): - tb_list = [TestbenchIO(AdapterTrans(method)) for method in container] + tb_list = [TestbenchIO(adapter_type(method)) for method in container] return list(tb_list), ModuleConnector(*tb_list) else: - tb = TestbenchIO(AdapterTrans(container)) + tb = TestbenchIO(adapter_type(container)) return tb, tb m = Module() m.submodules.dut = self._dut + hints = self._dut.__class__.__annotations__ for name, attr in vars(self._dut).items(): if guard_nested_collection(attr, Method | Methods) and attr: - tb_cont, mc = transform_methods_to_testbenchios(attr) + if name in hints and MethodDir.REQUIRED in hints[name].__metadata__: + adapter_type = Adapter + else: # PROVIDED is the default + adapter_type = AdapterTrans + tb_cont, mc = transform_methods_to_testbenchios(adapter_type, attr) self._io[name] = tb_cont m.submodules[name] = mc From 27e4480f3f6d6b2d31f415c2ddf2edfeb400fe20 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Thu, 19 Dec 2024 12:53:02 +0100 Subject: [PATCH 2/2] Updating kwargs in mocking --- transactron/testing/method_mock.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/transactron/testing/method_mock.py b/transactron/testing/method_mock.py index f3fdfe0..b6dc0a4 100644 --- a/transactron/testing/method_mock.py +++ b/transactron/testing/method_mock.py @@ -1,8 +1,9 @@ from contextlib import contextmanager import functools -from typing import Callable, Any, Optional +from typing import Callable, Any, Optional, Unpack from amaranth.sim._async import SimulatorContext +from transactron.core.body import AdapterBodyParams from transactron.lib.adapters import Adapter, AdapterBase from transactron.utils.transactron_helpers import async_mock_def_helper from .testbenchio import TestbenchIO @@ -21,7 +22,13 @@ def __init__( validate_arguments: Optional[Callable[..., bool]] = None, enable: Callable[[], bool] = lambda: True, delay: float = 0, + **kwargs: Unpack[AdapterBodyParams], ): + if isinstance(adapter, Adapter): + adapter.set(with_validate_arguments=validate_arguments is not None).update_args(**kwargs) + else: + assert validate_arguments is None + assert kwargs == {} self.adapter = adapter self.function = function self.validate_arguments = validate_arguments