Skip to content

Commit 6023855

Browse files
authored
Add dlpack support (pytorch#7025)
1 parent 5e1d454 commit 6023855

13 files changed

+673
-6
lines changed

test/test_operations.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
parser.add_argument('--verbosity', type=int, default=0)
1414
FLAGS, leftovers = parser.parse_known_args()
1515
sys.argv = [sys.argv[0]] + leftovers
16+
from absl.testing import absltest, parameterized
1617

1718
# Normal imports section starts here.
1819
import collections
@@ -28,6 +29,11 @@
2829
import torch.nn as nn
2930
import torch.nn.functional as F
3031
import torch.optim as optim
32+
from torch.testing._internal.common_device_type import dtypes
33+
from torch.testing._internal.common_dtype import (
34+
all_types_and_complex_and,
35+
all_types_and,
36+
)
3137
import torch_xla
3238
import torch_xla.core.xla_builder as xb
3339
import torch_xla.core.xla_op_registry as xor
@@ -40,6 +46,7 @@
4046
import torch_xla.distributed.spmd as xs
4147
from torch_xla import runtime as xr
4248
import torch_xla.test.test_utils as xtu
49+
import torch_xla.utils.dlpack as xdlpack
4350
import torch_xla.utils.utils as xu
4451
import torch_xla.utils.serialization as xser
4552
import torch_xla.core.xla_model as xm
@@ -2464,6 +2471,139 @@ def test_unsafe_buffer_pointer(self):
24642471
self.assertGreaterEqual(buf_ptr_3, 0)
24652472

24662473

2474+
class TestDLPack(parameterized.TestCase):
2475+
2476+
def _test_dlpack_capsule_conversion_helper(self, xla_tensor):
2477+
dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule
2478+
xla_tensor2 = xdlpack.from_dlpack(dlpt)
2479+
2480+
self.assertEqual(xla_tensor.device, xla_tensor2.device)
2481+
self.assertTrue(torch.allclose(xla_tensor.cpu(), xla_tensor2.cpu()))
2482+
self.assertRaisesRegex(RuntimeError,
2483+
"DLTensor capsule can be consumed only once",
2484+
lambda: xdlpack.from_dlpack(dlpt))
2485+
2486+
self.assertEqual(
2487+
torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor),
2488+
torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor2))
2489+
2490+
@onlyIfTorchSupportsCUDA
2491+
@onlyIfPJRTDeviceIsCUDA
2492+
@parameterized.parameters(*all_types_and(torch.half, torch.bfloat16))
2493+
def test_dlpack_roundtrip_tensor(self, dtype):
2494+
xla_device = xm.xla_device()
2495+
# xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr
2496+
# xla_tensor_2 uses XLANativeFunctions::_to_copy
2497+
xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device)
2498+
self._test_dlpack_capsule_conversion_helper(xla_tensor_2)
2499+
2500+
# xla_tensor_3 uses arange_out IR node.
2501+
xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device())
2502+
xm.mark_step()
2503+
self._test_dlpack_capsule_conversion_helper(xla_tensor_3)
2504+
2505+
@onlyIfTorchSupportsCUDA
2506+
@onlyIfPJRTDeviceIsCUDA
2507+
@parameterized.parameters(*all_types_and_complex_and(torch.half,
2508+
torch.bfloat16,
2509+
torch.bool, torch.uint16,
2510+
torch.uint32,
2511+
torch.uint64))
2512+
def test_dlpack_roundtrip_scalar(self, dtype):
2513+
xla_device = xm.xla_device()
2514+
xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device)
2515+
# `mark_step` ensures xtensor->CurrentDataHandle() != nullptr
2516+
xm.mark_step()
2517+
self._test_dlpack_capsule_conversion_helper(xla_tensor_0)
2518+
2519+
xla_tensor_1 = torch.tensor(42, dtype=dtype).to(xla_device)
2520+
# xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr
2521+
self._test_dlpack_capsule_conversion_helper(xla_tensor_1)
2522+
2523+
@onlyIfTorchSupportsCUDA
2524+
@onlyIfPJRTDeviceIsCUDA
2525+
def test_dlpack_roundtrip_bool(self):
2526+
xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device())
2527+
self._test_dlpack_capsule_conversion_helper(xla_tensor)
2528+
2529+
@onlyIfTorchSupportsCUDA
2530+
@onlyIfPJRTDeviceIsCUDA
2531+
def test_dlpack_pytorch_cuda_to_xla(self):
2532+
t1_cuda = torch.arange(5).cuda()
2533+
dlt1 = torch.utils.dlpack.to_dlpack(t1_cuda)
2534+
xla_t1 = xdlpack.from_dlpack(dlt1)
2535+
self.assertEqual(xla_t1.device.type, 'xla')
2536+
self.assertEqual(xla_t1.device.index, t1_cuda.device.index)
2537+
t1_cuda[0] = t1_cuda[0] + 20
2538+
self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu()))
2539+
2540+
t2_cuda = torch.tensor(5).cuda()
2541+
dlt2 = torch.utils.dlpack.to_dlpack(t2_cuda)
2542+
xla_t2 = xdlpack.from_dlpack(dlt2)
2543+
self.assertEqual(xla_t2.device.type, 'xla')
2544+
self.assertEqual(xla_t2.device.index, t2_cuda.device.index)
2545+
t2_cuda.fill_(6)
2546+
self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu()))
2547+
2548+
cuda1 = torch.device('cuda:1')
2549+
t3_cuda = torch.tensor(5, device=cuda1)
2550+
dlt3 = torch.utils.dlpack.to_dlpack(t3_cuda)
2551+
xla_t3 = xdlpack.from_dlpack(dlt3)
2552+
self.assertEqual(xla_t3.device.type, 'xla')
2553+
self.assertEqual(
2554+
xla_t3.device.index,
2555+
t3_cuda.device.index,
2556+
msg='both value should 1. xla_t3.device should be xla:1.')
2557+
t3_cuda.fill_(6)
2558+
self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu()))
2559+
2560+
@onlyIfTorchSupportsCUDA
2561+
@onlyIfPJRTDeviceIsCUDA
2562+
def test_dlpack_xla_to_pytorch_cuda(self):
2563+
xla_t1 = torch.arange(5).to(xm.xla_device())
2564+
dlt1 = xdlpack.to_dlpack(xla_t1)
2565+
cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1)
2566+
self.assertEqual(cuda_t1.device.type, 'cuda')
2567+
self.assertEqual(cuda_t1.device.index, xla_t1.device.index)
2568+
cuda_t1[0] = cuda_t1[0] + 20
2569+
self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu()))
2570+
2571+
@onlyIfTorchSupportsCUDA
2572+
@onlyIfPJRTDeviceIsCUDA
2573+
def test_dlpack_non_default_layout(self):
2574+
cuda_t = torch.arange(25, device=torch.device('cuda')).reshape(5, 5)
2575+
2576+
t1 = cuda_t.t()
2577+
xla_t1 = xdlpack.from_dlpack(t1.__dlpack__())
2578+
self.assertEqual(xla_t1.device.type, 'xla')
2579+
self.assertEqual(xla_t1.device.index, 0)
2580+
self.assertTrue(torch.allclose(t1.cpu(), xla_t1.cpu()))
2581+
2582+
t2 = cuda_t[0]
2583+
xla_t2 = xdlpack.from_dlpack(t2.__dlpack__())
2584+
self.assertEqual(xla_t2.device.type, 'xla')
2585+
self.assertEqual(xla_t2.device.index, 0)
2586+
self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu()))
2587+
2588+
t3 = cuda_t[:, 0]
2589+
self.assertRaisesRegex(
2590+
RuntimeError,
2591+
r"Only DLPack tensors with trivial \(compact\) striding are supported",
2592+
lambda: xdlpack.from_dlpack(t3.__dlpack__()))
2593+
2594+
t4 = cuda_t[1, :]
2595+
xla_t4 = xdlpack.from_dlpack(t4.__dlpack__())
2596+
self.assertEqual(xla_t4.device.type, 'xla')
2597+
self.assertEqual(xla_t4.device.index, 0)
2598+
self.assertTrue(torch.allclose(t4.cpu(), xla_t4.cpu()))
2599+
2600+
t5 = cuda_t[1]
2601+
xla_t5 = xdlpack.from_dlpack(t5.__dlpack__())
2602+
self.assertEqual(xla_t5.device.type, 'xla')
2603+
self.assertEqual(xla_t5.device.index, 0)
2604+
self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu()))
2605+
2606+
24672607
class SimpleModelWithDropout(torch.nn.Module):
24682608

24692609
def __init__(self):

torch_xla/csrc/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ ptxla_cc_library(
4242
"cross_replica_reduces.cpp",
4343
"data_ops.cpp",
4444
"debug_util.cpp",
45+
"dl_convertor.cpp",
4546
"elementwise.cpp",
4647
"helpers.cpp",
4748
"ir_dump_util.cpp",
@@ -81,6 +82,7 @@ ptxla_cc_library(
8182
"cross_replica_reduces.h",
8283
"data_ops.h",
8384
"debug_util.h",
85+
"dl_convertor.h",
8486
"elementwise.h",
8587
"generated_file_include.h",
8688
"helpers.h",

0 commit comments

Comments
 (0)