Skip to content

Commit

Permalink
integration test!
Browse files Browse the repository at this point in the history
  • Loading branch information
teqdruid committed Nov 22, 2024
1 parent b62bc9b commit 9640109
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
26 changes: 25 additions & 1 deletion frontends/PyCDE/integration_test/esi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import pycde
from pycde import (AppID, Clock, Module, Reset, modparams, generator)
from pycde.bsp import cosim
from pycde.common import Constant
from pycde.common import Constant, Input, Output
from pycde.constructs import ControlReg, Reg, Wire
from pycde.esi import ChannelService, FuncService, MMIO, MMIOReadWriteCmdType
from pycde.types import (Bits, Channel, UInt)
from pycde.behavioral import If, Else, EndIf
from pycde.handshake import Func

import sys

Expand Down Expand Up @@ -107,6 +108,28 @@ def construct(ports):
ChannelService.to_host(AppID("const_producer"), ch)


class JoinFunc(Func):
a = Input(UInt(32))
b = Input(UInt(32))
x = Output(UInt(32))

@generator
def construct(ports):
ports.x = (ports.a + ports.b).as_uint(32)


class Join(Module):
clk = Clock()
rst = Reset()

@generator
def construct(ports):
a = ChannelService.from_host(AppID("join_a"), UInt(32))
b = ChannelService.from_host(AppID("join_b"), UInt(32))
f = JoinFunc(clk=ports.clk, rst=ports.rst, a=a, b=b)
ChannelService.to_host(AppID("join_x"), f.x)


class Top(Module):
clk = Clock()
rst = Reset()
Expand All @@ -118,6 +141,7 @@ def construct(ports):
MMIOClient(i)()
MMIOReadWriteClient(clk=ports.clk, rst=ports.rst)
ConstProducer(clk=ports.clk, rst=ports.rst)
Join(clk=ports.clk, rst=ports.rst)


if __name__ == "__main__":
Expand Down
17 changes: 17 additions & 0 deletions frontends/PyCDE/integration_test/test_software/esi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,20 @@ def read_offset_check(i: int, add_amt: int):
producer.disconnect()
print(f"data: {data}")
assert data == 42

################################################################################
# Handshake Join
################################################################################

a = d.ports[esi.AppID("join_a")].write_port("data")
a.connect()
b = d.ports[esi.AppID("join_b")].write_port("data")
b.connect()
x = d.ports[esi.AppID("join_x")].read_port("data")
x.connect()

a.write(15)
b.write(24)
xdata = x.read()
print(f"join: {xdata}")
assert xdata == 15 + 24
6 changes: 1 addition & 5 deletions lib/Conversion/HandshakeToDC/HandshakeToDC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,10 +762,6 @@ class HandshakeToDCPass
public:
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
auto targetModifier = [](mlir::ConversionTarget &target) {
// target.addLegalDialect<hw::HWDialect, func::FuncDialect>();
};

auto patternBuilder = [&](TypeConverter &typeConverter,
handshaketodc::ConvertedOps &convertedOps,
RewritePatternSet &patterns) {
Expand All @@ -774,7 +770,7 @@ class HandshakeToDCPass
patterns.add<ReturnOpConversion>(typeConverter, mod.getContext());
};

LogicalResult res = runHandshakeToDC(mod, patternBuilder, targetModifier);
LogicalResult res = runHandshakeToDC(mod, patternBuilder, nullptr);
if (failed(res))
signalPassFailure();
}
Expand Down

0 comments on commit 9640109

Please sign in to comment.