|
13 | 13 | parser.add_argument('--verbosity', type=int, default=0)
|
14 | 14 | FLAGS, leftovers = parser.parse_known_args()
|
15 | 15 | sys.argv = [sys.argv[0]] + leftovers
|
| 16 | +from absl.testing import absltest, parameterized |
16 | 17 |
|
17 | 18 | # Normal imports section starts here.
|
18 | 19 | import collections
|
|
28 | 29 | import torch.nn as nn
|
29 | 30 | import torch.nn.functional as F
|
30 | 31 | import torch.optim as optim
|
| 32 | +from torch.testing._internal.common_device_type import dtypes |
| 33 | +from torch.testing._internal.common_dtype import ( |
| 34 | + all_types_and_complex_and, |
| 35 | + all_types_and, |
| 36 | +) |
31 | 37 | import torch_xla
|
32 | 38 | import torch_xla.core.xla_builder as xb
|
33 | 39 | import torch_xla.core.xla_op_registry as xor
|
|
40 | 46 | import torch_xla.distributed.spmd as xs
|
41 | 47 | from torch_xla import runtime as xr
|
42 | 48 | import torch_xla.test.test_utils as xtu
|
| 49 | +import torch_xla.utils.dlpack as xdlpack |
43 | 50 | import torch_xla.utils.utils as xu
|
44 | 51 | import torch_xla.utils.serialization as xser
|
45 | 52 | import torch_xla.core.xla_model as xm
|
@@ -2464,6 +2471,139 @@ def test_unsafe_buffer_pointer(self):
|
2464 | 2471 | self.assertGreaterEqual(buf_ptr_3, 0)
|
2465 | 2472 |
|
2466 | 2473 |
|
| 2474 | +class TestDLPack(parameterized.TestCase): |
| 2475 | + |
| 2476 | + def _test_dlpack_capsule_conversion_helper(self, xla_tensor): |
| 2477 | + dlpt = xdlpack.to_dlpack(xla_tensor) # dlpt1 has type PyCapsule |
| 2478 | + xla_tensor2 = xdlpack.from_dlpack(dlpt) |
| 2479 | + |
| 2480 | + self.assertEqual(xla_tensor.device, xla_tensor2.device) |
| 2481 | + self.assertTrue(torch.allclose(xla_tensor.cpu(), xla_tensor2.cpu())) |
| 2482 | + self.assertRaisesRegex(RuntimeError, |
| 2483 | + "DLTensor capsule can be consumed only once", |
| 2484 | + lambda: xdlpack.from_dlpack(dlpt)) |
| 2485 | + |
| 2486 | + self.assertEqual( |
| 2487 | + torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor), |
| 2488 | + torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor2)) |
| 2489 | + |
| 2490 | + @onlyIfTorchSupportsCUDA |
| 2491 | + @onlyIfPJRTDeviceIsCUDA |
| 2492 | + @parameterized.parameters(*all_types_and(torch.half, torch.bfloat16)) |
| 2493 | + def test_dlpack_roundtrip_tensor(self, dtype): |
| 2494 | + xla_device = xm.xla_device() |
| 2495 | + # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr |
| 2496 | + # xla_tensor_2 uses XLANativeFunctions::_to_copy |
| 2497 | + xla_tensor_2 = torch.arange(5, dtype=dtype).to(xla_device) |
| 2498 | + self._test_dlpack_capsule_conversion_helper(xla_tensor_2) |
| 2499 | + |
| 2500 | + # xla_tensor_3 uses arange_out IR node. |
| 2501 | + xla_tensor_3 = torch.arange(5, dtype=dtype, device=xm.xla_device()) |
| 2502 | + xm.mark_step() |
| 2503 | + self._test_dlpack_capsule_conversion_helper(xla_tensor_3) |
| 2504 | + |
| 2505 | + @onlyIfTorchSupportsCUDA |
| 2506 | + @onlyIfPJRTDeviceIsCUDA |
| 2507 | + @parameterized.parameters(*all_types_and_complex_and(torch.half, |
| 2508 | + torch.bfloat16, |
| 2509 | + torch.bool, torch.uint16, |
| 2510 | + torch.uint32, |
| 2511 | + torch.uint64)) |
| 2512 | + def test_dlpack_roundtrip_scalar(self, dtype): |
| 2513 | + xla_device = xm.xla_device() |
| 2514 | + xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) |
| 2515 | + # `mark_step` ensures xtensor->CurrentDataHandle() != nullptr |
| 2516 | + xm.mark_step() |
| 2517 | + self._test_dlpack_capsule_conversion_helper(xla_tensor_0) |
| 2518 | + |
| 2519 | + xla_tensor_1 = torch.tensor(42, dtype=dtype).to(xla_device) |
| 2520 | + # xtensor->CurrentDataHandle() == nullptr but xtensor->CurrentIrValue().node != nullptr and device_data != nullptr |
| 2521 | + self._test_dlpack_capsule_conversion_helper(xla_tensor_1) |
| 2522 | + |
| 2523 | + @onlyIfTorchSupportsCUDA |
| 2524 | + @onlyIfPJRTDeviceIsCUDA |
| 2525 | + def test_dlpack_roundtrip_bool(self): |
| 2526 | + xla_tensor = torch.ones(1, dtype=torch.bool).to(xm.xla_device()) |
| 2527 | + self._test_dlpack_capsule_conversion_helper(xla_tensor) |
| 2528 | + |
| 2529 | + @onlyIfTorchSupportsCUDA |
| 2530 | + @onlyIfPJRTDeviceIsCUDA |
| 2531 | + def test_dlpack_pytorch_cuda_to_xla(self): |
| 2532 | + t1_cuda = torch.arange(5).cuda() |
| 2533 | + dlt1 = torch.utils.dlpack.to_dlpack(t1_cuda) |
| 2534 | + xla_t1 = xdlpack.from_dlpack(dlt1) |
| 2535 | + self.assertEqual(xla_t1.device.type, 'xla') |
| 2536 | + self.assertEqual(xla_t1.device.index, t1_cuda.device.index) |
| 2537 | + t1_cuda[0] = t1_cuda[0] + 20 |
| 2538 | + self.assertTrue(torch.allclose(xla_t1.cpu(), t1_cuda.cpu())) |
| 2539 | + |
| 2540 | + t2_cuda = torch.tensor(5).cuda() |
| 2541 | + dlt2 = torch.utils.dlpack.to_dlpack(t2_cuda) |
| 2542 | + xla_t2 = xdlpack.from_dlpack(dlt2) |
| 2543 | + self.assertEqual(xla_t2.device.type, 'xla') |
| 2544 | + self.assertEqual(xla_t2.device.index, t2_cuda.device.index) |
| 2545 | + t2_cuda.fill_(6) |
| 2546 | + self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu())) |
| 2547 | + |
| 2548 | + cuda1 = torch.device('cuda:1') |
| 2549 | + t3_cuda = torch.tensor(5, device=cuda1) |
| 2550 | + dlt3 = torch.utils.dlpack.to_dlpack(t3_cuda) |
| 2551 | + xla_t3 = xdlpack.from_dlpack(dlt3) |
| 2552 | + self.assertEqual(xla_t3.device.type, 'xla') |
| 2553 | + self.assertEqual( |
| 2554 | + xla_t3.device.index, |
| 2555 | + t3_cuda.device.index, |
| 2556 | + msg='both value should 1. xla_t3.device should be xla:1.') |
| 2557 | + t3_cuda.fill_(6) |
| 2558 | + self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu())) |
| 2559 | + |
| 2560 | + @onlyIfTorchSupportsCUDA |
| 2561 | + @onlyIfPJRTDeviceIsCUDA |
| 2562 | + def test_dlpack_xla_to_pytorch_cuda(self): |
| 2563 | + xla_t1 = torch.arange(5).to(xm.xla_device()) |
| 2564 | + dlt1 = xdlpack.to_dlpack(xla_t1) |
| 2565 | + cuda_t1 = torch.utils.dlpack.from_dlpack(dlt1) |
| 2566 | + self.assertEqual(cuda_t1.device.type, 'cuda') |
| 2567 | + self.assertEqual(cuda_t1.device.index, xla_t1.device.index) |
| 2568 | + cuda_t1[0] = cuda_t1[0] + 20 |
| 2569 | + self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu())) |
| 2570 | + |
| 2571 | + @onlyIfTorchSupportsCUDA |
| 2572 | + @onlyIfPJRTDeviceIsCUDA |
| 2573 | + def test_dlpack_non_default_layout(self): |
| 2574 | + cuda_t = torch.arange(25, device=torch.device('cuda')).reshape(5, 5) |
| 2575 | + |
| 2576 | + t1 = cuda_t.t() |
| 2577 | + xla_t1 = xdlpack.from_dlpack(t1.__dlpack__()) |
| 2578 | + self.assertEqual(xla_t1.device.type, 'xla') |
| 2579 | + self.assertEqual(xla_t1.device.index, 0) |
| 2580 | + self.assertTrue(torch.allclose(t1.cpu(), xla_t1.cpu())) |
| 2581 | + |
| 2582 | + t2 = cuda_t[0] |
| 2583 | + xla_t2 = xdlpack.from_dlpack(t2.__dlpack__()) |
| 2584 | + self.assertEqual(xla_t2.device.type, 'xla') |
| 2585 | + self.assertEqual(xla_t2.device.index, 0) |
| 2586 | + self.assertTrue(torch.allclose(t2.cpu(), xla_t2.cpu())) |
| 2587 | + |
| 2588 | + t3 = cuda_t[:, 0] |
| 2589 | + self.assertRaisesRegex( |
| 2590 | + RuntimeError, |
| 2591 | + r"Only DLPack tensors with trivial \(compact\) striding are supported", |
| 2592 | + lambda: xdlpack.from_dlpack(t3.__dlpack__())) |
| 2593 | + |
| 2594 | + t4 = cuda_t[1, :] |
| 2595 | + xla_t4 = xdlpack.from_dlpack(t4.__dlpack__()) |
| 2596 | + self.assertEqual(xla_t4.device.type, 'xla') |
| 2597 | + self.assertEqual(xla_t4.device.index, 0) |
| 2598 | + self.assertTrue(torch.allclose(t4.cpu(), xla_t4.cpu())) |
| 2599 | + |
| 2600 | + t5 = cuda_t[1] |
| 2601 | + xla_t5 = xdlpack.from_dlpack(t5.__dlpack__()) |
| 2602 | + self.assertEqual(xla_t5.device.type, 'xla') |
| 2603 | + self.assertEqual(xla_t5.device.index, 0) |
| 2604 | + self.assertTrue(torch.allclose(t5.cpu(), xla_t5.cpu())) |
| 2605 | + |
| 2606 | + |
2467 | 2607 | class SimpleModelWithDropout(torch.nn.Module):
|
2468 | 2608 |
|
2469 | 2609 | def __init__(self):
|
|
0 commit comments