Skip to content

Commit

Permalink
utils: add cyclic_mask (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
piotro888 authored Jan 14, 2025
1 parent bc15c17 commit 60f0540
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
54 changes: 54 additions & 0 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
popcount,
count_leading_zeros,
count_trailing_zeros,
cyclic_mask,
)
from amaranth.utils import ceil_log2

Expand Down Expand Up @@ -191,3 +192,56 @@ async def process(self, sim: TestbenchContext):
def test_count_trailing_zeros(self, size):
with self.run_simulation(self.m) as sim:
sim.add_testbench(self.process)


class GenCyclicMaskTestCircuit(Elaboratable):
def __init__(self, xlen: int):
self.start = Signal(range(xlen))
self.end = Signal(range(xlen))
self.sig_out = Signal(xlen)
self.xlen = xlen

def elaborate(self, platform):
m = Module()

m.d.comb += self.sig_out.eq(cyclic_mask(self.xlen, self.start, self.end))

return m


@pytest.mark.parametrize("size", [1, 2, 3, 5, 8])
class TestGenCyclicMask(TestCaseWithSimulator):
@pytest.fixture(scope="function", autouse=True)
def setup_fixture(self, size):
self.size = size
random.seed(14)
self.test_number = 40
self.m = GenCyclicMaskTestCircuit(self.size)

async def check(self, sim: TestbenchContext, start, end):
sim.set(self.m.start, start)
sim.set(self.m.end, end)
await sim.delay(1e-6)
out = sim.get(self.m.sig_out)

expected = 0
for i in range(min(start, end), max(start, end) + 1):
expected |= 1 << i

if end < start:
expected ^= (1 << self.size) - 1
expected |= 1 << start
expected |= 1 << end

assert out == expected

async def process(self, sim: TestbenchContext):
for _ in range(self.test_number):
start = random.randrange(self.size)
end = random.randrange(self.size)
await self.check(sim, start, end)
await sim.delay(1e-6)

def test_count_trailing_zeros(self, size):
with self.run_simulation(self.m) as sim:
sim.add_testbench(self.process)
21 changes: 21 additions & 0 deletions transactron/utils/amaranth_ext/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"popcount",
"count_leading_zeros",
"count_trailing_zeros",
"cyclic_mask",
"flatten_signals",
"shape_of",
"const_of",
Expand Down Expand Up @@ -76,6 +77,26 @@ def count_trailing_zeros(s: Value) -> Value:
return count_leading_zeros(s[::-1])


def cyclic_mask(bits: int, start: Value, end: Value):
"""
Generate `bits` bit-wide mask with ones from `start` to `end` position, including both ends.
If `end` value is < than `start` the mask wraps around.
"""
start = start.as_unsigned()
end = end.as_unsigned()

# start <= end
length = (end - start + 1).as_unsigned()
mask_se = ((1 << length) - 1) << start

# start > end
left = (1 << (end + 1)) - 1
right = (1 << ((bits - start).as_unsigned())) - 1
mask_es = left | (right << start)

return Mux(start <= end, mask_se, mask_es)


def flatten_signals(signals: SignalBundle) -> Iterable[Signal]:
"""
Flattens input data, which can be either a signal, a record, a list (or a dict) of SignalBundle items.
Expand Down

0 comments on commit 60f0540

Please sign in to comment.