From 40c201414af3ff687ebd585673feca339b3009aa Mon Sep 17 00:00:00 2001 From: John Demme Date: Fri, 20 Dec 2024 18:53:00 -0500 Subject: [PATCH] [PyCDE] Add fork, join, and merge channel functions (#8011) - .fork creates two new channels, waits until they are both available, then accepts an input. Also buffer the output channels to avoid combinational loops. - Channel.join waits on two channels then creates a message on the one output channel containing a struct with field 'a' equal to input channel A's content and likewise for channel B. - Channel.merge funnels two channels together into a single output stream. This is functionality which really should be handled by the DC dialect but it's not ready for primetime. --- .../PyCDE/integration_test/esi_advanced.py | 101 ++++++++++++++++++ .../test_software/esi_advanced.py | 53 +++++++++ frontends/PyCDE/src/pycde/bsp/__init__.py | 14 +++ frontends/PyCDE/src/pycde/signals.py | 18 ++++ frontends/PyCDE/src/pycde/types.py | 62 ++++++++++- frontends/PyCDE/test/test_esi_advanced.py | 88 +++++++++++++++ 6 files changed, 332 insertions(+), 4 deletions(-) create mode 100644 frontends/PyCDE/integration_test/esi_advanced.py create mode 100644 frontends/PyCDE/integration_test/test_software/esi_advanced.py create mode 100644 frontends/PyCDE/test/test_esi_advanced.py diff --git a/frontends/PyCDE/integration_test/esi_advanced.py b/frontends/PyCDE/integration_test/esi_advanced.py new file mode 100644 index 000000000000..f0967d598157 --- /dev/null +++ b/frontends/PyCDE/integration_test/esi_advanced.py @@ -0,0 +1,101 @@ +# REQUIRES: esi-runtime, esi-cosim, rtl-sim +# RUN: rm -rf %t +# RUN: mkdir %t && cd %t +# RUN: %PYTHON% %s %t 2>&1 +# RUN: esi-cosim.py -- %PYTHON% %S/test_software/esi_advanced.py cosim env + +import sys + +from pycde import generator, Clock, Module, Reset, System +from pycde.bsp import get_bsp +from pycde.common import InputChannel, OutputChannel, Output +from pycde.types import Bits, UInt +from pycde import esi + + +class Merge(Module): + clk = Clock() + rst = Reset() + a = InputChannel(UInt(8)) + b = InputChannel(UInt(8)) + + x = OutputChannel(UInt(8)) + + @generator + def build(ports): + chan = ports.a.type.merge(ports.a, ports.b) + ports.x = chan + + +class Join(Module): + clk = Clock() + rst = Reset() + a = InputChannel(UInt(8)) + b = InputChannel(UInt(8)) + + x = OutputChannel(UInt(9)) + + @generator + def build(ports): + joined = ports.a.type.join(ports.a, ports.b) + ports.x = joined.transform(lambda x: x.a + x.b) + + +class Fork(Module): + clk = Clock() + rst = Reset() + a = InputChannel(UInt(8)) + + x = OutputChannel(UInt(8)) + y = OutputChannel(UInt(8)) + + @generator + def build(ports): + x, y = ports.a.fork(ports.clk, ports.rst) + ports.x = x + ports.y = y + + +class Top(Module): + clk = Clock() + rst = Reset() + + @generator + def build(ports): + clk = ports.clk + rst = ports.rst + merge_a = esi.ChannelService.from_host(esi.AppID("merge_a"), + UInt(8)).buffer(clk, rst, 1) + merge_b = esi.ChannelService.from_host(esi.AppID("merge_b"), + UInt(8)).buffer(clk, rst, 1) + merge = Merge("merge_i8", + clk=ports.clk, + rst=ports.rst, + a=merge_a, + b=merge_b) + esi.ChannelService.to_host(esi.AppID("merge_x"), + merge.x.buffer(clk, rst, 1)) + + join_a = esi.ChannelService.from_host(esi.AppID("join_a"), + UInt(8)).buffer(clk, rst, 1) + join_b = esi.ChannelService.from_host(esi.AppID("join_b"), + UInt(8)).buffer(clk, rst, 1) + join = Join("join_i8", clk=ports.clk, rst=ports.rst, a=join_a, b=join_b) + esi.ChannelService.to_host( + esi.AppID("join_x"), + join.x.buffer(clk, rst, 1).transform(lambda x: x.as_uint(16))) + + fork_a = esi.ChannelService.from_host(esi.AppID("fork_a"), + UInt(8)).buffer(clk, rst, 1) + fork = Fork("fork_i8", clk=ports.clk, rst=ports.rst, a=fork_a) + esi.ChannelService.to_host(esi.AppID("fork_x"), fork.x.buffer(clk, rst, 1)) + esi.ChannelService.to_host(esi.AppID("fork_y"), fork.y.buffer(clk, rst, 1)) + + +if __name__ == "__main__": + bsp = get_bsp(sys.argv[2] if len(sys.argv) > 2 else None) + s = System(bsp(Top), name="ESIAdvanced", output_directory=sys.argv[1]) + s.generate() + s.run_passes() + s.compile() + s.package() diff --git a/frontends/PyCDE/integration_test/test_software/esi_advanced.py b/frontends/PyCDE/integration_test/test_software/esi_advanced.py new file mode 100644 index 000000000000..aa07f66522f8 --- /dev/null +++ b/frontends/PyCDE/integration_test/test_software/esi_advanced.py @@ -0,0 +1,53 @@ +import esiaccel as esi +import sys + +platform = sys.argv[1] +acc = esi.AcceleratorConnection(platform, sys.argv[2]) + +d = acc.build_accelerator() + +merge_a = d.ports[esi.AppID("merge_a")].write_port("data") +merge_a.connect() +merge_b = d.ports[esi.AppID("merge_b")].write_port("data") +merge_b.connect() +merge_x = d.ports[esi.AppID("merge_x")].read_port("data") +merge_x.connect() + +for i in range(10, 15): + merge_a.write(i) + merge_b.write(i + 10) + x1 = merge_x.read() + x2 = merge_x.read() + print(f"merge_a: {i}, merge_b: {i + 10}, " + f"merge_x 1: {x1}, merge_x 2: {x2}") + assert x1 == i + 10 or x1 == i + assert x2 == i + 10 or x2 == i + assert x1 != x2 + +join_a = d.ports[esi.AppID("join_a")].write_port("data") +join_a.connect() +join_b = d.ports[esi.AppID("join_b")].write_port("data") +join_b.connect() +join_x = d.ports[esi.AppID("join_x")].read_port("data") +join_x.connect() + +for i in range(15, 27): + join_a.write(i) + join_b.write(i + 10) + x = join_x.read() + print(f"join_a: {i}, join_b: {i + 10}, join_x: {x}") + assert x == (i + i + 10) & 0xFFFF + +fork_a = d.ports[esi.AppID("fork_a")].write_port("data") +fork_a.connect() +fork_x = d.ports[esi.AppID("fork_x")].read_port("data") +fork_x.connect() +fork_y = d.ports[esi.AppID("fork_y")].read_port("data") +fork_y.connect() + +for i in range(27, 33): + fork_a.write(i) + x = fork_x.read() + y = fork_y.read() + print(f"fork_a: {i}, fork_x: {x}, fork_y: {y}") + assert x == y diff --git a/frontends/PyCDE/src/pycde/bsp/__init__.py b/frontends/PyCDE/src/pycde/bsp/__init__.py index c4c37e67ab9b..9e2f7f23715b 100644 --- a/frontends/PyCDE/src/pycde/bsp/__init__.py +++ b/frontends/PyCDE/src/pycde/bsp/__init__.py @@ -2,5 +2,19 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional + from .cosim import CosimBSP from .xrt import XrtBSP + + +def get_bsp(name: Optional[str] = None): + if name is None or name == "cosim": + return CosimBSP + elif name == "xrt": + return XrtBSP + elif name == "xrt_cosim": + from .xrt import XrtCosimBSP + return XrtCosimBSP + else: + raise ValueError(f"Unknown bsp type: {name}") diff --git a/frontends/PyCDE/src/pycde/signals.py b/frontends/PyCDE/src/pycde/signals.py index 689b5c816bc5..8cd0e0459d1f 100644 --- a/frontends/PyCDE/src/pycde/signals.py +++ b/frontends/PyCDE/src/pycde/signals.py @@ -155,6 +155,9 @@ def name(self, new: str): else: self._name = new + def get_name(self, default: str = "") -> str: + return self.name if self.name is not None else default + @property def appid(self) -> Optional[object]: # Optional AppID. from .module import AppID @@ -752,6 +755,21 @@ def transform(self, transform: Callable[[Signal], Signal]) -> ChannelSignal: ready_wire.assign(ready) return ret_chan + def fork(self, clk, rst) -> Tuple[ChannelSignal, ChannelSignal]: + """Fork the channel into two channels, returning the two new channels.""" + from .constructs import Wire + from .types import Bits + both_ready = Wire(Bits(1)) + both_ready.name = self.get_name() + "_fork_both_ready" + data, valid = self.unwrap(both_ready) + valid_gate = both_ready & valid + a, a_rdy = self.type.wrap(data, valid_gate) + b, b_rdy = self.type.wrap(data, valid_gate) + abuf = a.buffer(clk, rst, 1) + bbuf = b.buffer(clk, rst, 1) + both_ready.assign(a_rdy & b_rdy) + return abuf, bbuf + class BundleSignal(Signal): """Signal for types.Bundle.""" diff --git a/frontends/PyCDE/src/pycde/types.py b/frontends/PyCDE/src/pycde/types.py index 62b26fc7c970..008923717539 100644 --- a/frontends/PyCDE/src/pycde/types.py +++ b/frontends/PyCDE/src/pycde/types.py @@ -598,7 +598,7 @@ def inner(self): return self.inner_type def wrap(self, value, - valueOrEmpty) -> typing.Tuple["ChannelSignal", "BitsSignal"]: + valid_or_empty) -> typing.Tuple["ChannelSignal", "BitsSignal"]: """Wrap a data signal and valid signal into a data channel signal and a ready signal.""" @@ -608,21 +608,75 @@ def wrap(self, value, # one. from .dialects import esi + from .signals import Signal signaling = self.signaling if signaling == ChannelSignaling.ValidReady: - value = self.inner_type(value) - valid = types.i1(valueOrEmpty) + if not isinstance(value, Signal): + value = self.inner_type(value) + elif value.type != self.inner_type: + raise TypeError( + f"Expected signal of type {self.inner_type}, got {value.type}") + valid = Bits(1)(valid_or_empty) wrap_op = esi.WrapValidReadyOp(self._type, types.i1, value.value, valid.value) return wrap_op[0], wrap_op[1] elif signaling == ChannelSignaling.FIFO: value = self.inner_type(value) - empty = types.i1(valueOrEmpty) + empty = Bits(1)(valid_or_empty) wrap_op = esi.WrapFIFOOp(self._type, types.i1, value.value, empty.value) return wrap_op[0], wrap_op[1] else: raise TypeError("Unknown signaling standard") + def _join(self, a: "ChannelSignal", b: "ChannelSignal") -> "ChannelSignal": + """Join two channels into a single channel. The resulting type is a struct + with two fields, 'a' and 'b' wherein 'a' is the data from channel a and 'b' + is the data from channel b.""" + + from .constructs import Wire + both_ready = Wire(Bits(1)) + a_data, a_valid = a.unwrap(both_ready) + b_data, b_valid = b.unwrap(both_ready) + both_valid = a_valid & b_valid + result_data = self.inner_type({"a": a_data, "b": b_data}) + result_chan, result_ready = self.wrap(result_data, both_valid) + both_ready.assign(result_ready & both_valid) + return result_chan + + @staticmethod + def join(a: "ChannelSignal", b: "ChannelSignal") -> "ChannelSignal": + """Join two channels into a single channel. The resulting type is a struct + with two fields, 'a' and 'b' wherein 'a' is the data from channel a and 'b' + is the data from channel b.""" + + from .types import Channel, StructType + return Channel( + StructType([("a", a.type.inner_type), + ("b", b.type.inner_type)]))._join(a, b) + + def merge(self, a: "ChannelSignal", b: "ChannelSignal") -> "ChannelSignal": + """Merge two channels into a single channel, selecting a message from either + one. May implement any sort of fairness policy. Both channels must be of the + same type. Returns both the merged channel.""" + + from .constructs import Mux, Wire + a_ready = Wire(Bits(1)) + b_ready = Wire(Bits(1)) + a_data, a_valid = a.unwrap(a_ready) + b_data, b_valid = b.unwrap(b_ready) + + sel_a = a_valid + sel_b = ~sel_a + out_ready = Wire(Bits(1)) + a_ready.assign(sel_a & out_ready) + b_ready.assign(sel_b & out_ready) + + valid = (sel_a & a_valid) | (sel_b & b_valid) + data = Mux(sel_a, b_data, a_data) + chan, ready = self.wrap(data, valid) + out_ready.assign(ready) + return chan + @dataclass class BundledChannel: diff --git a/frontends/PyCDE/test/test_esi_advanced.py b/frontends/PyCDE/test/test_esi_advanced.py new file mode 100644 index 000000000000..8a5b6e74087d --- /dev/null +++ b/frontends/PyCDE/test/test_esi_advanced.py @@ -0,0 +1,88 @@ +# RUN: %PYTHON% %s | FileCheck %s + +from pycde import generator, Clock, Module, Reset +from pycde.common import InputChannel, OutputChannel +from pycde.testing import unittestmodule +from pycde.types import Bits, UInt + +# CHECK-LABEL: hw.module @Merge(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel, in %b : !esi.channel, out x : !esi.channel) +# CHECK-NEXT: %rawOutput, %valid = esi.unwrap.vr %a, [[R1:%.+]] : i8 +# CHECK-NEXT: %rawOutput_0, %valid_1 = esi.unwrap.vr %b, [[R2:%.+]] : i8 +# CHECK-NEXT: %true = hw.constant true +# CHECK-NEXT: [[R0:%.+]] = comb.xor bin %valid, %true : i1 +# CHECK-NEXT: [[R1]] = comb.and bin %valid, %ready : i1 +# CHECK-NEXT: [[R2]] = comb.and bin [[R0]], %ready : i1 +# CHECK-NEXT: [[R3:%.+]] = comb.and bin %valid, %valid : i1 +# CHECK-NEXT: [[R4:%.+]] = comb.and bin [[R0]], %valid_1 : i1 +# CHECK-NEXT: [[R5:%.+]] = comb.or bin [[R3]], [[R4]] : i1 +# CHECK-NEXT: [[R6:%.+]] = comb.mux bin %valid, %rawOutput, %rawOutput_0 +# CHECK-NEXT: %chanOutput, %ready = esi.wrap.vr [[R6]], [[R5]] : i8 +# CHECK-NEXT: hw.output %chanOutput : !esi.channel + + +@unittestmodule() +class Merge(Module): + clk = Clock() + rst = Reset() + a = InputChannel(Bits(8)) + b = InputChannel(Bits(8)) + + x = OutputChannel(Bits(8)) + + @generator + def build(ports): + chan = ports.a.type.merge(ports.a, ports.b) + ports.x = chan + + +# CHECK-LABEL: hw.module @Join(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel, in %b : !esi.channel, out x : !esi.channel) +# CHECK-NEXT: %rawOutput, %valid = esi.unwrap.vr %a, [[R2:%.+]] : ui8 +# CHECK-NEXT: %rawOutput_0, %valid_1 = esi.unwrap.vr %b, [[R2]] : ui8 +# CHECK-NEXT: [[R0:%.+]] = comb.and bin %valid, %valid_1 : i1 +# CHECK-NEXT: [[R1:%.+]] = hw.struct_create (%rawOutput, %rawOutput_0) : !hw.struct +# CHECK-NEXT: %chanOutput, %ready = esi.wrap.vr [[R1]], [[R0]] : !hw.struct +# CHECK-NEXT: [[R2]] = comb.and bin %ready, [[R0]] : i1 +# CHECK-NEXT: %rawOutput_2, %valid_3 = esi.unwrap.vr %chanOutput, %ready_7 : !hw.struct +# CHECK-NEXT: %a_4 = hw.struct_extract %rawOutput_2["a"] : !hw.struct +# CHECK-NEXT: %b_5 = hw.struct_extract %rawOutput_2["b"] : !hw.struct +# CHECK-NEXT: [[R3:%.+]] = hwarith.add %a_4, %b_5 : (ui8, ui8) -> ui9 +# CHECK-NEXT: %chanOutput_6, %ready_7 = esi.wrap.vr [[R3]], %valid_3 : ui9 +# CHECK-NEXT: hw.output %chanOutput_6 : !esi.channel +@unittestmodule(run_passes=True, emit_outputs=True) +class Join(Module): + clk = Clock() + rst = Reset() + a = InputChannel(UInt(8)) + b = InputChannel(UInt(8)) + + x = OutputChannel(UInt(9)) + + @generator + def build(ports): + joined = ports.a.type.join(ports.a, ports.b) + ports.x = joined.transform(lambda x: x.a + x.b) + + +# CHECK-LABEL: hw.module @Fork(in %clk : !seq.clock, in %rst : i1, in %a : !esi.channel, out x : !esi.channel, out y : !esi.channel) +# CHECK-NEXT: %rawOutput, %valid = esi.unwrap.vr %a, [[R3:%.+]] : ui8 +# CHECK-NEXT: [[R0:%.+]] = comb.and bin [[R3]], %valid : i1 +# CHECK-NEXT: %chanOutput, %ready = esi.wrap.vr %rawOutput, [[R0]] : ui8 +# CHECK-NEXT: %chanOutput_0, %ready_1 = esi.wrap.vr %rawOutput, [[R0]] : ui8 +# CHECK-NEXT: [[R1:%.+]] = esi.buffer %clk, %rst, %chanOutput {stages = 1 : i64} : ui8 +# CHECK-NEXT: [[R2:%.+]] = esi.buffer %clk, %rst, %chanOutput_0 {stages = 1 : i64} : ui8 +# CHECK-NEXT: [[R3]] = comb.and bin %ready, %ready_1 : i1 +# CHECK-NEXT: hw.output [[R1]], [[R2]] : !esi.channel, !esi.channel +@unittestmodule(run_passes=True, emit_outputs=True) +class Fork(Module): + clk = Clock() + rst = Reset() + a = InputChannel(UInt(8)) + + x = OutputChannel(UInt(8)) + y = OutputChannel(UInt(8)) + + @generator + def build(ports): + x, y = ports.a.fork(ports.clk, ports.rst) + ports.x = x + ports.y = y