Skip to content

Commit 19cf5c6

Browse files
authored
[DLPACK] Enable cython support (#1589)
1 parent ec3f09b commit 19cf5c6

File tree

8 files changed

+113
-47
lines changed

8 files changed

+113
-47
lines changed

HalideIR

include/tvm/runtime/c_runtime_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from,
467467

468468
/*!
469469
* \brief Delete (free) a DLManagedTensor's data.
470-
* \param dltensor Pointer to the DLManagedTensor.
470+
* \param dltensor Pointer to the DLManagedTensor.
471471
*/
472472
TVM_DLL void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
473473

python/tvm/_ffi/_ctypes/ndarray.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,47 @@
1+
# pylint: disable=invalid-name
12
"""Runtime NDArray api"""
23
from __future__ import absolute_import
34

45
import ctypes
5-
from ..base import _LIB, check_call
6+
from ..base import _LIB, check_call, c_str
67
from ..runtime_ctypes import TVMArrayHandle
78
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _return_handle
89

10+
11+
TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
12+
_c_str_dltensor = c_str('dltensor')
13+
_c_str_used_dltensor = c_str('used_dltensor')
14+
15+
16+
# used for PyCapsule manipulation
17+
if hasattr(ctypes, 'pythonapi'):
18+
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
19+
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
20+
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
21+
22+
23+
def _from_dlpack(dltensor):
24+
dltensor = ctypes.py_object(dltensor)
25+
if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor):
26+
ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
27+
handle = TVMArrayHandle()
28+
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
29+
ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
30+
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
31+
return _make_array(handle, False)
32+
raise ValueError("Expect a dltensor field, PyCapsule can only be consumed once")
33+
34+
35+
def _dlpack_deleter(pycapsule):
36+
pycapsule = ctypes.cast(pycapsule, ctypes.py_object)
37+
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
38+
ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)
39+
_LIB.TVMDLManagedTensorCallDeleter(ptr)
40+
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
41+
42+
_c_dlpack_deleter = TVMPyCapsuleDestructor(_dlpack_deleter)
43+
44+
945
class NDArrayBase(object):
1046
"""A simple Device/CPU Array object in runtime."""
1147
__slots__ = ["handle", "is_view"]
@@ -29,6 +65,17 @@ def __del__(self):
2965
def _tvm_handle(self):
3066
return ctypes.cast(self.handle, ctypes.c_void_p).value
3167

68+
def to_dlpack(self):
69+
"""Produce an array from a DLPack Tensor without copying memory
70+
71+
Returns
72+
-------
73+
dlpack : DLPack tensor view of the array data
74+
"""
75+
handle = ctypes.c_void_p()
76+
check_call(_LIB.TVMArrayToDLPack(self.handle, ctypes.byref(handle)))
77+
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)
78+
3279

3380
def _make_array(handle, is_view):
3481
handle = ctypes.cast(handle, TVMArrayHandle)

python/tvm/_ffi/_cython/base.pxi

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ..base import TVMError
22
from libcpp.vector cimport vector
33
from cpython.version cimport PY_MAJOR_VERSION
4+
from cpython cimport pycapsule
45
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
56
import ctypes
67

@@ -40,6 +41,11 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
4041
int64_t* strides
4142
uint64_t byte_offset
4243

44+
ctypedef struct DLManagedTensor:
45+
DLTensor dl_tensor
46+
void* manager_ctx
47+
void (*deleter)(DLManagedTensor* self)
48+
4349
ctypedef struct TVMValue:
4450
int64_t v_int64
4551
double v_float64
@@ -49,7 +55,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
4955
DLContext v_ctx
5056

5157
ctypedef int64_t tvm_index_t
52-
ctypedef void* DLTensorHandle
58+
ctypedef DLTensor* DLTensorHandle
5359
ctypedef void* TVMStreamHandle
5460
ctypedef void* TVMRetValueHandle
5561
ctypedef void* TVMFunctionHandle
@@ -92,6 +98,11 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
9298
int TVMArrayCopyFromTo(DLTensorHandle src,
9399
DLTensorHandle to,
94100
TVMStreamHandle stream)
101+
int TVMArrayFromDLPack(DLManagedTensor* arr_from,
102+
DLTensorHandle* out)
103+
int TVMArrayToDLPack(DLTensorHandle arr_from,
104+
DLManagedTensor** out)
105+
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
95106

96107
cdef extern from "tvm/c_dsl_api.h":
97108
int TVMNodeFree(NodeHandle handle)

python/tvm/_ffi/_cython/ndarray.pxi

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,29 @@
11
from ..runtime_ctypes import TVMArrayHandle
22

3+
cdef const char* _c_str_dltensor = "dltensor"
4+
cdef const char* _c_str_used_dltensor = "used_dltensor"
5+
6+
7+
cdef void _c_dlpack_deleter(object pycaps):
8+
cdef DLManagedTensor* dltensor
9+
if pycapsule.PyCapsule_IsValid(pycaps, _c_str_dltensor):
10+
dltensor = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(pycaps, _c_str_dltensor)
11+
TVMDLManagedTensorCallDeleter(dltensor)
12+
13+
14+
def _from_dlpack(object dltensor):
15+
cdef DLManagedTensor* ptr
16+
cdef DLTensorHandle chandle
17+
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
18+
ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
19+
CALL(TVMArrayFromDLPack(ptr, &chandle))
20+
# set name and destructor to be empty
21+
pycapsule.PyCapsule_SetDestructor(dltensor, NULL)
22+
pycapsule.PyCapsule_SetName(dltensor, _c_str_used_dltensor)
23+
return c_make_array(chandle, 0)
24+
raise ValueError("Expect a dltensor field, pycapsule.PyCapsule can only be consumed once")
25+
26+
327
cdef class NDArrayBase:
428
cdef DLTensor* chandle
529
cdef int c_is_view
@@ -35,12 +59,26 @@ cdef class NDArrayBase:
3559
if self.c_is_view == 0:
3660
CALL(TVMArrayFree(self.chandle))
3761

