From 1df1ef4802ec5acf0ebfd80de6cd625763cb5c86 Mon Sep 17 00:00:00 2001 From: John Demme Date: Mon, 2 Sep 2024 16:40:16 -0700 Subject: [PATCH] [ESI][Runtime] Address MMIO regions symbolically (#7568) Access MMIO address regions by the AppID of the requestor. Also provide access to the MMIO space descriptors. --- .../test_software/esi_test.py | 78 ++++++++------- frontends/PyCDE/src/pycde/bsp/common.py | 12 ++- frontends/PyCDE/src/pycde/esi.py | 6 +- include/circt/Dialect/ESI/ESIStdServices.td | 4 + .../ESI/runtime/cpp/include/esi/Design.h | 4 + .../ESI/runtime/cpp/include/esi/Services.h | 76 +++++++++++++++ lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp | 39 ++++++-- lib/Dialect/ESI/runtime/cpp/lib/Services.cpp | 95 ++++++++++++++++++- .../runtime/python/esiaccel/accelerator.py | 9 +- .../runtime/python/esiaccel/esiCppAccel.cpp | 92 ++++++++++++------ .../ESI/runtime/python/esiaccel/types.py | 24 +++++ 11 files changed, 364 insertions(+), 75 deletions(-) diff --git a/frontends/PyCDE/integration_test/test_software/esi_test.py b/frontends/PyCDE/integration_test/test_software/esi_test.py index 352968fc56bf..605d27a1dec1 100644 --- a/frontends/PyCDE/integration_test/test_software/esi_test.py +++ b/frontends/PyCDE/integration_test/test_software/esi_test.py @@ -1,4 +1,5 @@ import esiaccel as esi +from esiaccel.types import MMIORegion import sys import time @@ -10,64 +11,77 @@ data = mmio.read(8) assert data == 0x207D98E5E5100E51 +assert acc.sysinfo().esi_version() == 0 +m = acc.manifest() +assert m.api_version == 0 +print(m.type_table) + +d = acc.build_accelerator() + +mmio_svc: esi.accelerator.MMIO +for svc in d.services: + if isinstance(svc, esi.accelerator.MMIO): + mmio_svc = svc + break + +for id, region in mmio_svc.regions.items(): + print(f"Region {id}: {region.base} - {region.base + region.size}") + +assert len(mmio_svc.regions) == 4 + ################################################################################ # MMIOClient tests ################################################################################ -def read_offset(mmio_offset: int, offset: int, add_amt: int): - data = mmio.read(mmio_offset + offset) +def read_offset(mmio_x: MMIORegion, offset: int, add_amt: int): + data = mmio_x.read(offset) if data == add_amt + offset: - print(f"PASS: read_offset({mmio_offset}, {offset}, {add_amt}) -> {data}") + print(f"PASS: read_offset({offset}, {add_amt}) -> {data}") else: - assert False, f"read_offset({mmio_offset}, {offset}, {add_amt}) -> {data}" + assert False, f"read_offset({offset}, {add_amt}) -> {data}" + +mmio9 = d.ports[esi.AppID("mmio_client", 9)] +read_offset(mmio9, 0, 9) +read_offset(mmio9, 13, 9) -# MMIO offset into mmio_client[9]. TODO: get this from the manifest. API coming. -mmio_client_9_offset = 131072 -read_offset(mmio_client_9_offset, 0, 9) -read_offset(mmio_client_9_offset, 13, 9) +mmio4 = d.ports[esi.AppID("mmio_client", 4)] +read_offset(mmio4, 0, 4) +read_offset(mmio4, 13, 4) -# MMIO offset into mmio_client[4]. -mmio_client_4_offset = 65536 -read_offset(mmio_client_4_offset, 0, 4) -read_offset(mmio_client_4_offset, 13, 4) +mmio14 = d.ports[esi.AppID("mmio_client", 14)] +read_offset(mmio14, 0, 14) +read_offset(mmio14, 13, 14) -# MMIO offset into mmio_client[14]. -mmio_client_14_offset = 196608 -read_offset(mmio_client_14_offset, 0, 14) -read_offset(mmio_client_14_offset, 13, 14) +assert mmio14.descriptor.base == 196608 +assert mmio14.descriptor.size == 65536 ################################################################################ # MMIOReadWriteClient tests ################################################################################ -mmio_rw_client_offset = 262144 +mmio_rw = d.ports[esi.AppID("mmio_rw_client")] def read_offset_check(i: int, add_amt: int): - d = mmio.read(mmio_rw_client_offset + i) - if d == i + 9: - print(f"PASS: read_offset_check({mmio_rw_client_offset} + {i}: {d}") + d = mmio_rw.read(i) + if d == i + add_amt: + print(f"PASS: read_offset_check({i}): {d}") else: - assert False, f": read_offset_check({mmio_rw_client_offset} + {i}: {d}" + assert False, f": read_offset_check({i}): {d}" -mmio.write(mmio_rw_client_offset + 8, 9) -read_offset_check(0, 9) -read_offset_check(12, 9) -read_offset_check(0x1400, 9) +add_amt = 137 +mmio_rw.write(8, add_amt) +read_offset_check(0, add_amt) +read_offset_check(12, add_amt) +read_offset_check(0x1400, add_amt) ################################################################################ # Manifest tests ################################################################################ -assert acc.sysinfo().esi_version() == 0 -m = acc.manifest() -assert m.api_version == 0 -print(m.type_table) - -d = acc.build_accelerator() loopback = d.children[esi.AppID("loopback")] recv = loopback.ports[esi.AppID("add")].read_port("result") recv.connect() @@ -108,7 +122,7 @@ def read_offset_check(i: int, add_amt: int): if write_succeeded: break -assert (write_succeeded, "Non-blocking write failed") +assert write_succeeded, "Non-blocking write failed" resp = recv.read() print(f"data: {data}") print(f"resp: {resp}") diff --git a/frontends/PyCDE/src/pycde/bsp/common.py b/frontends/PyCDE/src/pycde/bsp/common.py index 9373bee5738f..25309a37b372 100644 --- a/frontends/PyCDE/src/pycde/bsp/common.py +++ b/frontends/PyCDE/src/pycde/bsp/common.py @@ -152,11 +152,19 @@ def build_table(bundles) -> Tuple[Dict[int, AssignableSignal], int]: for bundle in bundles.to_client_reqs: if bundle.port == 'read': table[offset] = bundle - bundle.add_record({"offset": offset, "type": "ro"}) + bundle.add_record({ + "offset": offset, + "size": ChannelMMIO.RegisterSpace, + "type": "ro" + }) offset += ChannelMMIO.RegisterSpace elif bundle.port == 'read_write': table[offset] = bundle - bundle.add_record({"offset": offset, "type": "rw"}) + bundle.add_record({ + "offset": offset, + "size": ChannelMMIO.RegisterSpace, + "type": "rw" + }) offset += ChannelMMIO.RegisterSpace else: assert False, "Unrecognized port name." diff --git a/frontends/PyCDE/src/pycde/esi.py b/frontends/PyCDE/src/pycde/esi.py index 170a7b4b8c8e..effb25ae3d14 100644 --- a/frontends/PyCDE/src/pycde/esi.py +++ b/frontends/PyCDE/src/pycde/esi.py @@ -7,7 +7,7 @@ from .module import generator, Module, ModuleLikeBuilderBase, PortProxyBase from .signals import (BitsSignal, BundleSignal, ChannelSignal, Signal, _FromCirctValue) -from .support import get_user_loc +from .support import _obj_to_attribute, get_user_loc, obj_to_typed_attribute from .system import System from .types import (Bits, Bundle, BundledChannel, Channel, ChannelDirection, StructType, Type, UInt, types, _FromCirctType) @@ -144,9 +144,9 @@ def add_record(self, details: Dict[str, str]): give the runtime necessary information about how to connect to the client through the generated service. For instance, offsets into an MMIO space.""" - ir_details: Dict[str, ir.StringAttr] = {} + ir_details: Dict[str, ir.Attribute] = {} for k, v in details.items(): - ir_details[k] = ir.StringAttr.get(str(v)) + ir_details[k] = _obj_to_attribute(v) with get_user_loc(), ir.InsertionPoint.at_block_begin( self.rec.reqDetails.blocks[0]): raw_esi.ServiceImplClientRecordOp( diff --git a/include/circt/Dialect/ESI/ESIStdServices.td b/include/circt/Dialect/ESI/ESIStdServices.td index 4d4ed6c6629e..1bd034ba1c18 100644 --- a/include/circt/Dialect/ESI/ESIStdServices.td +++ b/include/circt/Dialect/ESI/ESIStdServices.td @@ -97,4 +97,8 @@ def MMIOServiceDeclOp: ESI_Op<"service.std.mmio", let assemblyFormat = [{ $sym_name attr-dict }]; + + let extraClassDeclaration = [{ + std::optional getTypeName() { return "esi.service.std.mmio"; } + }]; } diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h index b95421e0cd01..b676de993305 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Design.h @@ -76,6 +76,10 @@ class HWModule { const std::map &getPorts() const { return portIndex; } + /// Access the services provided by this module. + const std::vector &getServices() const { + return services; + } /// Master poll method. Calls the `poll` method on all locally owned ports and /// the master `poll` method on all of the children. Returns true if any of diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Services.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Services.h index c55b69f79d5f..67b7570f8ba8 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Services.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Services.h @@ -21,6 +21,7 @@ #define ESI_RUNTIME_SERVICES_H #include "esi/Common.h" +#include "esi/Context.h" #include "esi/Ports.h" #include @@ -48,6 +49,19 @@ class Service { virtual std::string getServiceSymbol() const = 0; + /// Create a "child" service of this service. Does not have to be the same + /// service type, but typically is. Used when a service already exists in the + /// active services table, but a new one wants to replace it. Useful for cases + /// where the child service needs to use the parent service. Defaults to + /// calling the `getService` method on `AcceleratorConnection` to get the + /// global service, implying that the child service does not need to use the + /// service it is replacing. + virtual Service *getChildService(AcceleratorConnection *conn, + Service::Type service, AppIDPath id = {}, + std::string implName = {}, + ServiceImplDetails details = {}, + HWClientDetails clients = {}); + /// Get specialized port for this service to attach to the given appid path. /// Null returns mean nothing to attach. virtual ServicePort *getPort(AppIDPath id, const BundleType *type, @@ -94,10 +108,72 @@ class SysInfo : public Service { class MMIO : public Service { public: + static constexpr std::string_view StdName = "esi.service.std.mmio"; + + /// Describe a region (slice) of MMIO space. + struct RegionDescriptor { + uint32_t base; + uint32_t size; + }; + + MMIO(Context &ctxt, AppIDPath idPath, std::string implName, + const ServiceImplDetails &details, const HWClientDetails &clients); + MMIO() = default; virtual ~MMIO() = default; + + /// Read a 64-bit value from the global MMIO space. virtual uint64_t read(uint32_t addr) const = 0; + /// Write a 64-bit value to the global MMIO space. virtual void write(uint32_t addr, uint64_t data) = 0; + /// Get the regions of MMIO space that this service manages. Otherwise known + /// as the base address table. + const std::map &getRegions() const { + return regions; + } + + /// If the service is a MMIO service, return a region of the MMIO space which + /// peers into ours. + virtual Service *getChildService(AcceleratorConnection *conn, + Service::Type service, AppIDPath id = {}, + std::string implName = {}, + ServiceImplDetails details = {}, + HWClientDetails clients = {}) override; + virtual std::string getServiceSymbol() const override; + + /// Get a MMIO region port for a particular region descriptor. + virtual ServicePort *getPort(AppIDPath id, const BundleType *type, + const std::map &, + AcceleratorConnection &) const override; + +private: + /// MMIO base address table. + std::map regions; + +public: + /// A "slice" of some parent MMIO space. + class MMIORegion : public ServicePort { + friend class MMIO; + MMIORegion(AppID id, MMIO *parent, RegionDescriptor desc); + + public: + /// Get the offset (and size) of the region in the parent (usually global) + /// MMIO address space. + virtual RegionDescriptor getDescriptor() const { return desc; }; + /// Read a 64-bit value from this region, not the global address space. + virtual uint64_t read(uint32_t addr) const; + /// Write a 64-bit value to this region, not the global address space. + virtual void write(uint32_t addr, uint64_t data); + + virtual std::optional toString() const override { + return "MMIO region " + toHex(desc.base) + " - " + + toHex(desc.base + desc.size); + } + + private: + MMIO *parent; + RegionDescriptor desc; + }; }; /// Implement the SysInfo API for a standard MMIO protocol. diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp index 56e7986179e4..bc096f5faab2 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp @@ -147,7 +147,21 @@ std::any Manifest::Impl::getAny(const nlohmann::json &value) const { std::map ret; for (auto &e : json.items()) ret[e.key()] = getAny(e.value()); - return ret; + + // If this can be converted to a constant, do so. + if (ret.size() != 2 || !ret.contains("type") || !ret.contains("value")) + return ret; + std::any value = ret.at("value"); + std::any typeID = ret.at("type"); + if (typeID.type() != typeid(std::string)) + return ret; + std::optional type = + getType(std::any_cast(type)); + if (!type) + return ret; + // TODO: Check or guide the conversion of the value to the type based on the + // type. + return Constant{value, type}; }; auto getArray = [this](const nlohmann::json &json) -> std::any { @@ -160,10 +174,10 @@ std::any Manifest::Impl::getAny(const nlohmann::json &value) const { auto getValue = [&](const nlohmann::json &innerValue) -> std::any { if (innerValue.is_string()) return innerValue.get(); - else if (innerValue.is_number_integer()) - return innerValue.get(); else if (innerValue.is_number_unsigned()) return innerValue.get(); + else if (innerValue.is_number_integer()) + return innerValue.get(); else if (innerValue.is_number_float()) return innerValue.get(); else if (innerValue.is_boolean()) @@ -368,11 +382,20 @@ Manifest::Impl::getService(AppIDPath idPath, AcceleratorConnection &acc, } // Create the service. - // TODO: Add support for 'standard' services. - services::Service::Type svcType = - services::ServiceRegistry::lookupServiceType(service); - services::Service *svc = - acc.getService(svcType, idPath, implName, svcDetails, clientDetails); + services::Service *svc = nullptr; + auto activeServiceIter = activeServices.find(service); + if (activeServiceIter != activeServices.end()) { + services::Service::Type svcType = + services::ServiceRegistry::lookupServiceType( + activeServiceIter->second->getServiceSymbol()); + svc = activeServiceIter->second->getChildService( + &acc, svcType, idPath, implName, svcDetails, clientDetails); + } else { + services::Service::Type svcType = + services::ServiceRegistry::lookupServiceType(service); + svc = acc.getService(svcType, idPath, implName, svcDetails, clientDetails); + } + if (svc) // Update the active services table. activeServices[service] = svc; diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Services.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Services.cpp index fadd3ace8037..2288189bd050 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Services.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Services.cpp @@ -24,6 +24,14 @@ using namespace esi; using namespace esi::services; +Service *Service::getChildService(AcceleratorConnection *conn, + Service::Type service, AppIDPath id, + std::string implName, + ServiceImplDetails details, + HWClientDetails clients) { + return conn->getService(service, id, implName, details, clients); +} + std::string SysInfo::getServiceSymbol() const { return "__builtin_SysInfo"; } // Allocate 10MB for the uncompressed manifest. This should be plenty. @@ -41,7 +49,90 @@ std::string SysInfo::getJsonManifest() const { return std::string(reinterpret_cast(dst.data()), dstSize); } -std::string MMIO::getServiceSymbol() const { return "__builtin_MMIO"; } +//===----------------------------------------------------------------------===// +// MMIO class implementations. +//===----------------------------------------------------------------------===// + +MMIO::MMIO(Context &ctxt, AppIDPath idPath, std::string implName, + const ServiceImplDetails &details, const HWClientDetails &clients) { + for (const HWClientDetail &client : clients) { + auto offsetIter = client.implOptions.find("offset"); + if (offsetIter == client.implOptions.end()) + throw std::runtime_error("MMIO client missing 'offset' option"); + Constant offset = std::any_cast(offsetIter->second); + uint64_t offsetVal = std::any_cast(offset.value); + if (offsetVal >= 1ul << 32) + throw std::runtime_error("MMIO client offset mustn't exceed 32 bits"); + + auto sizeIter = client.implOptions.find("size"); + if (sizeIter == client.implOptions.end()) + throw std::runtime_error("MMIO client missing 'size' option"); + Constant size = std::any_cast(sizeIter->second); + uint64_t sizeVal = std::any_cast(size.value); + if (sizeVal >= 1ul << 32) + throw std::runtime_error("MMIO client size mustn't exceed 32 bits"); + regions[client.relPath] = + RegionDescriptor{(uint32_t)offsetVal, (uint32_t)sizeVal}; + } +} + +std::string MMIO::getServiceSymbol() const { + return std::string(MMIO::StdName); +} +ServicePort *MMIO::getPort(AppIDPath id, const BundleType *type, + const std::map &, + AcceleratorConnection &conn) const { + auto regionIter = regions.find(id); + if (regionIter == regions.end()) + return nullptr; + return new MMIORegion(id.back(), const_cast(this), + regionIter->second); +} + +namespace { +class MMIOPassThrough : public MMIO { +public: + MMIOPassThrough(Context &ctxt, AppIDPath idPath, std::string implName, + const ServiceImplDetails &details, + const HWClientDetails &clients, MMIO *parent) + : MMIO(ctxt, idPath, implName, details, clients), parent(parent) {} + uint64_t read(uint32_t addr) const override { return parent->read(addr); } + void write(uint32_t addr, uint64_t data) override { + parent->write(addr, data); + } + +private: + MMIO *parent; +}; +} // namespace + +Service *MMIO::getChildService(AcceleratorConnection *conn, + Service::Type service, AppIDPath id, + std::string implName, ServiceImplDetails details, + HWClientDetails clients) { + if (service != typeid(MMIO)) + return Service::getChildService(conn, service, id, implName, details, + clients); + return new MMIOPassThrough(conn->getCtxt(), id, implName, details, clients, + this); +} + +//===----------------------------------------------------------------------===// +// MMIO Region service port class implementations. +//===----------------------------------------------------------------------===// + +MMIO::MMIORegion::MMIORegion(AppID id, MMIO *parent, RegionDescriptor desc) + : ServicePort(id, {}), parent(parent), desc(desc) {} +uint64_t MMIO::MMIORegion::read(uint32_t addr) const { + if (addr >= desc.size) + throw std::runtime_error("MMIO read out of bounds: " + toHex(addr)); + return parent->read(desc.base + addr); +} +void MMIO::MMIORegion::write(uint32_t addr, uint64_t data) { + if (addr >= desc.size) + throw std::runtime_error("MMIO write out of bounds: " + toHex(addr)); + parent->write(desc.base + addr, data); +} MMIOSysInfo::MMIOSysInfo(const MMIO *mmio) : mmio(mmio) {} @@ -211,5 +302,7 @@ Service::Type ServiceRegistry::lookupServiceType(const std::string &svcName) { return typeid(FuncService); if (svcName == "esi.service.std.call") return typeid(CallService); + if (svcName == MMIO::StdName) + return typeid(MMIO); return typeid(CustomService); } diff --git a/lib/Dialect/ESI/runtime/python/esiaccel/accelerator.py b/lib/Dialect/ESI/runtime/python/esiaccel/accelerator.py index 7ade803b67db..ed5c335b88e4 100644 --- a/lib/Dialect/ESI/runtime/python/esiaccel/accelerator.py +++ b/lib/Dialect/ESI/runtime/python/esiaccel/accelerator.py @@ -10,7 +10,7 @@ # # ===-----------------------------------------------------------------------===# -from typing import Dict, Optional +from typing import Dict, List, Optional from .types import BundlePort from . import esiCppAccel as cpp @@ -66,6 +66,13 @@ def ports(self) -> Dict[cpp.AppID, BundlePort]: for name, port in self.cpp_hwmodule.ports.items() } + @property + def services(self) -> List[cpp.AppID]: + return self.cpp_hwmodule.services + + +MMIO = cpp.MMIO + class Instance(HWModule): """Subclass of `HWModule` which represents a submodule instance. Adds an diff --git a/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp b/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp index d23eb4581c9e..145e468b2899 100644 --- a/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp +++ b/lib/Dialect/ESI/runtime/python/esiaccel/esiCppAccel.cpp @@ -42,6 +42,24 @@ struct polymorphic_type_hook { return port; } }; +template <> +struct polymorphic_type_hook { + static const void *get(const Service *svc, const std::type_info *&type) { + if (auto p = dynamic_cast(svc)) { + type = &typeid(MMIO); + return p; + } + if (auto p = dynamic_cast(svc)) { + type = &typeid(SysInfo); + return p; + } + if (auto p = dynamic_cast(svc)) { + type = &typeid(HostMem); + return p; + } + return svc; + } +}; namespace detail { /// Pybind11 doesn't have a built-in type caster for std::any @@ -118,6 +136,31 @@ PYBIND11_MODULE(esiCppAccel, m) { "type", [](Constant &c) { return getPyType(*c.type); }, py::return_value_policy::reference); + py::class_(m, "AppID") + .def(py::init>(), py::arg("name"), + py::arg("idx") = std::nullopt) + .def_property_readonly("name", [](AppID &id) { return id.name; }) + .def_property_readonly("idx", + [](AppID &id) -> py::object { + if (id.idx) + return py::cast(id.idx); + return py::none(); + }) + .def("__repr__", + [](AppID &id) { + std::string ret = "<" + id.name; + if (id.idx) + ret = ret + "[" + std::to_string(*id.idx) + "]"; + ret = ret + ">"; + return ret; + }) + .def("__eq__", [](AppID &a, AppID &b) { return a == b; }) + .def("__hash__", [](AppID &id) { + return utils::hash_combine(std::hash{}(id.name), + std::hash{}(id.idx.value_or(-1))); + }); + py::class_(m, "AppIDPath").def("__repr__", &AppIDPath::toStr); + py::class_(m, "ModuleInfo") .def_property_readonly("name", [](ModuleInfo &info) { return info.name; }) .def_property_readonly("summary", @@ -137,13 +180,22 @@ PYBIND11_MODULE(esiCppAccel, m) { return os.str(); }); - py::class_(m, "SysInfo") + py::class_(m, "Service"); + + py::class_(m, "SysInfo") .def("esi_version", &SysInfo::getEsiVersion) .def("json_manifest", &SysInfo::getJsonManifest); - py::class_(m, "MMIO") + py::class_(m, "MMIORegionDescriptor") + .def_property_readonly("base", + [](MMIO::RegionDescriptor &r) { return r.base; }) + .def_property_readonly("size", + [](MMIO::RegionDescriptor &r) { return r.size; }); + py::class_(m, "MMIO") .def("read", &services::MMIO::read) - .def("write", &services::MMIO::write); + .def("write", &services::MMIO::write) + .def_property_readonly("regions", &services::MMIO::getRegions, + py::return_value_policy::reference); py::class_(m, "HostMemRegion") .def_property_readonly("ptr", @@ -168,7 +220,7 @@ PYBIND11_MODULE(esiCppAccel, m) { return ret; }); - py::class_(m, "HostMem") + py::class_(m, "HostMem") .def("allocate", &services::HostMem::allocate, py::arg("size"), py::arg("options") = services::HostMem::Options(), py::return_value_policy::take_ownership) @@ -186,30 +238,6 @@ PYBIND11_MODULE(esiCppAccel, m) { }, py::arg("ptr")); - py::class_(m, "AppID") - .def(py::init>(), py::arg("name"), - py::arg("idx") = std::nullopt) - .def_property_readonly("name", [](AppID &id) { return id.name; }) - .def_property_readonly("idx", - [](AppID &id) -> py::object { - if (id.idx) - return py::cast(id.idx); - return py::none(); - }) - .def("__repr__", - [](AppID &id) { - std::string ret = "<" + id.name; - if (id.idx) - ret = ret + "[" + std::to_string(*id.idx) + "]"; - ret = ret + ">"; - return ret; - }) - .def("__eq__", [](AppID &a, AppID &b) { return a == b; }) - .def("__hash__", [](AppID &id) { - return utils::hash_combine(std::hash{}(id.name), - std::hash{}(id.idx.value_or(-1))); - }); - // py::class_>(m, "MessageDataFuture"); py::class_>(m, "MessageDataFuture") .def("valid", @@ -266,6 +294,12 @@ PYBIND11_MODULE(esiCppAccel, m) { py::return_value_policy::reference); py::class_(m, "ServicePort"); + + py::class_(m, "MMIORegion") + .def_property_readonly("descriptor", &MMIO::MMIORegion::getDescriptor) + .def("read", &MMIO::MMIORegion::read) + .def("write", &MMIO::MMIORegion::write); + py::class_(m, "Function") .def( "call", @@ -286,6 +320,8 @@ PYBIND11_MODULE(esiCppAccel, m) { py::class_(m, "HWModule") .def_property_readonly("info", &HWModule::getInfo) .def_property_readonly("ports", &HWModule::getPorts, + py::return_value_policy::reference) + .def_property_readonly("services", &HWModule::getServices, py::return_value_policy::reference); // In order to inherit methods from "HWModule", it needs to be defined first. diff --git a/lib/Dialect/ESI/runtime/python/esiaccel/types.py b/lib/Dialect/ESI/runtime/python/esiaccel/types.py index 8ee716259bdc..c66bc3f57bc5 100644 --- a/lib/Dialect/ESI/runtime/python/esiaccel/types.py +++ b/lib/Dialect/ESI/runtime/python/esiaccel/types.py @@ -362,6 +362,8 @@ def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort): # TODO: add a proper registration mechanism for service ports. if isinstance(cpp_port, cpp.Function): return super().__new__(FunctionPort) + if isinstance(cpp_port, cpp.MMIORegion): + return super().__new__(MMIORegion) return super().__new__(cls) def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort): @@ -403,6 +405,28 @@ def add_done_callback(self, fn: Callable[[Future], object]) -> None: raise NotImplementedError("add_done_callback is not implemented") +class MMIORegion(BundlePort): + """A region of memory-mapped I/O space. This is a collection of named + channels, which are either read or read-write. The channels are accessed + by name, and can be connected to the host.""" + + def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion): + super().__init__(owner, cpp_port) + self.region = cpp_port + + @property + def descriptor(self) -> cpp.MMIORegionDesc: + return self.region.descriptor + + def read(self, offset: int) -> bytearray: + """Read a value from the MMIO region at the given offset.""" + return self.region.read(offset) + + def write(self, offset: int, data: bytearray) -> None: + """Write a value to the MMIO region at the given offset.""" + self.region.write(offset, data) + + class FunctionPort(BundlePort): """A pair of channels which carry the input and output of a function."""