Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move StableSelectingNetwork out of lib #40

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pytest
import random

from transactron.lib import StableSelectingNetwork
from transactron.utils import StableSelectingNetwork
from transactron.testing import TestCaseWithSimulator, TestbenchContext


class TestStableSelectingNetwork(TestCaseWithSimulator):

@pytest.mark.parametrize("n", [2, 3, 7, 8])
def test(self, n: int):
m = StableSelectingNetwork(n, [("data", 8)])
m = StableSelectingNetwork(n, 8)

random.seed(42)

Expand All @@ -22,13 +22,13 @@ async def process(sim: TestbenchContext):
expected_output_prefix = []
for i in range(n):
sim.set(m.valids[i], valids[i])
sim.set(m.inputs[i].data, inputs[i])
sim.set(m.inputs[i], inputs[i])

if valids[i]:
expected_output_prefix.append(inputs[i])

for i in range(total):
out = sim.get(m.outputs[i].data)
out = sim.get(m.outputs[i])
assert out == expected_output_prefix[i]

assert sim.get(m.output_cnt) == total
Expand Down
80 changes: 0 additions & 80 deletions transactron/lib/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"Connect",
"ConnectTrans",
"ManyToOneConnectTrans",
"StableSelectingNetwork",
"Pipe",
]

Expand Down Expand Up @@ -343,82 +342,3 @@ def elaborate(self, platform):
)

return m


class StableSelectingNetwork(Elaboratable):
"""A network that groups inputs with a valid bit set.

The circuit takes `n` inputs with a valid signal each and
on the output returns a grouped and consecutive sequence of the provided
input signals. The order of valid inputs is preserved.

For example for input (0 is an invalid input):
0, a, 0, d, 0, 0, e

The circuit will return:
a, d, e, 0, 0, 0, 0

The circuit uses a divide and conquer algorithm.
The recursive call takes two bit vectors and each of them
is already properly sorted, for example:
v1 = [a, b, 0, 0]; v2 = [c, d, e, 0]

Now by shifting left v2 and merging it with v1, we get the result:
v = [a, b, c, d, e, 0, 0, 0]

Thus, the network has depth log_2(n).

"""

def __init__(self, n: int, layout: MethodLayout):
self.n = n
self.layout = from_method_layout(layout)

self.inputs = [Signal(self.layout) for _ in range(n)]
self.valids = [Signal() for _ in range(n)]

self.outputs = [Signal(self.layout) for _ in range(n)]
self.output_cnt = Signal(range(n + 1))

def elaborate(self, platform):
m = TModule()

current_level = []
for i in range(self.n):
current_level.append((Array([self.inputs[i]]), self.valids[i]))

# Create the network using the bottom-up approach.
while len(current_level) >= 2:
next_level = []
while len(current_level) >= 2:
a, cnt_a = current_level.pop(0)
b, cnt_b = current_level.pop(0)

total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1)
m.d.comb += total_cnt.eq(cnt_a + cnt_b)

total_len = len(a) + len(b)
merged = Array(Signal(self.layout) for _ in range(total_len))

for i in range(len(a)):
m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i]))
for i in range(len(b)):
m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a]))

next_level.append((merged, total_cnt))

# If we had an odd number of elements on the current level,
# move the item left to the next level.
if len(current_level) == 1:
next_level.append(current_level.pop(0))

current_level = next_level

last_level, total_cnt = current_level.pop(0)

for i in range(self.n):
m.d.comb += self.outputs[i].eq(last_level[i])

m.d.comb += self.output_cnt.eq(total_cnt)

return m
104 changes: 92 additions & 12 deletions transactron/utils/amaranth_ext/elaboratables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Literal, Optional, overload
from collections.abc import Iterable
from amaranth import *
from amaranth.lib.data import ArrayLayout
from amaranth_types import ShapeLike
from transactron.utils._typing import HasElaborate, ModuleLike, ValueLike

__all__ = [
Expand All @@ -13,6 +15,7 @@
"RoundRobin",
"MultiPriorityEncoder",
"RingMultiPriorityEncoder",
"StableSelectingNetwork",
]


Expand Down Expand Up @@ -270,13 +273,13 @@ def __init__(self, input_width: int, outputs_count: int):
self.outputs_count = outputs_count

self.input = Signal(self.input_width)
self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)]
self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)]
self.outputs = Signal(ArrayLayout(range(self.input_width), self.outputs_count))
self.valids = Signal(self.outputs_count)

