Skip to content

Commit

Permalink
More sensible interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
tilk committed Jan 17, 2025
1 parent 6643132 commit 2ba547d
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions transactron/utils/amaranth_ext/elaboratables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Literal, Optional, overload
from collections.abc import Iterable
from amaranth import *
from amaranth.lib.data import ArrayLayout
from amaranth_types import ShapeLike
from transactron.utils._typing import HasElaborate, ModuleLike, ValueLike

Expand Down Expand Up @@ -272,13 +273,13 @@ def __init__(self, input_width: int, outputs_count: int):
self.outputs_count = outputs_count

self.input = Signal(self.input_width)
self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)]
self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)]
self.outputs = Signal(ArrayLayout(range(self.input_width), self.outputs_count))
self.valids = Signal(self.outputs_count)

@staticmethod
def create(
m: Module, input_width: int, input: ValueLike, outputs_count: int = 1, name: Optional[str] = None
) -> list[tuple[Signal, Signal]]:
) -> list[tuple[Value, Value]]:
"""Syntax sugar for creating MultiPriorityEncoder
This static method allows to use MultiPriorityEncoder in a more functional
Expand Down Expand Up @@ -327,12 +328,10 @@ def create(
except AttributeError:
setattr(m.submodules, name, prio_encoder)
m.d.comb += prio_encoder.input.eq(input)
return list(zip(prio_encoder.outputs, prio_encoder.valids))
return [(prio_encoder.outputs[i], prio_encoder.valids[i]) for i in range(outputs_count)]

@staticmethod
def create_simple(
m: Module, input_width: int, input: ValueLike, name: Optional[str] = None
) -> tuple[Signal, Signal]:
def create_simple(m: Module, input_width: int, input: ValueLike, name: Optional[str] = None) -> tuple[Value, Value]:
"""Syntax sugar for creating MultiPriorityEncoder
This is the same as `create` function, but with `outputs_count` hardcoded to 1.
Expand Down Expand Up @@ -422,8 +421,8 @@ def __init__(self, input_width: int, outputs_count: int):
self.input = Signal(self.input_width)
self.first = Signal(range(self.input_width))
self.last = Signal(range(self.input_width))
self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)]
self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)]
self.outputs = Signal(ArrayLayout(range(self.input_width), self.outputs_count))
self.valids = Signal(self.outputs_count)

@staticmethod
def create(
Expand All @@ -434,7 +433,7 @@ def create(
last: ValueLike,
outputs_count: int = 1,
name: Optional[str] = None,
) -> list[tuple[Signal, Signal]]:
) -> list[tuple[Value, Value]]:
"""Syntax sugar for creating RingMultiPriorityEncoder
This static method allows to use RingMultiPriorityEncoder in a more functional
Expand Down Expand Up @@ -493,12 +492,12 @@ def create(
m.d.comb += prio_encoder.input.eq(input)
m.d.comb += prio_encoder.first.eq(first)
m.d.comb += prio_encoder.last.eq(last)
return list(zip(prio_encoder.outputs, prio_encoder.valids))
return [(prio_encoder.outputs[i], prio_encoder.valids[i]) for i in range(outputs_count)]

@staticmethod
def create_simple(
m: Module, input_width: int, input: ValueLike, first: ValueLike, last: ValueLike, name: Optional[str] = None
) -> tuple[Signal, Signal]:
) -> tuple[Value, Value]:
"""Syntax sugar for creating RingMultiPriorityEncoder
This is the same as `create` function, but with `outputs_count` hardcoded to 1.
Expand Down Expand Up @@ -563,10 +562,10 @@ def __init__(self, n: int, shape: ShapeLike):
self.n = n
self.shape = shape

self.inputs = [Signal(shape) for _ in range(n)]
self.valids = [Signal() for _ in range(n)]
self.inputs = Signal(ArrayLayout(shape, n))
self.valids = Signal(n)

self.outputs = [Signal(shape) for _ in range(n)]
self.outputs = Signal(ArrayLayout(shape, n))
self.output_cnt = Signal(range(n + 1))

def elaborate(self, platform):
Expand Down

0 comments on commit 2ba547d

Please sign in to comment.