From bd94bf321cd7f2c726bb8847fc277c02046fcf98 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Fri, 17 Jan 2025 12:58:29 +0100 Subject: [PATCH 1/2] Move StableSelectingNetwork out of lib --- .../amaranth_ext/test_elaboratables.py} | 8 +- transactron/lib/connectors.py | 80 ------------------ .../utils/amaranth_ext/elaboratables.py | 81 +++++++++++++++++++ 3 files changed, 85 insertions(+), 84 deletions(-) rename test/{test_connectors.py => utils/amaranth_ext/test_elaboratables.py} (82%) diff --git a/test/test_connectors.py b/test/utils/amaranth_ext/test_elaboratables.py similarity index 82% rename from test/test_connectors.py rename to test/utils/amaranth_ext/test_elaboratables.py index 030c991..2b69ab7 100644 --- a/test/test_connectors.py +++ b/test/utils/amaranth_ext/test_elaboratables.py @@ -1,7 +1,7 @@ import pytest import random -from transactron.lib import StableSelectingNetwork +from transactron.utils import StableSelectingNetwork from transactron.testing import TestCaseWithSimulator, TestbenchContext @@ -9,7 +9,7 @@ class TestStableSelectingNetwork(TestCaseWithSimulator): @pytest.mark.parametrize("n", [2, 3, 7, 8]) def test(self, n: int): - m = StableSelectingNetwork(n, [("data", 8)]) + m = StableSelectingNetwork(n, 8) random.seed(42) @@ -22,13 +22,13 @@ async def process(sim: TestbenchContext): expected_output_prefix = [] for i in range(n): sim.set(m.valids[i], valids[i]) - sim.set(m.inputs[i].data, inputs[i]) + sim.set(m.inputs[i], inputs[i]) if valids[i]: expected_output_prefix.append(inputs[i]) for i in range(total): - out = sim.get(m.outputs[i].data) + out = sim.get(m.outputs[i]) assert out == expected_output_prefix[i] assert sim.get(m.output_cnt) == total diff --git a/transactron/lib/connectors.py b/transactron/lib/connectors.py index 723660f..417bdf3 100644 --- a/transactron/lib/connectors.py +++ b/transactron/lib/connectors.py @@ -12,7 +12,6 @@ "Connect", "ConnectTrans", "ManyToOneConnectTrans", - "StableSelectingNetwork", "Pipe", ] @@ -343,82 +342,3 @@ def elaborate(self, platform): ) return m - - -class StableSelectingNetwork(Elaboratable): - """A network that groups inputs with a valid bit set. - - The circuit takes `n` inputs with a valid signal each and - on the output returns a grouped and consecutive sequence of the provided - input signals. The order of valid inputs is preserved. - - For example for input (0 is an invalid input): - 0, a, 0, d, 0, 0, e - - The circuit will return: - a, d, e, 0, 0, 0, 0 - - The circuit uses a divide and conquer algorithm. - The recursive call takes two bit vectors and each of them - is already properly sorted, for example: - v1 = [a, b, 0, 0]; v2 = [c, d, e, 0] - - Now by shifting left v2 and merging it with v1, we get the result: - v = [a, b, c, d, e, 0, 0, 0] - - Thus, the network has depth log_2(n). - - """ - - def __init__(self, n: int, layout: MethodLayout): - self.n = n - self.layout = from_method_layout(layout) - - self.inputs = [Signal(self.layout) for _ in range(n)] - self.valids = [Signal() for _ in range(n)] - - self.outputs = [Signal(self.layout) for _ in range(n)] - self.output_cnt = Signal(range(n + 1)) - - def elaborate(self, platform): - m = TModule() - - current_level = [] - for i in range(self.n): - current_level.append((Array([self.inputs[i]]), self.valids[i])) - - # Create the network using the bottom-up approach. - while len(current_level) >= 2: - next_level = [] - while len(current_level) >= 2: - a, cnt_a = current_level.pop(0) - b, cnt_b = current_level.pop(0) - - total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1) - m.d.comb += total_cnt.eq(cnt_a + cnt_b) - - total_len = len(a) + len(b) - merged = Array(Signal(self.layout) for _ in range(total_len)) - - for i in range(len(a)): - m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i])) - for i in range(len(b)): - m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a])) - - next_level.append((merged, total_cnt)) - - # If we had an odd number of elements on the current level, - # move the item left to the next level. - if len(current_level) == 1: - next_level.append(current_level.pop(0)) - - current_level = next_level - - last_level, total_cnt = current_level.pop(0) - - for i in range(self.n): - m.d.comb += self.outputs[i].eq(last_level[i]) - - m.d.comb += self.output_cnt.eq(total_cnt) - - return m diff --git a/transactron/utils/amaranth_ext/elaboratables.py b/transactron/utils/amaranth_ext/elaboratables.py index ed6b571..60c9654 100644 --- a/transactron/utils/amaranth_ext/elaboratables.py +++ b/transactron/utils/amaranth_ext/elaboratables.py @@ -3,6 +3,7 @@ from typing import Literal, Optional, overload from collections.abc import Iterable from amaranth import * +from amaranth_types import ShapeLike from transactron.utils._typing import HasElaborate, ModuleLike, ValueLike __all__ = [ @@ -13,6 +14,7 @@ "RoundRobin", "MultiPriorityEncoder", "RingMultiPriorityEncoder", + "StableSelectingNetwork", ] @@ -530,3 +532,82 @@ def elaborate(self, platform): m.d.comb += self.outputs[k].eq(corrected_out) m.d.comb += self.valids[k].eq(multi_enc.valids[k]) return m + + +class StableSelectingNetwork(Elaboratable): + """A network that groups inputs with a valid bit set. + + The circuit takes `n` inputs with a valid signal each and + on the output returns a grouped and consecutive sequence of the provided + input signals. The order of valid inputs is preserved. + + For example for input (0 is an invalid input): + 0, a, 0, d, 0, 0, e + + The circuit will return: + a, d, e, 0, 0, 0, 0 + + The circuit uses a divide and conquer algorithm. + The recursive call takes two bit vectors and each of them + is already properly sorted, for example: + v1 = [a, b, 0, 0]; v2 = [c, d, e, 0] + + Now by shifting left v2 and merging it with v1, we get the result: + v = [a, b, c, d, e, 0, 0, 0] + + Thus, the network has depth log_2(n). + + """ + + def __init__(self, n: int, shape: ShapeLike): + self.n = n + self.shape = shape + + self.inputs = [Signal(shape) for _ in range(n)] + self.valids = [Signal() for _ in range(n)] + + self.outputs = [Signal(shape) for _ in range(n)] + self.output_cnt = Signal(range(n + 1)) + + def elaborate(self, platform): + m = Module() + + current_level = [] + for i in range(self.n): + current_level.append((Array([self.inputs[i]]), self.valids[i])) + + # Create the network using the bottom-up approach. + while len(current_level) >= 2: + next_level = [] + while len(current_level) >= 2: + a, cnt_a = current_level.pop(0) + b, cnt_b = current_level.pop(0) + + total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1) + m.d.comb += total_cnt.eq(cnt_a + cnt_b) + + total_len = len(a) + len(b) + merged = Array(Signal(self.shape) for _ in range(total_len)) + + for i in range(len(a)): + m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i])) + for i in range(len(b)): + m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a])) + + next_level.append((merged, total_cnt)) + + # If we had an odd number of elements on the current level, + # move the item left to the next level. + if len(current_level) == 1: + next_level.append(current_level.pop(0)) + + current_level = next_level + + last_level, total_cnt = current_level.pop(0) + + for i in range(self.n): + m.d.comb += self.outputs[i].eq(last_level[i]) + + m.d.comb += self.output_cnt.eq(total_cnt) + + return m From 334ca3d77e787a9e8deb4ce47b923414f4ea62a4 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Fri, 17 Jan 2025 13:14:28 +0100 Subject: [PATCH 2/2] More sensible interfaces --- .../utils/amaranth_ext/elaboratables.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/transactron/utils/amaranth_ext/elaboratables.py b/transactron/utils/amaranth_ext/elaboratables.py index 60c9654..253f9bb 100644 --- a/transactron/utils/amaranth_ext/elaboratables.py +++ b/transactron/utils/amaranth_ext/elaboratables.py @@ -3,6 +3,7 @@ from typing import Literal, Optional, overload from collections.abc import Iterable from amaranth import * +from amaranth.lib.data import ArrayLayout from amaranth_types import ShapeLike from transactron.utils._typing import HasElaborate, ModuleLike, ValueLike @@ -272,13 +273,13 @@ def __init__(self, input_width: int, outputs_count: int): self.outputs_count = outputs_count self.input = Signal(self.input_width) - self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)] - self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)] + self.outputs = Signal(ArrayLayout(range(self.input_width), self.outputs_count)) + self.valids = Signal(self.outputs_count) @staticmethod def create( m: Module, input_width: int, input: ValueLike, outputs_count: int = 1, name: Optional[str] = None - ) -> list[tuple[Signal, Signal]]: + ) -> list[tuple[Value, Value]]: """Syntax sugar for creating MultiPriorityEncoder This static method allows to use MultiPriorityEncoder in a more functional @@ -327,12 +328,10 @@ def create( except AttributeError: setattr(m.submodules, name, prio_encoder) m.d.comb += prio_encoder.input.eq(input) - return list(zip(prio_encoder.outputs, prio_encoder.valids)) + return [(prio_encoder.outputs[i], prio_encoder.valids[i]) for i in range(outputs_count)] @staticmethod - def create_simple( - m: Module, input_width: int, input: ValueLike, name: Optional[str] = None - ) -> tuple[Signal, Signal]: + def create_simple(m: Module, input_width: int, input: ValueLike, name: Optional[str] = None) -> tuple[Value, Value]: """Syntax sugar for creating MultiPriorityEncoder This is the same as `create` function, but with `outputs_count` hardcoded to 1. @@ -422,8 +421,8 @@ def __init__(self, input_width: int, outputs_count: int): self.input = Signal(self.input_width) self.first = Signal(range(self.input_width)) self.last = Signal(range(self.input_width)) - self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)] - self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)] + self.outputs = Signal(ArrayLayout(range(self.input_width), self.outputs_count)) + self.valids = Signal(self.outputs_count) @staticmethod def create( @@ -434,7 +433,7 @@ def create( last: ValueLike, outputs_count: int = 1, name: Optional[str] = None, - ) -> list[tuple[Signal, Signal]]: + ) -> list[tuple[Value, Value]]: """Syntax sugar for creating RingMultiPriorityEncoder This static method allows to use RingMultiPriorityEncoder in a more functional @@ -493,12 +492,12 @@ def create( m.d.comb += prio_encoder.input.eq(input) m.d.comb += prio_encoder.first.eq(first) m.d.comb += prio_encoder.last.eq(last) - return list(zip(prio_encoder.outputs, prio_encoder.valids)) + return [(prio_encoder.outputs[i], prio_encoder.valids[i]) for i in range(outputs_count)] @staticmethod def create_simple( m: Module, input_width: int, input: ValueLike, first: ValueLike, last: ValueLike, name: Optional[str] = None - ) -> tuple[Signal, Signal]: + ) -> tuple[Value, Value]: """Syntax sugar for creating RingMultiPriorityEncoder This is the same as `create` function, but with `outputs_count` hardcoded to 1. @@ -563,10 +562,10 @@ def __init__(self, n: int, shape: ShapeLike): self.n = n self.shape = shape - self.inputs = [Signal(shape) for _ in range(n)] - self.valids = [Signal() for _ in range(n)] + self.inputs = Signal(ArrayLayout(shape, n)) + self.valids = Signal(n) - self.outputs = [Signal(shape) for _ in range(n)] + self.outputs = Signal(ArrayLayout(shape, n)) self.output_cnt = Signal(range(n + 1)) def elaborate(self, platform):