Skip to content

Commit fed32a8

Browse files
[SYCL] Fix device handling during interop L0 context creation (#12138)
The Level Zero make_context accepts a vector of SYCL devices. Those devices were not being propagated to the created context, which led to an error during queue creation. This patch fixes the problem by forwarding those devices to the context constructor.
1 parent 7908298 commit fed32a8

File tree

4 files changed

+92
-25
lines changed

4 files changed

+92
-25
lines changed

sycl/source/backend/level_zero.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ __SYCL_EXPORT context make_context(const std::vector<device> &DeviceList,
5555
NativeHandle, DeviceHandles.size(), DeviceHandles.data(), !KeepOwnership,
5656
&PiContext);
5757
// Construct the SYCL context from PI context.
58-
return detail::createSyclObjFromImpl<context>(std::make_shared<context_impl>(
59-
PiContext, detail::defaultAsyncHandler, Plugin, !KeepOwnership));
58+
return detail::createSyclObjFromImpl<context>(
59+
std::make_shared<context_impl>(PiContext, detail::defaultAsyncHandler,
60+
Plugin, DeviceList, !KeepOwnership));
6061
}
6162

6263
//----------------------------------------------------------------------------

sycl/source/detail/context_impl.cpp

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -71,31 +71,36 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
7171

7272
context_impl::context_impl(sycl::detail::pi::PiContext PiContext,
7373
async_handler AsyncHandler, const PluginPtr &Plugin,
74+
const std::vector<sycl::device> &DeviceList,
7475
bool OwnedByRuntime)
75-
: MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(AsyncHandler), MDevices(),
76-
MContext(PiContext), MPlatform(), MHostContext(false),
77-
MSupportBufferLocationByDevices(NotChecked) {
78-
79-
std::vector<sycl::detail::pi::PiDevice> DeviceIds;
80-
uint32_t DevicesNum = 0;
81-
// TODO catch an exception and put it to list of asynchronous exceptions
82-
Plugin->call<PiApiKind::piContextGetInfo>(
83-
MContext, PI_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum), &DevicesNum,
84-
nullptr);
85-
DeviceIds.resize(DevicesNum);
86-
// TODO catch an exception and put it to list of asynchronous exceptions
87-
Plugin->call<PiApiKind::piContextGetInfo>(
88-
MContext, PI_CONTEXT_INFO_DEVICES,
89-
sizeof(sycl::detail::pi::PiDevice) * DevicesNum, &DeviceIds[0], nullptr);
90-
91-
if (!DeviceIds.empty()) {
92-
std::shared_ptr<detail::platform_impl> Platform =
93-
platform_impl::getPlatformFromPiDevice(DeviceIds[0], Plugin);
94-
for (sycl::detail::pi::PiDevice Dev : DeviceIds) {
95-
MDevices.emplace_back(createSyclObjFromImpl<device>(
96-
Platform->getOrMakeDeviceImpl(Dev, Platform)));
76+
: MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(AsyncHandler),
77+
MDevices(DeviceList), MContext(PiContext), MPlatform(),
78+
MHostContext(false), MSupportBufferLocationByDevices(NotChecked) {
79+
if (!MDevices.empty()) {
80+
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
81+
} else {
82+
std::vector<sycl::detail::pi::PiDevice> DeviceIds;
83+
uint32_t DevicesNum = 0;
84+
// TODO catch an exception and put it to list of asynchronous exceptions
85+
Plugin->call<PiApiKind::piContextGetInfo>(
86+
MContext, PI_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum), &DevicesNum,
87+
nullptr);
88+
DeviceIds.resize(DevicesNum);
89+
// TODO catch an exception and put it to list of asynchronous exceptions
90+
Plugin->call<PiApiKind::piContextGetInfo>(
91+
MContext, PI_CONTEXT_INFO_DEVICES,
92+
sizeof(sycl::detail::pi::PiDevice) * DevicesNum, &DeviceIds[0],
93+
nullptr);
94+
95+
if (!DeviceIds.empty()) {
96+
std::shared_ptr<detail::platform_impl> Platform =
97+
platform_impl::getPlatformFromPiDevice(DeviceIds[0], Plugin);
98+
for (sycl::detail::pi::PiDevice Dev : DeviceIds) {
99+
MDevices.emplace_back(createSyclObjFromImpl<device>(
100+
Platform->getOrMakeDeviceImpl(Dev, Platform)));
101+
}
102+
MPlatform = Platform;
97103
}
98-
MPlatform = Platform;
99104
}
100105
// TODO catch an exception and put it to list of asynchronous exceptions
101106
// getPlugin() will be the same as the Plugin passed. This should be taken

sycl/source/detail/context_impl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class context_impl {
7171
/// transferred to runtime
7272
context_impl(sycl::detail::pi::PiContext PiContext,
7373
async_handler AsyncHandler, const PluginPtr &Plugin,
74+
const std::vector<sycl::device> &DeviceList = {},
7475
bool OwnedByRuntime = true);
7576

7677
~context_impl();
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// REQUIRES: level_zero, level_zero_dev_kit
2+
// RUN: %{build} -o %t.out %level_zero_options
3+
// RUN: %{run} %t.out
4+
5+
// This test checks that an interop Level Zero device is properly handled during
6+
// interop context construction.
7+
#include <sycl/ext/oneapi/backend/level_zero.hpp>
8+
#include <sycl/sycl.hpp>
9+
10+
#include <level_zero/ze_api.h>
11+
12+
#include <cassert>
13+
#include <iostream>
14+
#include <vector>
15+
16+
int main(int argc, char *argv[]) {
17+
int level0DriverIndex = 0;
18+
int level0DeviceIndex = 0;
19+
20+
zeInit(0);
21+
uint32_t level0NumDrivers = 0;
22+
zeDriverGet(&level0NumDrivers, nullptr);
23+
24+
assert(level0NumDrivers > 0);
25+
26+
std::vector<ze_driver_handle_t> level0Drivers(level0NumDrivers);
27+
zeDriverGet(&level0NumDrivers, level0Drivers.data());
28+
29+
ze_driver_handle_t level0Driver = level0Drivers[level0DriverIndex];
30+
uint32_t level0NumDevices = 0;
31+
zeDeviceGet(level0Driver, &level0NumDevices, nullptr);
32+
33+
assert(level0NumDevices > 0);
34+
35+
std::vector<ze_device_handle_t> level0Devices(level0NumDevices);
36+
zeDeviceGet(level0Driver, &level0NumDevices, level0Devices.data());
37+
38+
ze_device_handle_t level0Device = level0Devices[level0DeviceIndex];
39+
ze_context_handle_t level0Context = nullptr;
40+
ze_context_desc_t level0ContextDesc = {};
41+
level0ContextDesc.stype = ZE_STRUCTURE_TYPE_CONTEXT_DESC;
42+
zeContextCreateEx(level0Driver, &level0ContextDesc, 1, &level0Device,
43+
&level0Context);
44+
45+
sycl::device dev;
46+
sycl::device interopDev =
47+
sycl::make_device<sycl::backend::ext_oneapi_level_zero>(level0Device);
48+
sycl::context interopCtx =
49+
sycl::make_context<sycl::backend::ext_oneapi_level_zero>(
50+
{level0Context,
51+
{interopDev},
52+
sycl::ext::oneapi::level_zero::ownership::keep});
53+
54+
assert(interopCtx.get_devices().size() == 1);
55+
assert(interopCtx.get_devices()[0] == interopDev);
56+
sycl::queue q{interopCtx, interopDev};
57+
58+
zeContextDestroy(level0Context);
59+
return 0;
60+
}

0 commit comments

Comments
 (0)