Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial idea how to clean up unittests.
Browse files Browse the repository at this point in the history
Lekcyjna committed Apr 2, 2024
1 parent 5add626 commit a47dcaa
Showing 7 changed files with 100 additions and 83 deletions.
14 changes: 8 additions & 6 deletions test/fu/functional_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import asdict, dataclass
from itertools import product
import random
import pytest
from collections import deque
from typing import Generic, TypeVar

@@ -93,7 +94,8 @@ def compute_result(i1: int, i2: int, i_imm: int, pc: int, fn: _T, xlen: int) ->
"""
raise NotImplementedError

def setUp(self):
@pytest.fixture(autouse=True)
def setup(self, configure_dependency_context):
self.gen_params = GenParams(test_core_config)

self.report_mock = TestbenchIO(Adapter(i=self.gen_params.get(ExceptionRegisterLayouts).report))
@@ -149,7 +151,7 @@ def consumer(self):
while self.responses:
expected = self.responses.pop()
result = yield from self.m.accept.call()
self.assertDictEqual(expected, result)
assert expected== result
yield from self.random_wait(self.max_wait)

def producer(self):
@@ -162,19 +164,19 @@ def exception_consumer(self):
while self.exceptions:
expected = self.exceptions.pop()
result = yield from self.report_mock.call()
self.assertDictEqual(expected, result)
assert expected== result
yield from self.random_wait(self.max_wait)

# keep partialy dependent tests from hanging up and detect extra calls
yield Passive()
result = yield from self.report_mock.call()
self.assertFalse(True, "unexpected report call")
assert not True, "unexpected report call"

def pipeline_verifier(self):
yield Passive()
while True:
self.assertTrue((yield self.m.issue.adapter.iface.ready))
self.assertEqual((yield self.m.issue.adapter.en), (yield self.m.issue.adapter.done))
assert (yield self.m.issue.adapter.iface.ready)
assert (yield self.m.issue.adapter.en)== (yield self.m.issue.adapter.done)
yield

def run_standard_fu_test(self, pipeline_test=False):
2 changes: 1 addition & 1 deletion test/fu/test_jb_unit.py
Original file line number Diff line number Diff line change
@@ -144,7 +144,7 @@ def compute_result_auipc(i1: int, i2: int, i_imm: int, pc: int, fn: JumpBranchFn
),
],
)
class JumpBranchUnitTest(FunctionalUnitTestCase[JumpBranchFn.Fn]):
class TestJumpBranchUnit(FunctionalUnitTestCase[JumpBranchFn.Fn]):
compute_result = compute_result
zero_imm = False

2 changes: 1 addition & 1 deletion test/transactions/test_branches.py
Original file line number Diff line number Diff line change
@@ -96,4 +96,4 @@ def test_conflict_removal(self):
cgr, _, _ = tm._conflict_graph(MethodMap(tm.transactions))

for s in cgr.values():
self.assertFalse(s)
assert not s
55 changes: 28 additions & 27 deletions test/transactions/test_transactions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest.case import TestCase
import pytest
from amaranth import *
from amaranth.sim import *

@@ -32,37 +33,37 @@ def __init__(self):
Transaction(manager=mgr)

T()
self.assertEqual(mgr.transactions[0].name, "T")
assert mgr.transactions[0].name== "T"

t = Transaction(name="x", manager=mgr)
self.assertEqual(t.name, "x")
assert t.name== "x"

t = Transaction(manager=mgr)
self.assertEqual(t.name, "t")
assert t.name== "t"

m = Method(name="x")
self.assertEqual(m.name, "x")
assert m.name== "x"

m = Method()
self.assertEqual(m.name, "m")
assert m.name== "m"


class TestScheduler(TestCaseWithSimulator):
def count_test(self, sched, cnt):
self.assertEqual(sched.count, cnt)
self.assertEqual(len(sched.requests), cnt)
self.assertEqual(len(sched.grant), cnt)
self.assertEqual(len(sched.valid), 1)
assert sched.count== cnt
assert len(sched.requests)== cnt
assert len(sched.grant)== cnt
assert len(sched.valid)== 1

def sim_step(self, sched, request, expected_grant):
yield sched.requests.eq(request)
yield

if request == 0:
self.assertFalse((yield sched.valid))
assert not (yield sched.valid)
else:
self.assertEqual((yield sched.grant), expected_grant)
self.assertTrue((yield sched.valid))
assert (yield sched.grant)== expected_grant
assert (yield sched.valid)

def test_single(self):
sched = Scheduler(1)
@@ -147,7 +148,7 @@ def tgt(x: int):
self.out1_expected.append(x)

def chk(x: int):
self.assertEqual(x, self.in_expected.popleft())
assert x== self.in_expected.popleft()

return self.make_process(self.m.in1, prob, self.in1_stream, tgt, chk)

@@ -156,7 +157,7 @@ def tgt(x: int):
self.out2_expected.append(x)

def chk(x: int):
self.assertEqual(x, self.in_expected.popleft())
assert x== self.in_expected.popleft()

return self.make_process(self.m.in2, prob, self.in2_stream, tgt, chk)

@@ -170,7 +171,7 @@ def chk(x: int):
elif self.out2_expected and x == self.out2_expected[0]:
self.out2_expected.popleft()
else:
self.fail("%d not found in any of the queues" % x)
assert False, "%d not found in any of the queues" % x

return self.make_process(self.m.out, prob, self.out_stream, tgt, chk)

@@ -195,9 +196,9 @@ def test_calls(self, name, prob1, prob2, probout):
sim.add_sync_process(self.make_in2_process(prob2))
sim.add_sync_process(self.make_out_process(probout))

self.assertFalse(self.in_expected)
self.assertFalse(self.out1_expected)
self.assertFalse(self.out2_expected)
assert not self.in_expected
assert not self.out1_expected
assert not self.out2_expected


class SchedulingTestCircuit(Elaboratable):
@@ -284,12 +285,12 @@ def process():
yield m.r1.eq(r1)
yield m.r2.eq(r2)
yield Settle()
self.assertNotEqual((yield m.t1), (yield m.t2))
assert (yield m.t1)!= (yield m.t2)
if r1 == 1 and r2 == 1:
if priority == Priority.LEFT:
self.assertTrue((yield m.t1))
assert (yield m.t1)
if priority == Priority.RIGHT:
self.assertTrue((yield m.t2))
assert (yield m.t2)

with self.run_simulation(m) as sim:
sim.add_process(process)
@@ -301,7 +302,7 @@ def test_unsatisfiable(self, priority: Priority):
import graphlib

if priority != Priority.UNDEFINED:
cm = self.assertRaises(graphlib.CycleError)
cm = pytest.raises(graphlib.CycleError)
else:
cm = contextlib.nullcontext()

@@ -369,8 +370,8 @@ def process():
yield m.r1.eq(r1)
yield m.r2.eq(r2)
yield
self.assertEqual((yield m.t1), r1)
self.assertEqual((yield m.t2), r1 * r2)
assert (yield m.t1)== r1
assert (yield m.t2)== r1 * r2

with self.run_simulation(m) as sim:
sim.add_sync_process(process)
@@ -415,8 +416,8 @@ def process():
yield m.r1.eq(r1)
yield m.r2.eq(r2)
yield
self.assertEqual((yield m.t1), r1)
self.assertFalse((yield m.t2))
assert (yield m.t1)== r1
assert not (yield m.t2)

with self.run_simulation(m) as sim:
sim.add_sync_process(process)
@@ -441,6 +442,6 @@ class TestSingleCaller(TestCaseWithSimulator):
def test_single_caller(self):
m = SingleCallerTestCircuit()

with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
with self.run_simulation(m):
pass
2 changes: 1 addition & 1 deletion test/utils/test_fifo.py
Original file line number Diff line number Diff line change
@@ -70,7 +70,7 @@ def target():

v = yield from fifoc.fifo_read.call_result()
if v is not None:
self.assertEqual(v["data"], expq.pop())
assert v["data"]==expq.pop()

yield from fifoc.fifo_read.disable()

4 changes: 2 additions & 2 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ def elaborate(self, platform):
class TestPopcount(TestCaseWithSimulator):
size: int

def setUp(self):
def setup_method(self):
random.seed(14)
self.test_number = 40
self.m = PopcountTestCircuit(self.size)
@@ -82,7 +82,7 @@ def check(self, n):
yield self.m.sig_in.eq(n)
yield Settle()
out_popcount = yield self.m.sig_out
self.assertEqual(out_popcount, n.bit_count(), f"{n:x}")
assert out_popcount==n.bit_count(), f"{n:x}"

def process(self):
for i in range(self.test_number):
104 changes: 59 additions & 45 deletions transactron/testing/infrastructure.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import sys
import pytest
import os
import random
import unittest
import functools
from contextlib import contextmanager, nullcontext
from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast, TypeAlias
from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast, TypeAlias, Optional
from abc import ABC
from amaranth import *
from amaranth.sim import *
@@ -202,27 +202,14 @@ def run(self) -> bool:
return not self.advance()


class TestCaseWithSimulator(unittest.TestCase):
class TestCaseWithSimulator():
dependency_manager: DependencyManager

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@pytest.fixture(autouse=True)
def configure_dependency_context(self, request):
self.dependency_manager = DependencyManager()

def wrap(f: Callable[[], None]):
@functools.wraps(f)
def wrapper():
with DependencyContext(self.dependency_manager):
f()

return wrapper

for k in dir(self):
if k.startswith("test") or k == "setUp":
f = getattr(self, k)
if isinstance(f, Callable):
setattr(self, k, wrap(getattr(self, k)))
with DependencyContext(self.dependency_manager):
yield

def add_class_mocks(self, sim: PysimSimulator) -> None:
for key in dir(self):
@@ -239,47 +226,74 @@ def add_all_mocks(self, sim: PysimSimulator, frame_locals: dict) -> None:
self.add_class_mocks(sim)
self.add_local_mocks(sim, frame_locals)

@contextmanager
def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True):
@pytest.fixture(autouse=True)
def configure_traces(self, request):
traces_file = None
if "__TRANSACTRON_DUMP_TRACES" in os.environ:
traces_file = unittest.TestCase.id(self)
traces_file = ".".join(request.node.nodeid.split("/"))
self._transactron_infrastructure_traces_file = traces_file


@pytest.fixture(autouse=True)
def fixture_sim_processes_to_add(self):
# By default return empty lists, it will be updated by other fixtures based on needs
self._transactron_sim_processes_to_add : list[Callable[[], Optional[Callable]]] = []

@pytest.fixture(autouse=True)
def configure_profiles(self, request, fixture_sim_processes_to_add, configure_dependency_context):
profile=None
if "__TRANSACTRON_PROFILE" in os.environ:
def f():
nonlocal profile
try:
transaction_manager = DependencyContext.get().get_dependency(TransactionManagerKey())
profile = Profile()
return profiler_process(transaction_manager, profile)
except KeyError:
pass
return None

self._transactron_sim_processes_to_add.append(f)

yield

if profile is not None:
profile_dir = "test/__profiles__"
profile_file = ".".join(request.node.nodeid.split("/"))
os.makedirs(profile_dir, exist_ok=True)
profile.encode(f"{profile_dir}/{profile_file}.json")

@pytest.fixture(autouse=True)
def configure_logging(self, fixture_sim_processes_to_add):
def on_error():
assert False, "Simulation finished due to an error"

log_level = parse_logging_level(os.environ["__TRANSACTRON_LOG_LEVEL"])
log_filter = os.environ["__TRANSACTRON_LOG_FILTER"]
self._transactron_sim_processes_to_add.append(lambda: make_logging_process(log_level, log_filter, on_error))


@contextmanager
def run_simulation(self, module: HasElaborate, max_cycles: float = 10e4, add_transaction_module=True):
clk_period = 1e-6
sim = PysimSimulator(
module,
max_cycles=max_cycles,
add_transaction_module=add_transaction_module,
traces_file=traces_file,
traces_file=self._transactron_infrastructure_traces_file,
clk_period=clk_period,
)
self.add_all_mocks(sim, sys._getframe(2).f_locals)

yield sim

profile = None
if "__TRANSACTRON_PROFILE" in os.environ and isinstance(sim.tested_module, TransactionModule):
profile = Profile()
sim.add_sync_process(
profiler_process(sim.tested_module.manager.get_dependency(TransactionManagerKey()), profile)
)

def on_error():
self.assertTrue(False, "Simulation finished due to an error")

log_level = parse_logging_level(os.environ["__TRANSACTRON_LOG_LEVEL"])
log_filter = os.environ["__TRANSACTRON_LOG_FILTER"]
sim.add_sync_process(make_logging_process(log_level, log_filter, on_error))
for f in self._transactron_sim_processes_to_add:
ret = f()
if ret is not None:
sim.add_sync_process(ret)

res = sim.run()

if profile is not None:
profile_dir = "test/__profiles__"
profile_file = unittest.TestCase.id(self)
os.makedirs(profile_dir, exist_ok=True)
profile.encode(f"{profile_dir}/{profile_file}.json")

self.assertTrue(res, "Simulation time limit exceeded")
assert res, "Simulation time limit exceeded"

def tick(self, cycle_cnt: int = 1):
"""

0 comments on commit a47dcaa

Please sign in to comment.