Skip to content

Commit

Permalink
Use get_device_id to find DLpack device ID
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Jan 15, 2025
1 parent 7a9c6a8 commit afce3b3
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 49 deletions.
25 changes: 13 additions & 12 deletions dpctl/_sycl_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2043,21 +2043,23 @@ cdef class SyclDevice(_SyclDevice):
return str(relId)

def get_unpartitioned_parent_device(self):
""" get_unpartitioned_parent_device(self)
""" get_unpartitioned_parent_device()
Returns the unpartitioned parent device of this device, or None for a
root device.
Returns the unpartitioned parent device of this device.
If this device is already an unpartitioned, root device,
the same device is returned.
Returns:
dpctl.SyclDevice:
A parent, unpartitioned :class:`dpctl.SyclDevice` instance if
the device is a sub-device, ``None`` otherwise.
A parent, unpartitioned :class:`dpctl.SyclDevice` instance, or
``self`` if already a root device.
"""
cdef DPCTLSyclDeviceRef pDRef = NULL
cdef DPCTLSyclDeviceRef tDRef = NULL
pDRef = DPCTLDevice_GetParentDevice(self._device_ref)
if pDRef is NULL:
return None
return self
else:
tDRef = DPCTLDevice_GetParentDevice(pDRef)
while tDRef is not NULL:
Expand All @@ -2077,7 +2079,7 @@ cdef class SyclDevice(_SyclDevice):
Raises:
ValueError:
If the device is a sub-device.
If the device could not be found.
:Example:
.. code-block:: python
Expand All @@ -2089,13 +2091,12 @@ cdef class SyclDevice(_SyclDevice):
assert devs[i] == gpu_dev
"""
cdef int dev_id = -1
cdef SyclDevice dev

if self.parent_device:
raise ValueError("This SyclDevice is not a root device")

dev_id = self.get_overall_ordinal()
dev = self.get_unpartitioned_parent_device()
dev_id = dev.get_overall_ordinal()
if dev_id < 0:
raise ValueError
raise ValueError("device could not be found")
return dev_id


Expand Down
2 changes: 0 additions & 2 deletions dpctl/tensor/_dlpack.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ cpdef object to_dlpack_versioned_capsule(usm_ndarray array, bint copied) except
cpdef object numpy_to_dlpack_versioned_capsule(ndarray array, bint copied) except +
cpdef object from_dlpack_capsule(object dltensor) except +

cdef int get_parent_device_ordinal_id(SyclDevice dev) except *

cdef class DLPackCreationError(Exception):
"""
A DLPackCreateError exception is raised when constructing
Expand Down
27 changes: 2 additions & 25 deletions dpctl/tensor/_dlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -177,28 +177,6 @@ cdef object _get_default_context(c_dpctl.SyclDevice dev):

return default_context


cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except -1:
cdef DPCTLSyclDeviceRef pDRef = NULL
cdef DPCTLSyclDeviceRef tDRef = NULL
cdef c_dpctl.SyclDevice p_dev

pDRef = DPCTLDevice_GetParentDevice(dev.get_device_ref())
if pDRef is not NULL:
# if dev is a sub-device, find its parent
# and return its overall ordinal id
tDRef = DPCTLDevice_GetParentDevice(pDRef)
while tDRef is not NULL:
DPCTLDevice_Delete(pDRef)
pDRef = tDRef
tDRef = DPCTLDevice_GetParentDevice(pDRef)
p_dev = c_dpctl.SyclDevice._create(pDRef)
return p_dev.get_overall_ordinal()

# return overall ordinal id of argument device
return dev.get_overall_ordinal()


cdef int get_array_dlpack_device_id(
usm_ndarray usm_ary
) except -1:
Expand All @@ -224,14 +202,13 @@ cdef int get_array_dlpack_device_id(
"on non-partitioned SYCL devices on platforms where "
"default_context oneAPI extension is not supported."
)
device_id = ary_sycl_device.get_overall_ordinal()
else:
if not usm_ary.sycl_context == default_context:
raise DLPackCreationError(
"to_dlpack_capsule: DLPack can only export arrays based on USM "
"allocations bound to a default platform SYCL context"
)
device_id = get_parent_device_ordinal_id(ary_sycl_device)
device_id = ary_sycl_device.get_device_id()

if device_id < 0:
raise DLPackCreationError(
Expand Down Expand Up @@ -1086,7 +1063,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
d = device.sycl_device
else:
d = device
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
dl_device = (device_OneAPI, d.get_device_id())
if dl_device is not None:
if (dl_device[0] not in [device_OneAPI, device_CPU]):
raise ValueError(
Expand Down
14 changes: 7 additions & 7 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1304,16 +1304,16 @@ cdef class usm_ndarray:
DLPackCreationError:
when the ``device_id`` could not be determined.
"""
cdef int dev_id = c_dlpack.get_parent_device_ordinal_id(<c_dpctl.SyclDevice>self.sycl_device)
if dev_id < 0:
try:
dev_id = self.sycl_device.get_device_id()
except ValueError as e:
raise c_dlpack.DLPackCreationError(
"Could not determine id of the device where array was allocated."
)
else:
return (
DLDeviceType.kDLOneAPI,
dev_id,
)
return (
DLDeviceType.kDLOneAPI,
dev_id,
)

def __eq__(self, other):
return dpctl.tensor.equal(self, other)
Expand Down
19 changes: 16 additions & 3 deletions dpctl/tests/test_sycl_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,21 @@ def test_get_device_id_method():
assert hash(d) == hash(d_r)


def test_sub_devices_disallow_device_id():
def test_get_unpartitioned_parent_device_method():
"""
Test that the get_unpartitioned_parent method returns self for root
devices.
"""
devices = dpctl.get_devices()
for d in devices:
assert d == d.get_unpartitioned_parent_device()


def test_get_unpartitioned_parent_device_from_sub_device():
"""
Test that the get_unpartitioned_parent method returns the parent device
from the sub-device.
"""
try:
dev = dpctl.SyclDevice()
except dpctl.SyclDeviceCreationError:
Expand All @@ -295,5 +309,4 @@ def test_sub_devices_disallow_device_id():
except dpctl.SyclSubDeviceCreationError:
pytest.skip("Default device can not be partitioned")
assert isinstance(sdevs, list) and len(sdevs) > 0
with pytest.raises(ValueError):
sdevs[0].get_device_id()
assert dev == sdevs[0].get_unpartitioned_parent_device()

0 comments on commit afce3b3

Please sign in to comment.