From 2ba547da0456297b53f4408fb2795f662d736a43 Mon Sep 17 00:00:00 2001 From: Marek Materzok Date: Fri, 17 Jan 2025 13:14:28 +0100 Subject: [PATCH] More sensible interfaces --- .../utils/amaranth_ext/elaboratables.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/transactron/utils/amaranth_ext/elaboratables.py b/transactron/utils/amaranth_ext/elaboratables.py index 60c9654..253f9bb 100644 --- a/transactron/utils/amaranth_ext/elaboratables.py +++ b/transactron/utils/amaranth_ext/elaboratables.py @@ -3,6 +3,7 @@ from typing import Literal, Optional, overload from collections.abc import Iterable from amaranth import * +from amaranth.lib.data import ArrayLayout from amaranth_types import ShapeLike from transactron.utils._typing import HasElaborate, ModuleLike, ValueLike @@ -272,13 +273,13 @@ def __init__(self, input_width: int, outputs_count: int): self.outputs_count = outputs_count self.input = Signal(self.input_width) - self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)] - self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)] + self.outputs = Signal(ArrayLayout(range(self.input_width), self.outputs_count)) + self.valids = Signal(self.outputs_count) @staticmethod def create( m: Module, input_width: int, input: ValueLike, outputs_count: int = 1, name: Optional[str] = None - ) -> list[tuple[Signal, Signal]]: + ) -> list[tuple[Value, Value]]: """Syntax sugar for creating MultiPriorityEncoder This static method allows to use MultiPriorityEncoder in a more functional @@ -327,12 +328,10 @@ def create( except AttributeError: setattr(m.submodules, name, prio_encoder) m.d.comb += prio_encoder.input.eq(input) - return list(zip(prio_encoder.outputs, prio_encoder.valids)) + return [(prio_encoder.outputs[i], prio_encoder.valids[i]) for i in range(outputs_count)] @staticmethod - def create_simple( - m: Module, input_width: int, input: ValueLike, name: Optional[str] = None - ) -> tuple[Signal, Signal]: + def create_simple(m: Module, input_width: int, input: ValueLike, name: Optional[str] = None) -> tuple[Value, Value]: """Syntax sugar for creating MultiPriorityEncoder This is the same as `create` function, but with `outputs_count` hardcoded to 1. @@ -422,8 +421,8 @@ def __init__(self, input_width: int, outputs_count: int): self.input = Signal(self.input_width) self.first = Signal(range(self.input_width)) self.last = Signal(range(self.input_width)) - self.outputs = [Signal(range(self.input_width), name=f"output_{i}") for i in range(self.outputs_count)] - self.valids = [Signal(name=f"valid_{i}") for i in range(self.outputs_count)] + self.outputs = Signal(ArrayLayout(range(self.input_width), self.outputs_count)) + self.valids = Signal(self.outputs_count) @staticmethod def create( @@ -434,7 +433,7 @@ def create( last: ValueLike, outputs_count: int = 1, name: Optional[str] = None, - ) -> list[tuple[Signal, Signal]]: + ) -> list[tuple[Value, Value]]: """Syntax sugar for creating RingMultiPriorityEncoder This static method allows to use RingMultiPriorityEncoder in a more functional @@ -493,12 +492,12 @@ def create( m.d.comb += prio_encoder.input.eq(input) m.d.comb += prio_encoder.first.eq(first) m.d.comb += prio_encoder.last.eq(last) - return list(zip(prio_encoder.outputs, prio_encoder.valids)) + return [(prio_encoder.outputs[i], prio_encoder.valids[i]) for i in range(outputs_count)] @staticmethod def create_simple( m: Module, input_width: int, input: ValueLike, first: ValueLike, last: ValueLike, name: Optional[str] = None - ) -> tuple[Signal, Signal]: + ) -> tuple[Value, Value]: """Syntax sugar for creating RingMultiPriorityEncoder This is the same as `create` function, but with `outputs_count` hardcoded to 1. @@ -563,10 +562,10 @@ def __init__(self, n: int, shape: ShapeLike): self.n = n self.shape = shape - self.inputs = [Signal(shape) for _ in range(n)] - self.valids = [Signal() for _ in range(n)] + self.inputs = Signal(ArrayLayout(shape, n)) + self.valids = Signal(n) - self.outputs = [Signal(shape) for _ in range(n)] + self.outputs = Signal(ArrayLayout(shape, n)) self.output_cnt = Signal(range(n + 1)) def elaborate(self, platform):