Skip to content

Commit

Permalink
[L0] Allocate event pools efficiently in multi-device scenarios
Browse files Browse the repository at this point in the history
Signed-off-by: Raiyan Latif <raiyan.latif@intel.com>
  • Loading branch information
raiyanla authored and nrspruit committed Feb 9, 2024
1 parent 47102cb commit 92a0250
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 80 deletions.
30 changes: 19 additions & 11 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ static ur_result_t enqueueCommandBufferMemCopyHelper(
SyncPointWaitList, ZeEventList));

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, false, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, false, &LaunchEvent));
LaunchEvent->CommandType = CommandType;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -358,7 +359,8 @@ static ur_result_t enqueueCommandBufferMemCopyRectHelper(
SyncPointWaitList, ZeEventList));

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, false, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, false, &LaunchEvent));
LaunchEvent->CommandType = CommandType;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -401,7 +403,8 @@ static ur_result_t enqueueCommandBufferFillHelper(
SyncPointWaitList, ZeEventList));

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, true, &LaunchEvent));
LaunchEvent->CommandType = CommandType;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -453,8 +456,10 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
// Create signal & wait events to be used in the command-list for sync
// on command-buffer enqueue.
auto RetCommandBuffer = *CommandBuffer;
UR_CALL(EventCreate(Context, nullptr, false, &RetCommandBuffer->SignalEvent));
UR_CALL(EventCreate(Context, nullptr, false, &RetCommandBuffer->WaitEvent));
UR_CALL(EventCreate(Context, nullptr, false, false,
&RetCommandBuffer->SignalEvent));
UR_CALL(EventCreate(Context, nullptr, false, false,
&RetCommandBuffer->WaitEvent));

// Add prefix commands
ZE2UR_CALL(zeCommandListAppendEventReset,
Expand Down Expand Up @@ -550,7 +555,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
SyncPointWaitList, ZeEventList));
ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, false, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, false, &LaunchEvent));
LaunchEvent->CommandType = UR_COMMAND_KERNEL_LAUNCH;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -732,7 +738,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
}

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, true, &LaunchEvent));
LaunchEvent->CommandType = UR_COMMAND_USM_PREFETCH;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -795,7 +802,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
}

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
UR_CALL(
EventCreate(CommandBuffer->Context, nullptr, false, true, &LaunchEvent));
LaunchEvent->CommandType = UR_COMMAND_USM_ADVISE;

// Get sync point and register the event with it.
Expand Down Expand Up @@ -933,9 +941,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
(SignalCommandList->first, CommandBuffer->WaitEvent->ZeEvent));

