Skip to content

Commit

Permalink
Adapt transactron.lib.storage to amaranth.lib.data (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
piotro888 authored Jan 17, 2025
1 parent 60f0540 commit 2bcd9d3
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 48 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "transactron"
dynamic = ["version"]
dependencies = [
"amaranth == 0.5.3",
"amaranth-stubs @ git+https://github.com/kuznia-rdzeni/amaranth-stubs.git@edb302b001433edf4c8568190adc9bd0c0039f45",
"amaranth-stubs @ git+https://github.com/kuznia-rdzeni/amaranth-stubs.git@a93c5da4b939065b3c60534a8bd143865f3929c4",
"dataclasses-json == 0.6.3",
"tabulate == 0.9.0"
]
Expand Down
18 changes: 11 additions & 7 deletions test/lib/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from hypothesis import given, settings, Phase
from transactron.testing import *
from transactron.lib.storage import *
from transactron.utils.transactron_helpers import make_layout


class TestContentAddressableMemory(TestCaseWithSimulator):
Expand Down Expand Up @@ -157,8 +158,8 @@ def test_mem(
data_width = 6
m = SimpleTestCircuit(
MemoryBank(
data_layout=[("data", data_width)],
elem_count=max_addr,
shape=make_layout(("data_field", data_width)),
depth=max_addr,
transparent=transparent,
read_ports=read_ports,
write_ports=write_ports,
Expand All @@ -175,7 +176,7 @@ async def process(sim: TestbenchContext):
for cycle in range(test_count):
d = random.randrange(2**data_width)
a = random.randrange(max_addr)
await m.write[i].call(sim, data={"data": d}, addr=a)
await m.write[i].call(sim, data={"data_field": d}, addr=a)
await sim.delay(1e-9 * (i + 2 if not transparent else i))
data[a] = d
await self.random_wait(sim, writer_rand)
Expand All @@ -202,7 +203,7 @@ async def process(sim: TestbenchContext):
await self.random_wait(sim, reader_resp_rand or 1, min_cycle_cnt=1)
await sim.delay(1e-9 * (write_ports + 3))
d = read_req_queues[i].popleft()
assert (await m.read_resp[i].call(sim)).data == d
assert (await m.read_resp[i].call(sim)).data.data_field == d
await self.random_wait(sim, reader_resp_rand)

return process
Expand Down Expand Up @@ -230,7 +231,10 @@ def test_mem(self, max_addr: int, writer_rand: int, reader_rand: int, seed: int,
data_width = 6
m = SimpleTestCircuit(
AsyncMemoryBank(
data_layout=[("data", data_width)], elem_count=max_addr, read_ports=read_ports, write_ports=write_ports
shape=make_layout(("data_field", data_width)),
depth=max_addr,
read_ports=read_ports,
write_ports=write_ports,
),
)

Expand All @@ -243,7 +247,7 @@ async def process(sim: TestbenchContext):
for cycle in range(test_count):
d = random.randrange(2**data_width)
a = random.randrange(max_addr)
await m.write[i].call(sim, data={"data": d}, addr=a)
await m.write[i].call(sim, data={"data_field": d}, addr=a)
await sim.delay(1e-9 * (i + 2))
data[a] = d
await self.random_wait(sim, writer_rand, min_cycle_cnt=1)
Expand All @@ -257,7 +261,7 @@ async def process(sim: TestbenchContext):
d = await m.read[i].call(sim, addr=a)
await sim.delay(1e-9)
expected_d = data[a]
assert d["data"] == expected_d
assert d["data"]["data_field"] == expected_d
await self.random_wait(sim, reader_rand, min_cycle_cnt=1)

return process
Expand Down
5 changes: 3 additions & 2 deletions transactron/lib/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from transactron import Method, def_method, TModule
from transactron.lib import FIFO, AsyncMemoryBank, logging
from transactron.utils.dependencies import ListKey, DependencyContext, SimpleKey
from transactron.utils.transactron_helpers import make_layout

__all__ = [
"MetricRegisterModel",
Expand Down Expand Up @@ -661,7 +662,7 @@ def elaborate(self, platform):
epoch_width = bits_for(self.max_latency)

m.submodules.slots = self.slots = AsyncMemoryBank(
data_layout=[("epoch", epoch_width)], elem_count=self.slots_number
shape=make_layout(("epoch", epoch_width)), depth=self.slots_number
)
m.submodules.histogram = self.histogram

Expand Down Expand Up @@ -690,7 +691,7 @@ def _(slot: Value):
ret = self.slots.read(m, addr=slot)
# The result of substracting two unsigned n-bit is a signed (n+1)-bit value,
# so we need to cast the result and discard the most significant bit.
duration = (epoch - ret.epoch).as_unsigned()[:-1]
duration = (epoch - ret.data.epoch).as_unsigned()[:-1]
self.histogram.add(m, duration)

return m
Expand Down
94 changes: 56 additions & 38 deletions transactron/lib/storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from amaranth import *
from amaranth.lib.data import View
from amaranth.utils import *
import amaranth.lib.memory as memory
from amaranth_types import ShapeLike
import amaranth_types.memory as amemory

from transactron.utils.transactron_helpers import from_method_layout, make_layout
Expand Down Expand Up @@ -35,25 +37,26 @@ class MemoryBank(Elaboratable):
def __init__(
self,
*,
data_layout: LayoutList,
elem_count: int,
shape: ShapeLike,
depth: int,
granularity: Optional[int] = None,
transparent: bool = False,
read_ports: int = 1,
write_ports: int = 1,
memory_type: amemory.AbstractMemoryConstructor[int, Value] = memory.Memory,
memory_type: amemory.AbstractMemoryConstructor[ShapeLike, Value] = memory.Memory,
src_loc: int | SrcLoc = 0,
):
"""
Parameters
----------
data_layout: method layout
shape: ShapeLike
The format of structures stored in the Memory.
elem_count: int
depth: int
Number of elements stored in Memory.
granularity: Optional[int]
Granularity of write, forwarded to Amaranth. If `None` the whole structure is always saved at once.
If not, the width of `data_layout` is split into `granularity` parts, which can be saved independently.
Granularity of write. If `None` the whole structure is always saved at once.
If not, shape is split into `granularity` parts, which can be saved independently (according to
`amaranth.lib.memory` granularity logic).
transparent: bool
Read port transparency, false by default. When a read port is transparent, if a given memory address
is read and written in the same clock cycle, the read returns the written value instead of the value
Expand All @@ -67,37 +70,43 @@ def __init__(
Alternatively, the source location to use instead of the default.
"""
self.src_loc = get_src_loc(src_loc)
self.data_layout = make_layout(*data_layout)
self.elem_count = elem_count
self.shape = shape
self.depth = depth
self.granularity = granularity
self.width = from_method_layout(self.data_layout).size
self.addr_width = bits_for(self.elem_count - 1)
self.addr_width = bits_for(self.depth - 1)
self.transparent = transparent
self.reads_ports = read_ports
self.writes_ports = write_ports
self.memory_type = memory_type

self.read_reqs_layout: LayoutList = [("addr", self.addr_width)]
write_layout = [("addr", self.addr_width), ("data", self.data_layout)]
self.read_resps_layout = make_layout(("data", self.shape))
write_layout = [("addr", self.addr_width), ("data", self.shape)]
if self.granularity is not None:
write_layout.append(("mask", self.width // self.granularity))
# use Amaranth lib.memory granularity rule checks and width
amaranth_write_port_sig = memory.WritePort.Signature(
addr_width=0,
shape=self.shape, # type: ignore
granularity=granularity,
)
write_layout.append(("mask", amaranth_write_port_sig.members["en"].shape))
self.writes_layout = make_layout(*write_layout)

self.read_req = Methods(read_ports, i=self.read_reqs_layout, src_loc=self.src_loc)
self.read_resp = Methods(read_ports, o=self.data_layout, src_loc=self.src_loc)
self.read_resp = Methods(read_ports, o=self.read_resps_layout, src_loc=self.src_loc)
self.write = Methods(write_ports, i=self.writes_layout, src_loc=self.src_loc)

def elaborate(self, platform) -> TModule:
m = TModule()

m.submodules.mem = self.mem = mem = self.memory_type(shape=self.width, depth=self.elem_count, init=[])
write_port = [mem.write_port() for _ in range(self.writes_ports)]
m.submodules.mem = self.mem = mem = self.memory_type(shape=self.shape, depth=self.depth, init=[])
write_port = [mem.write_port(granularity=self.granularity) for _ in range(self.writes_ports)]
read_port = [
mem.read_port(transparent_for=write_port if self.transparent else []) for _ in range(self.reads_ports)
]
read_output_valid = [Signal() for _ in range(self.reads_ports)]
overflow_valid = [Signal() for _ in range(self.reads_ports)]
overflow_data = [Signal(self.width) for _ in range(self.reads_ports)]
overflow_data = [Signal(self.shape) for _ in range(self.reads_ports)]

# The read request method can be called at most twice when not reading the response.
# The first result is stored in the overflow buffer, the second - in the read value buffer of the memory.
Expand All @@ -114,7 +123,9 @@ def _(i: int):
m.d.sync += overflow_valid[i].eq(0)
with m.Else():
m.d.sync += read_output_valid[i].eq(0)
return Mux(overflow_valid[i], overflow_data[i], read_port[i].data)

# Amaranth Mux drops lib.data Layouts
return {"data": View(self.shape, Mux(overflow_valid[i], overflow_data[i], read_port[i].data))}

for i in range(self.reads_ports):
m.d.comb += read_port[i].en.eq(0) # because the init value is 1
Expand All @@ -123,12 +134,12 @@ def _(i: int):
def _(i: int, addr):
m.d.sync += read_output_valid[i].eq(1)
m.d.comb += read_port[i].en.eq(1)
m.d.comb += read_port[i].addr.eq(addr)
m.d.av_comb += read_port[i].addr.eq(addr)

@def_methods(m, self.write)
def _(i: int, arg):
m.d.comb += write_port[i].addr.eq(arg.addr)
m.d.comb += write_port[i].data.eq(arg.data)
m.d.av_comb += write_port[i].addr.eq(arg.addr)
m.d.av_comb += write_port[i].data.eq(arg.data)
if self.granularity is None:
m.d.comb += write_port[i].en.eq(1)
else:
Expand Down Expand Up @@ -254,24 +265,25 @@ class AsyncMemoryBank(Elaboratable):
def __init__(
self,
*,
data_layout: LayoutList,
elem_count: int,
shape: ShapeLike,
depth: int,
granularity: Optional[int] = None,
read_ports: int = 1,
write_ports: int = 1,
memory_type: amemory.AbstractMemoryConstructor[int, Value] = memory.Memory,
memory_type: amemory.AbstractMemoryConstructor[ShapeLike, Value] = memory.Memory,
src_loc: int | SrcLoc = 0,
):
"""
Parameters
----------
data_layout: method layout
shape: ShapeLike
The format of structures stored in the Memory.
elem_count: int
depth: int
Number of elements stored in Memory.
granularity: Optional[int]
Granularity of write, forwarded to Amaranth. If `None` the whole structure is always saved at once.
If not, the width of `data_layout` is split into `granularity` parts, which can be saved independently.
Granularity of write. If `None` the whole structure is always saved at once.
If not, shape is split into `granularity` parts, which can be saved independently (according to
`amaranth.lib.memory` granularity logic).
read_ports: int
Number of read ports.
write_ports: int
Expand All @@ -281,36 +293,42 @@ def __init__(
Alternatively, the source location to use instead of the default.
"""
self.src_loc = get_src_loc(src_loc)
self.data_layout = make_layout(*data_layout)
self.elem_count = elem_count
self.shape = shape
self.depth = depth
self.granularity = granularity
self.width = from_method_layout(self.data_layout).size
self.addr_width = bits_for(self.elem_count - 1)
self.addr_width = bits_for(self.depth - 1)
self.reads_ports = read_ports
self.writes_ports = write_ports
self.memory_type = memory_type

self.read_reqs_layout: LayoutList = [("addr", self.addr_width)]
write_layout = [("addr", self.addr_width), ("data", self.data_layout)]
self.read_resps_layout: LayoutList = [("data", self.shape)]
write_layout = [("addr", self.addr_width), ("data", self.shape)]
if self.granularity is not None:
write_layout.append(("mask", self.width // self.granularity))
# use Amaranth lib.memory granularity rule checks and width
amaranth_write_port_sig = memory.WritePort.Signature(
addr_width=0,
shape=shape, # type: ignore
granularity=granularity,
)
write_layout.append(("mask", amaranth_write_port_sig.members["en"].shape))
self.writes_layout = make_layout(*write_layout)

self.read = Methods(read_ports, i=self.read_reqs_layout, o=self.data_layout, src_loc=self.src_loc)
self.read = Methods(read_ports, i=self.read_reqs_layout, o=self.read_resps_layout, src_loc=self.src_loc)
self.write = Methods(write_ports, i=self.writes_layout, src_loc=self.src_loc)

def elaborate(self, platform) -> TModule:
m = TModule()

mem = self.memory_type(shape=self.width, depth=self.elem_count, init=[])
mem = self.memory_type(shape=self.shape, depth=self.depth, init=[])
m.submodules.mem = self.mem = mem
write_port = [mem.write_port() for _ in range(self.writes_ports)]
write_port = [mem.write_port(granularity=self.granularity) for _ in range(self.writes_ports)]
read_port = [mem.read_port(domain="comb") for _ in range(self.reads_ports)]

@def_methods(m, self.read)
def _(i: int, addr):
m.d.comb += read_port[i].addr.eq(addr)
return read_port[i].data
return {"data": read_port[i].data}

@def_methods(m, self.write)
def _(i: int, arg):
Expand Down

0 comments on commit 2bcd9d3

Please sign in to comment.