Skip to content

Commit

Permalink
[ESI][Runtime] Address MMIO regions symbolically (#7568)
Browse files Browse the repository at this point in the history
Access MMIO address regions by the AppID of the requestor. Also provide
access to the MMIO space descriptors.
  • Loading branch information
teqdruid authored Sep 2, 2024
1 parent d88c535 commit 1df1ef4
Show file tree
Hide file tree
Showing 11 changed files with 364 additions and 75 deletions.
78 changes: 46 additions & 32 deletions frontends/PyCDE/integration_test/test_software/esi_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import esiaccel as esi
from esiaccel.types import MMIORegion

import sys
import time
Expand All @@ -10,64 +11,77 @@
data = mmio.read(8)
assert data == 0x207D98E5E5100E51

assert acc.sysinfo().esi_version() == 0
m = acc.manifest()
assert m.api_version == 0
print(m.type_table)

d = acc.build_accelerator()

mmio_svc: esi.accelerator.MMIO
for svc in d.services:
if isinstance(svc, esi.accelerator.MMIO):
mmio_svc = svc
break

for id, region in mmio_svc.regions.items():
print(f"Region {id}: {region.base} - {region.base + region.size}")

assert len(mmio_svc.regions) == 4

################################################################################
# MMIOClient tests
################################################################################


def read_offset(mmio_offset: int, offset: int, add_amt: int):
data = mmio.read(mmio_offset + offset)
def read_offset(mmio_x: MMIORegion, offset: int, add_amt: int):
data = mmio_x.read(offset)
if data == add_amt + offset:
print(f"PASS: read_offset({mmio_offset}, {offset}, {add_amt}) -> {data}")
print(f"PASS: read_offset({offset}, {add_amt}) -> {data}")
else:
assert False, f"read_offset({mmio_offset}, {offset}, {add_amt}) -> {data}"
assert False, f"read_offset({offset}, {add_amt}) -> {data}"


mmio9 = d.ports[esi.AppID("mmio_client", 9)]
read_offset(mmio9, 0, 9)
read_offset(mmio9, 13, 9)

# MMIO offset into mmio_client[9]. TODO: get this from the manifest. API coming.
mmio_client_9_offset = 131072
read_offset(mmio_client_9_offset, 0, 9)
read_offset(mmio_client_9_offset, 13, 9)
mmio4 = d.ports[esi.AppID("mmio_client", 4)]
read_offset(mmio4, 0, 4)
read_offset(mmio4, 13, 4)

# MMIO offset into mmio_client[4].
mmio_client_4_offset = 65536
read_offset(mmio_client_4_offset, 0, 4)
read_offset(mmio_client_4_offset, 13, 4)
mmio14 = d.ports[esi.AppID("mmio_client", 14)]
read_offset(mmio14, 0, 14)
read_offset(mmio14, 13, 14)

# MMIO offset into mmio_client[14].
mmio_client_14_offset = 196608
read_offset(mmio_client_14_offset, 0, 14)
read_offset(mmio_client_14_offset, 13, 14)
assert mmio14.descriptor.base == 196608
assert mmio14.descriptor.size == 65536

################################################################################
# MMIOReadWriteClient tests
################################################################################

mmio_rw_client_offset = 262144
mmio_rw = d.ports[esi.AppID("mmio_rw_client")]


def read_offset_check(i: int, add_amt: int):
d = mmio.read(mmio_rw_client_offset + i)
if d == i + 9:
print(f"PASS: read_offset_check({mmio_rw_client_offset} + {i}: {d}")
d = mmio_rw.read(i)
if d == i + add_amt:
print(f"PASS: read_offset_check({i}): {d}")
else:
assert False, f": read_offset_check({mmio_rw_client_offset} + {i}: {d}"
assert False, f": read_offset_check({i}): {d}"


mmio.write(mmio_rw_client_offset + 8, 9)
read_offset_check(0, 9)
read_offset_check(12, 9)
read_offset_check(0x1400, 9)
add_amt = 137
mmio_rw.write(8, add_amt)
read_offset_check(0, add_amt)
read_offset_check(12, add_amt)
read_offset_check(0x1400, add_amt)

################################################################################
# Manifest tests
################################################################################

assert acc.sysinfo().esi_version() == 0
m = acc.manifest()
assert m.api_version == 0
print(m.type_table)

d = acc.build_accelerator()
loopback = d.children[esi.AppID("loopback")]
recv = loopback.ports[esi.AppID("add")].read_port("result")
recv.connect()
Expand Down Expand Up @@ -108,7 +122,7 @@ def read_offset_check(i: int, add_amt: int):
if write_succeeded:
break

assert (write_succeeded, "Non-blocking write failed")
assert write_succeeded, "Non-blocking write failed"
resp = recv.read()
print(f"data: {data}")
print(f"resp: {resp}")
Expand Down
12 changes: 10 additions & 2 deletions frontends/PyCDE/src/pycde/bsp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,19 @@ def build_table(bundles) -> Tuple[Dict[int, AssignableSignal], int]:
for bundle in bundles.to_client_reqs:
if bundle.port == 'read':
table[offset] = bundle
bundle.add_record({"offset": offset, "type": "ro"})
bundle.add_record({
"offset": offset,
"size": ChannelMMIO.RegisterSpace,
"type": "ro"
})
offset += ChannelMMIO.RegisterSpace
elif bundle.port == 'read_write':
table[offset] = bundle
bundle.add_record({"offset": offset, "type": "rw"})
bundle.add_record({
"offset": offset,
"size": ChannelMMIO.RegisterSpace,
"type": "rw"
})
offset += ChannelMMIO.RegisterSpace
else:
assert False, "Unrecognized port name."
Expand Down
6 changes: 3 additions & 3 deletions frontends/PyCDE/src/pycde/esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .module import generator, Module, ModuleLikeBuilderBase, PortProxyBase
from .signals import (BitsSignal, BundleSignal, ChannelSignal, Signal,
_FromCirctValue)
from .support import get_user_loc
from .support import _obj_to_attribute, get_user_loc, obj_to_typed_attribute
from .system import System
from .types import (Bits, Bundle, BundledChannel, Channel, ChannelDirection,
StructType, Type, UInt, types, _FromCirctType)
Expand Down Expand Up @@ -144,9 +144,9 @@ def add_record(self, details: Dict[str, str]):
give the runtime necessary information about how to connect to the client
through the generated service. For instance, offsets into an MMIO space."""

ir_details: Dict[str, ir.StringAttr] = {}
ir_details: Dict[str, ir.Attribute] = {}
for k, v in details.items():
ir_details[k] = ir.StringAttr.get(str(v))
ir_details[k] = _obj_to_attribute(v)
with get_user_loc(), ir.InsertionPoint.at_block_begin(
self.rec.reqDetails.blocks[0]):
raw_esi.ServiceImplClientRecordOp(
Expand Down
4 changes: 4 additions & 0 deletions include/circt/Dialect/ESI/ESIStdServices.td
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,8 @@ def MMIOServiceDeclOp: ESI_Op<"service.std.mmio",
let assemblyFormat = [{
$sym_name attr-dict
}];

let extraClassDeclaration = [{
std::optional<StringRef> getTypeName() { return "esi.service.std.mmio"; }
}];
}
4 changes: 4 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Design.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class HWModule {
const std::map<AppID, const BundlePort &> &getPorts() const {
return portIndex;
}
/// Access the services provided by this module.
const std::vector<services::Service *> &getServices() const {
return services;
}

/// Master poll method. Calls the `poll` method on all locally owned ports and
/// the master `poll` method on all of the children. Returns true if any of
Expand Down
76 changes: 76 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/include/esi/Services.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define ESI_RUNTIME_SERVICES_H

#include "esi/Common.h"
#include "esi/Context.h"
#include "esi/Ports.h"

#include <cstdint>
Expand Down Expand Up @@ -48,6 +49,19 @@ class Service {

virtual std::string getServiceSymbol() const = 0;

/// Create a "child" service of this service. Does not have to be the same
/// service type, but typically is. Used when a service already exists in the
/// active services table, but a new one wants to replace it. Useful for cases
/// where the child service needs to use the parent service. Defaults to
/// calling the `getService` method on `AcceleratorConnection` to get the
/// global service, implying that the child service does not need to use the
/// service it is replacing.
virtual Service *getChildService(AcceleratorConnection *conn,
Service::Type service, AppIDPath id = {},
std::string implName = {},
ServiceImplDetails details = {},
HWClientDetails clients = {});

/// Get specialized port for this service to attach to the given appid path.
/// Null returns mean nothing to attach.
virtual ServicePort *getPort(AppIDPath id, const BundleType *type,
Expand Down Expand Up @@ -94,10 +108,72 @@ class SysInfo : public Service {

class MMIO : public Service {
public:
static constexpr std::string_view StdName = "esi.service.std.mmio";

/// Describe a region (slice) of MMIO space.
struct RegionDescriptor {
uint32_t base;
uint32_t size;
};

MMIO(Context &ctxt, AppIDPath idPath, std::string implName,
const ServiceImplDetails &details, const HWClientDetails &clients);
MMIO() = default;
virtual ~MMIO() = default;

/// Read a 64-bit value from the global MMIO space.
virtual uint64_t read(uint32_t addr) const = 0;
/// Write a 64-bit value to the global MMIO space.
virtual void write(uint32_t addr, uint64_t data) = 0;
/// Get the regions of MMIO space that this service manages. Otherwise known
/// as the base address table.
const std::map<AppIDPath, RegionDescriptor> &getRegions() const {
return regions;
}

/// If the service is a MMIO service, return a region of the MMIO space which
/// peers into ours.
virtual Service *getChildService(AcceleratorConnection *conn,
Service::Type service, AppIDPath id = {},
std::string implName = {},
ServiceImplDetails details = {},
HWClientDetails clients = {}) override;

virtual std::string getServiceSymbol() const override;

/// Get a MMIO region port for a particular region descriptor.
virtual ServicePort *getPort(AppIDPath id, const BundleType *type,
const std::map<std::string, ChannelPort &> &,
AcceleratorConnection &) const override;

private:
/// MMIO base address table.
std::map<AppIDPath, RegionDescriptor> regions;

public:
/// A "slice" of some parent MMIO space.
class MMIORegion : public ServicePort {
friend class MMIO;
MMIORegion(AppID id, MMIO *parent, RegionDescriptor desc);

public:
/// Get the offset (and size) of the region in the parent (usually global)
/// MMIO address space.
virtual RegionDescriptor getDescriptor() const { return desc; };
/// Read a 64-bit value from this region, not the global address space.
virtual uint64_t read(uint32_t addr) const;
/// Write a 64-bit value to this region, not the global address space.
virtual void write(uint32_t addr, uint64_t data);

virtual std::optional<std::string> toString() const override {
return "MMIO region " + toHex(desc.base) + " - " +
toHex(desc.base + desc.size);
}

private:
MMIO *parent;
RegionDescriptor desc;
};
};

/// Implement the SysInfo API for a standard MMIO protocol.
Expand Down
39 changes: 31 additions & 8 deletions lib/Dialect/ESI/runtime/cpp/lib/Manifest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,21 @@ std::any Manifest::Impl::getAny(const nlohmann::json &value) const {
std::map<std::string, std::any> ret;
for (auto &e : json.items())
ret[e.key()] = getAny(e.value());
return ret;

// If this can be converted to a constant, do so.
if (ret.size() != 2 || !ret.contains("type") || !ret.contains("value"))
return ret;
std::any value = ret.at("value");
std::any typeID = ret.at("type");
if (typeID.type() != typeid(std::string))
return ret;
std::optional<const Type *> type =
getType(std::any_cast<std::string>(type));
if (!type)
return ret;
// TODO: Check or guide the conversion of the value to the type based on the
// type.
return Constant{value, type};
};

auto getArray = [this](const nlohmann::json &json) -> std::any {
Expand All @@ -160,10 +174,10 @@ std::any Manifest::Impl::getAny(const nlohmann::json &value) const {
auto getValue = [&](const nlohmann::json &innerValue) -> std::any {
if (innerValue.is_string())
return innerValue.get<std::string>();
else if (innerValue.is_number_integer())
return innerValue.get<int64_t>();
else if (innerValue.is_number_unsigned())
return innerValue.get<uint64_t>();
else if (innerValue.is_number_integer())
return innerValue.get<int64_t>();
else if (innerValue.is_number_float())
return innerValue.get<double>();
else if (innerValue.is_boolean())
Expand Down Expand Up @@ -368,11 +382,20 @@ Manifest::Impl::getService(AppIDPath idPath, AcceleratorConnection &acc,
}

// Create the service.
// TODO: Add support for 'standard' services.
services::Service::Type svcType =
services::ServiceRegistry::lookupServiceType(service);
services::Service *svc =
acc.getService(svcType, idPath, implName, svcDetails, clientDetails);
services::Service *svc = nullptr;
auto activeServiceIter = activeServices.find(service);
if (activeServiceIter != activeServices.end()) {
services::Service::Type svcType =
services::ServiceRegistry::lookupServiceType(
activeServiceIter->second->getServiceSymbol());
svc = activeServiceIter->second->getChildService(
&acc, svcType, idPath, implName, svcDetails, clientDetails);
} else {
services::Service::Type svcType =
services::ServiceRegistry::lookupServiceType(service);
svc = acc.getService(svcType, idPath, implName, svcDetails, clientDetails);
}

if (svc)
// Update the active services table.
activeServices[service] = svc;
Expand Down
Loading

0 comments on commit 1df1ef4

Please sign in to comment.