if (Event) {
UR_CALL(createEventAndAssociateQueue(Queue, &RetEvent,
UR_COMMAND_COMMAND_BUFFER_ENQUEUE_EXP,
SignalCommandList, false, true));
UR_CALL(createEventAndAssociateQueue(
Queue, &RetEvent, UR_COMMAND_COMMAND_BUFFER_ENQUEUE_EXP,
SignalCommandList, false, false, true));

if ((Queue->Properties & UR_QUEUE_FLAG_PROFILING_ENABLE)) {
// Multiple submissions of a command buffer implies that we need to save
Expand Down
47 changes: 34 additions & 13 deletions source/adapters/level_zero/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,17 @@ static const uint32_t MaxNumEventsPerPool = [] {

ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool(
ze_event_pool_handle_t &Pool, size_t &Index, bool HostVisible,
bool ProfilingEnabled) {
bool ProfilingEnabled, ur_device_handle_t Device) {
// Lock while updating event pool machinery.
std::scoped_lock<ur_mutex> Lock(ZeEventPoolCacheMutex);

ze_device_handle_t ZeDevice = nullptr;

if (Device) {
ZeDevice = Device->ZeDevice;
}
std::list<ze_event_pool_handle_t> *ZePoolCache =
getZeEventPoolCache(HostVisible, ProfilingEnabled);
getZeEventPoolCache(HostVisible, ProfilingEnabled, ZeDevice);

if (!ZePoolCache->empty()) {
if (NumEventsAvailableInEventPool[ZePoolCache->front()] == 0) {
Expand Down Expand Up @@ -511,9 +516,14 @@ ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool(
urPrint("ze_event_pool_desc_t flags set to: %d\n", ZeEventPoolDesc.flags);

std::vector<ze_device_handle_t> ZeDevices;
std::for_each(
Devices.begin(), Devices.end(),
[&](const ur_device_handle_t &D) { ZeDevices.push_back(D->ZeDevice); });
if (ZeDevice) {
ZeDevices.push_back(ZeDevice);
} else {
std::for_each(Devices.begin(), Devices.end(),
[&](const ur_device_handle_t &D) {
ZeDevices.push_back(D->ZeDevice);
});
}

ZE2UR_CALL(zeEventPoolCreate, (ZeContext, &ZeEventPoolDesc,
ZeDevices.size(), &ZeDevices[0], ZePool));
Expand All @@ -528,11 +538,10 @@ ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool(
return UR_RESULT_SUCCESS;
}

ur_event_handle_t
ur_context_handle_t_::getEventFromContextCache(bool HostVisible,
bool WithProfiling) {
ur_event_handle_t ur_context_handle_t_::getEventFromContextCache(
bool HostVisible, bool WithProfiling, ur_device_handle_t Device) {
std::scoped_lock<ur_mutex> Lock(EventCacheMutex);
auto Cache = getEventCache(HostVisible, WithProfiling);
auto Cache = getEventCache(HostVisible, WithProfiling, Device);
if (Cache->empty())
return nullptr;

Expand All @@ -546,8 +555,14 @@ ur_context_handle_t_::getEventFromContextCache(bool HostVisible,

void ur_context_handle_t_::addEventToContextCache(ur_event_handle_t Event) {
std::scoped_lock<ur_mutex> Lock(EventCacheMutex);
auto Cache =
getEventCache(Event->isHostVisible(), Event->isProfilingEnabled());
ur_device_handle_t Device = nullptr;

if (!Event->IsMultiDevice && Event->UrQueue) {
Device = Event->UrQueue->Device;
}

auto Cache = getEventCache(Event->isHostVisible(),
Event->isProfilingEnabled(), Device);
Cache->emplace_back(Event);
}

Expand All @@ -562,8 +577,14 @@ ur_context_handle_t_::decrementUnreleasedEventsInPool(ur_event_handle_t Event) {
return UR_RESULT_SUCCESS;
}

std::list<ze_event_pool_handle_t> *ZePoolCache =
getZeEventPoolCache(Event->isHostVisible(), Event->isProfilingEnabled());
ze_device_handle_t ZeDevice = nullptr;

if (!Event->IsMultiDevice && Event->UrQueue) {
ZeDevice = Event->UrQueue->Device->ZeDevice;
}

std::list<ze_event_pool_handle_t> *ZePoolCache = getZeEventPoolCache(
Event->isHostVisible(), Event->isProfilingEnabled(), ZeDevice);

// Put the empty pool to the cache of the pools.
if (NumEventsUnreleasedInEventPool[Event->ZeEventPool] == 0)
Expand Down
78 changes: 66 additions & 12 deletions source/adapters/level_zero/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ struct ur_context_handle_t_ : _ur_object {
//
// Cache of event pools to which host-visible events are added to.
std::vector<std::list<ze_event_pool_handle_t>> ZeEventPoolCache{4};
std::vector<std::unordered_map<ze_device_handle_t,
std::list<ze_event_pool_handle_t> *>>
ZeEventPoolCacheDeviceMap{4};

// This map will be used to determine if a pool is full or not
// by storing number of empty slots available in the pool.
Expand All @@ -163,6 +166,9 @@ struct ur_context_handle_t_ : _ur_object {

// Caches for events.
std::vector<std::list<ur_event_handle_t>> EventCaches{4};
std::vector<
std::unordered_map<ur_device_handle_t, std::list<ur_event_handle_t> *>>
EventCachesDeviceMap{4};

// Initialize the PI context.
ur_result_t initialize();
Expand All @@ -188,20 +194,46 @@ struct ur_context_handle_t_ : _ur_object {
// slot for an event with profiling capabilities.
ur_result_t getFreeSlotInExistingOrNewPool(ze_event_pool_handle_t &, size_t &,
bool HostVisible,
bool ProfilingEnabled);
bool ProfilingEnabled,
ur_device_handle_t Device);

// Get ur_event_handle_t from cache.
ur_event_handle_t getEventFromContextCache(bool HostVisible,
bool WithProfiling);
bool WithProfiling,
ur_device_handle_t Device);

// Add ur_event_handle_t to cache.
void addEventToContextCache(ur_event_handle_t);

auto getZeEventPoolCache(bool HostVisible, bool WithProfiling) {
if (HostVisible)
return WithProfiling ? &ZeEventPoolCache[0] : &ZeEventPoolCache[1];
else
return WithProfiling ? &ZeEventPoolCache[2] : &ZeEventPoolCache[3];
auto getZeEventPoolCache(bool HostVisible, bool WithProfiling,
ze_device_handle_t ZeDevice) {
if (HostVisible) {
if (ZeDevice) {
auto ZeEventPoolCacheMap = WithProfiling
? &ZeEventPoolCacheDeviceMap[0]
: &ZeEventPoolCacheDeviceMap[1];
if (ZeEventPoolCacheMap->find(ZeDevice) == ZeEventPoolCacheMap->end()) {
ZeEventPoolCache.emplace_back();
(*ZeEventPoolCacheMap)[ZeDevice] = &ZeEventPoolCache.back();
}
return (*ZeEventPoolCacheMap)[ZeDevice];
} else {
return WithProfiling ? &ZeEventPoolCache[0] : &ZeEventPoolCache[1];
}
} else {
if (ZeDevice) {
auto ZeEventPoolCacheMap = WithProfiling
? &ZeEventPoolCacheDeviceMap[2]
: &ZeEventPoolCacheDeviceMap[3];
if (ZeEventPoolCacheMap->find(ZeDevice) == ZeEventPoolCacheMap->end()) {
ZeEventPoolCache.emplace_back();
(*ZeEventPoolCacheMap)[ZeDevice] = &ZeEventPoolCache.back();
}
return (*ZeEventPoolCacheMap)[ZeDevice];
} else {
return WithProfiling ? &ZeEventPoolCache[2] : &ZeEventPoolCache[3];
}
}
}

// Decrement number of events living in the pool upon event destroy
Expand Down Expand Up @@ -240,11 +272,33 @@ struct ur_context_handle_t_ : _ur_object {

private:
// Get the cache of events for a provided scope and profiling mode.
auto getEventCache(bool HostVisible, bool WithProfiling) {
if (HostVisible)
return WithProfiling ? &EventCaches[0] : &EventCaches[1];
else
return WithProfiling ? &EventCaches[2] : &EventCaches[3];
auto getEventCache(bool HostVisible, bool WithProfiling,
ur_device_handle_t Device) {
if (HostVisible) {
if (Device) {
auto EventCachesMap =
WithProfiling ? &EventCachesDeviceMap[0] : &EventCachesDeviceMap[1];
if (EventCachesMap->find(Device) == EventCachesMap->end()) {
EventCaches.emplace_back();
(*EventCachesMap)[Device] = &EventCaches.back();
}
return (*EventCachesMap)[Device];
} else {
return WithProfiling ? &EventCaches[0] : &EventCaches[1];
}
} else {
if (Device) {
auto EventCachesMap =
WithProfiling ? &EventCachesDeviceMap[2] : &EventCachesDeviceMap[3];
if (EventCachesMap->find(Device) == EventCachesMap->end()) {
EventCaches.emplace_back();
(*EventCachesMap)[Device] = &EventCaches.back();
}
return (*EventCachesMap)[Device];
} else {
return WithProfiling ? &EventCaches[2] : &EventCaches[3];
}
}
}
};

Expand Down
Loading

0 comments on commit 92a0250

Please sign in to comment.