diff --git a/test/utils/amaranth_ext/test_min_value.py b/test/utils/amaranth_ext/test_min_value.py new file mode 100644 index 0000000..7a43c43 --- /dev/null +++ b/test/utils/amaranth_ext/test_min_value.py @@ -0,0 +1,36 @@ +from transactron.testing import * +from transactron.utils.amaranth_ext import min_value +import random + + +class MinValueCircuit(Elaboratable): + def __init__(self, bits: int, num: int): + self.inputs = [Signal(bits) for _ in range(num)] + self.output = Signal(bits) + + def elaborate(self, platform): + m = Module() + + m.d.comb += self.output.eq(min_value(m, self.inputs)) + + return m + + +class TestMinValue(TestCaseWithSimulator): + def test_min_value(self): + bits = 4 + num = 3 + num_tests = 100 + + circ = MinValueCircuit(bits, num) + + async def testbench(sim: TestbenchContext): + for _ in range(num_tests): + vals = [random.randrange(2**bits) for _ in circ.inputs] + for sig, val in zip(circ.inputs, vals): + sim.set(sig, val) + assert sim.get(circ.output) == min(vals) + await sim.tick() + + with self.run_simulation(circ) as sim: + sim.add_testbench(testbench) diff --git a/transactron/utils/amaranth_ext/functions.py b/transactron/utils/amaranth_ext/functions.py index b273c0a..0de73d2 100644 --- a/transactron/utils/amaranth_ext/functions.py +++ b/transactron/utils/amaranth_ext/functions.py @@ -5,7 +5,7 @@ from amaranth.lib import data from collections.abc import Iterable, Mapping -from amaranth_types.types import ValueLike, ShapeLike +from amaranth_types.types import ValueLike, ShapeLike, ModuleLike from transactron.utils._typing import SignalBundle __all__ = [ @@ -17,6 +17,7 @@ "flatten_signals", "shape_of", "const_of", + "min_value" ] @@ -132,3 +133,26 @@ def const_of(value: int, shape: ShapeLike) -> Any: return shape.from_bits(value) else: return C(value, Shape.cast(shape)) + + +def min_value(m: ModuleLike, values: Iterable[Value]) -> Value: + values = list(values) + assert all(not value.shape().signed for value in values) # signed currently unsupported + result = Signal(max(len(value) for value in values)) + + # extend inputs to constant width + new_values = list(Signal.like(result) for _ in values) + for sig, value in zip(new_values, values): + m.d.comb += sig.eq(value) + values = new_values + + for i in reversed(range(0, len(result))): + res = Signal() + m.d.comb += res.eq(Cat(value[i] for value in values).all()) + m.d.comb += result[i].eq(res) + new_values = list(Signal.like(result) for _ in values) + for sig, value in zip(new_values, values): + m.d.comb += sig.eq(Mux(value[i] & ~res, -1, value)) + values = new_values + + return result