@staticmethod
def create(
m: Module, input_width: int, input: ValueLike, outputs_count: int = 1, name: Optional[str] = None
) -> list[tuple[Signal, Signal]]:
) -> list[tuple[Value, Value]]:
"""Syntax sugar for creating MultiPriorityEncoder

This static method allows to use MultiPriorityEncoder in a more functional
Expand Down Expand Up @@ -325,12 +328,10 @@ def create(
except AttributeError:
setattr(m.submodules, name, prio_encoder)
m.d.comb += prio_encoder.input.eq(input)
return list(zip(prio_encoder.outputs, prio_encoder.valids))
return [(prio_encoder.outputs[i], prio_encoder.valids[i]) for i in range(outputs_count)]

@staticmethod
def create_simple(
m: Module, input_width: int, input: ValueLike, name: Optional[str] = None
) -> tuple[Signal, Signal]:
def create_simple(m: Module, input_width: int, input: ValueLike, name: Optional[str] = None) -> tuple[Value, Value]:
"""Syntax sugar for creating MultiPriorityEncoder

This is the same as `create` function, but with `outputs_count` hardcoded to 1.
Expand Down Expand Up @@ -420,8 +421,8 @@ def __init__(self, input_width: int, outputs_count: int):
self.input = Signal(self.input_width)
self.first = Signal(range(self.input_width))
self.last = Signal(range(self.input_width))
self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)]
self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)]
self.outputs = Signal(ArrayLayout(range(self.input_width), self.outputs_count))
self.valids = Signal(self.outputs_count)

@staticmethod
def create(
Expand All @@ -432,7 +433,7 @@ def create(
last: ValueLike,
outputs_count: int = 1,
name: Optional[str] = None,
) -> list[tuple[Signal, Signal]]:
) -> list[tuple[Value, Value]]:
"""Syntax sugar for creating RingMultiPriorityEncoder

This static method allows to use RingMultiPriorityEncoder in a more functional
Expand Down Expand Up @@ -491,12 +492,12 @@ def create(
m.d.comb += prio_encoder.input.eq(input)
m.d.comb += prio_encoder.first.eq(first)
m.d.comb += prio_encoder.last.eq(last)
return list(zip(prio_encoder.outputs, prio_encoder.valids))
return [(prio_encoder.outputs[i], prio_encoder.valids[i]) for i in range(outputs_count)]

@staticmethod
def create_simple(
m: Module, input_width: int, input: ValueLike, first: ValueLike, last: ValueLike, name: Optional[str] = None
) -> tuple[Signal, Signal]:
) -> tuple[Value, Value]:
"""Syntax sugar for creating RingMultiPriorityEncoder

This is the same as `create` function, but with `outputs_count` hardcoded to 1.
Expand Down Expand Up @@ -530,3 +531,82 @@ def elaborate(self, platform):
m.d.comb += self.outputs[k].eq(corrected_out)
m.d.comb += self.valids[k].eq(multi_enc.valids[k])
return m


class StableSelectingNetwork(Elaboratable):
"""A network that groups inputs with a valid bit set.

The circuit takes `n` inputs with a valid signal each and
on the output returns a grouped and consecutive sequence of the provided
input signals. The order of valid inputs is preserved.

For example for input (0 is an invalid input):
0, a, 0, d, 0, 0, e

The circuit will return:
a, d, e, 0, 0, 0, 0

The circuit uses a divide and conquer algorithm.
The recursive call takes two bit vectors and each of them
is already properly sorted, for example:
v1 = [a, b, 0, 0]; v2 = [c, d, e, 0]

Now by shifting left v2 and merging it with v1, we get the result:
v = [a, b, c, d, e, 0, 0, 0]

Thus, the network has depth log_2(n).

"""

def __init__(self, n: int, shape: ShapeLike):
self.n = n
self.shape = shape

self.inputs = Signal(ArrayLayout(shape, n))
self.valids = Signal(n)

self.outputs = Signal(ArrayLayout(shape, n))
self.output_cnt = Signal(range(n + 1))

def elaborate(self, platform):
m = Module()

current_level = []
for i in range(self.n):
current_level.append((Array([self.inputs[i]]), self.valids[i]))

# Create the network using the bottom-up approach.
while len(current_level) >= 2:
next_level = []
while len(current_level) >= 2:
a, cnt_a = current_level.pop(0)
b, cnt_b = current_level.pop(0)

total_cnt = Signal(max(len(cnt_a), len(cnt_b)) + 1)
m.d.comb += total_cnt.eq(cnt_a + cnt_b)

total_len = len(a) + len(b)
merged = Array(Signal(self.shape) for _ in range(total_len))

for i in range(len(a)):
m.d.comb += merged[i].eq(Mux(cnt_a <= i, b[i - cnt_a], a[i]))
for i in range(len(b)):
m.d.comb += merged[len(a) + i].eq(Mux(len(a) + i - cnt_a >= len(b), 0, b[len(a) + i - cnt_a]))

next_level.append((merged, total_cnt))

# If we had an odd number of elements on the current level,
# move the item left to the next level.
if len(current_level) == 1:
next_level.append(current_level.pop(0))

current_level = next_level

last_level, total_cnt = current_level.pop(0)

for i in range(self.n):
m.d.comb += self.outputs[i].eq(last_level[i])

m.d.comb += self.output_cnt.eq(total_cnt)

return m
Loading