62+
def to_dlpack(self):
63+
"""Produce an array from a DLPack Tensor without copying memory
64+
65+
Returns
66+
-------
67+
dlpack : DLPack tensor view of the array data
68+
"""
69+
cdef DLManagedTensor* dltensor
70+
if self.c_is_view != 0:
71+
raise ValueError("to_dlpack do not work with memory views")
72+
CALL(TVMArrayToDLPack(self.chandle, &dltensor))
73+
return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter)
74+
3875

3976
cdef c_make_array(void* chandle, is_view):
4077
ret = _CLASS_NDARRAY(None, is_view)
4178
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
4279
return ret
4380

81+
4482
cdef _TVM_COMPATS = ()
4583

4684
cdef _TVM_EXT_RET = {}

python/tvm/_ffi/ndarray.py

Lines changed: 5 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,17 @@
1717
if _FFI_MODE == "ctypes":
1818
raise ImportError()
1919
if sys.version_info >= (3, 0):
20-
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array
20+
from ._cy3.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
2121
from ._cy3.core import NDArrayBase as _NDArrayBase
2222
else:
23-
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array
23+
from ._cy2.core import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
2424
from ._cy2.core import NDArrayBase as _NDArrayBase
2525
except IMPORT_EXCEPT:
2626
# pylint: disable=wrong-import-position
27-
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array
27+
from ._ctypes.ndarray import _set_class_ndarray, _reg_extension, _make_array, _from_dlpack
2828
from ._ctypes.ndarray import NDArrayBase as _NDArrayBase
2929

3030

31-
TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
32-
_c_str_dltensor = c_str('dltensor')
33-
34-
35-
# used for PyCapsule manipulation
36-
if hasattr(ctypes, 'pythonapi'):
37-
ctypes.pythonapi.PyCapsule_GetName.restype = ctypes.c_char_p
38-
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
39-
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
40-
41-
4231
def context(dev_type, dev_id=0):
4332
"""Construct a TVM context with given device type and id.
4433
@@ -134,30 +123,14 @@ def from_dlpack(dltensor):
134123
Parameters
135124
----------
136125
dltensor : DLPack tensor
126+
Input DLManagedTensor, can only be consumed once.
137127
138128
Returns
139129
-------
140130
arr: tvm.nd.NDArray
141131
The array view of the tensor data.
142132
"""
143-
dltensor = ctypes.py_object(dltensor)
144-
name = ctypes.pythonapi.PyCapsule_GetName(dltensor)
145-
ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, name)
146-
handle = TVMArrayHandle()
147-
check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle)))
148-
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, None)
149-
return _make_array(handle, False)
150-
151-
152-
def _dlpack_deleter(pycapsule):
153-
pycapsule = ctypes.py_object(pycapsule)
154-
if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
155-
ptr = ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor)
156-
_LIB.TVMDLManagedTensorCallDeleter(ptr)
157-
ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0))
158-
159-
160-
_c_dlpack_deleter = TVMPyCapsuleDestructor(_dlpack_deleter)
133+
return _from_dlpack(dltensor)
161134

162135

163136
class NDArrayBase(_NDArrayBase):
@@ -308,17 +281,6 @@ def copyto(self, target):
308281
raise ValueError("Unsupported target type %s" % str(type(target)))
309282
return target
310283

311-
def to_dlpack(self):
312-
"""Produce an array from a DLPack Tensor without copying memory
313-
314-
Returns
315-
-------
316-
dlpack : DLPack tensor view of the array data
317-
"""
318-
handle = ctypes.c_void_p()
319-
check_call(_LIB.TVMArrayToDLPack(self.handle, ctypes.byref(handle)))
320-
return ctypes.pythonapi.PyCapsule_New(handle, _c_str_dltensor, _c_dlpack_deleter)
321-
322284

323285
def free_extension_handle(handle, type_code):
324286
"""Free c++ extension type handle

tests/scripts/task_python_nnvm.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ export PYTHONPATH=nnvm/python:python:topi/python
44
# to avoid openblas threading error
55
export OMP_NUM_THREADS=1
66

7+
# Rebuild cython
8+
make cython || exit -1
9+
make cython3 || exit -1
10+
711
echo "Running unittest..."
812
python -m nose -v nnvm/tests/python/unittest || exit -1
913
python3 -m nose -v nnvm/tests/python/unittest || exit -1

tests/scripts/task_python_topi.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
export PYTHONPATH=python:topi/python
22

3+
# Rebuild cython
4+
make cython || exit -1
5+
make cython3 || exit -1
6+
37
python -m nose -v topi/tests/python || exit -1
48
python3 -m nose -v topi/tests/python || exit -1

0 commit comments

Comments
 (0)