Skip to content

Commit

Permalink
Adds factory functions to convert between dlpack devices and `dpctl.S…
Browse files Browse the repository at this point in the history
…yclDevice`
  • Loading branch information
ndgrigorian committed Jan 7, 2025
1 parent 23fcd62 commit f125be4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
6 changes: 6 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
uint64,
)
from dpctl.tensor._device import Device
from dpctl.tensor._dldevice_conversions import (
dldevice_to_sycldevice,
sycldevice_to_dldevice,
)
from dpctl.tensor._dlpack import from_dlpack
from dpctl.tensor._indexing_functions import (
extract,
Expand Down Expand Up @@ -388,4 +392,6 @@
"take_along_axis",
"put_along_axis",
"top_k",
"dldevice_to_sycldevice",
"sycldevice_to_dldevice",
]
40 changes: 40 additions & 0 deletions dpctl/tensor/_dldevice_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dpctl

from ._usmarray import DLDeviceType


def dldevice_to_sycldevice(dl_dev: tuple):
if isinstance(dl_dev, tuple):
if len(dl_dev) != 2:
raise ValueError("dldevice tuple must have length 2")
else:
raise TypeError(
f"dl_dev is expected to be a 2-tuple, got " f"{type(dl_dev)}"
)
if dl_dev[0] != DLDeviceType.kDLOneAPI:
raise ValueError("dldevice type must be kDLOneAPI")
return dpctl.SyclDevice(str(dl_dev[1]))


def sycldevice_to_dldevice(dev: dpctl.SyclDevice):
if not isinstance(dev, dpctl.SyclDevice):
raise TypeError(
"dev is expected to be a dpctl.SyclDevice, got " f"{type(dev)}"
)
return (DLDeviceType.kDLOneAPI, dev.get_device_id())

0 comments on commit f125be4

Please sign in to comment.