Skip to content

Commit a88a2c4

Browse files
authored
[PT/XLA] Use on-demand capacity for v4 tests (#170)
Change-Id: I976a2b1c09d577392fb65f940ae03848dfca06a7
1 parent e3ff5fe commit a88a2c4

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

dags/pytorch_xla/nightly.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,14 @@ def torchvision():
5555
).run()
5656
resnet_v2_8 = task.TpuQueuedResourceTask(
5757
test_config.JSonnetTpuVmTest.from_pytorch(
58-
"pt-nightly-resnet50-pjrt-fake-v2-8-1vm"
58+
"pt-nightly-resnet50-pjrt-fake-v2-8-1vm",
59+
reserved=True,
5960
),
6061
US_CENTRAL1_C,
6162
).run()
6263
resnet_v3_8_tests = [
6364
task.TpuQueuedResourceTask(
64-
test_config.JSonnetTpuVmTest.from_pytorch(test),
65+
test_config.JSonnetTpuVmTest.from_pytorch(test, reserved=True),
6566
US_CENTRAL1_C,
6667
).run()
6768
for test in (
@@ -92,6 +93,7 @@ def torchvision():
9293
"pt-nightly-resnet50-pjrt-fake-v5litepod-4-1vm",
9394
network=V5_NETWORKS,
9495
subnetwork=V5E_SUBNETWORKS,
96+
reserved=True,
9597
),
9698
US_EAST1_C,
9799
).run()
@@ -137,7 +139,9 @@ def torchvision():
137139
@task_group(prefix_group_id=False)
138140
def huggingface():
139141
accelerate_v2_8 = task.TpuQueuedResourceTask(
140-
test_config.JSonnetTpuVmTest.from_pytorch("pt-nightly-accelerate-smoke-v2-8-1vm"),
142+
test_config.JSonnetTpuVmTest.from_pytorch(
143+
"pt-nightly-accelerate-smoke-v2-8-1vm", reserved=True
144+
),
141145
US_CENTRAL1_C,
142146
).run()
143147
accelerate_v4_8 = task.TpuQueuedResourceTask(

xlml/apis/test_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def _from_json_helper(
313313

314314
@staticmethod
315315
def from_jax(
316-
test_name: str, reserved_tpu: bool = True, network='default', subnetwork='default'
316+
test_name: str, reserved: bool = False, network='default', subnetwork='default'
317317
):
318318
"""Parses a compiled legacy JSonnet config test from `tests/jax`."""
319319
test = _load_compiled_jsonnet(test_name)
@@ -323,14 +323,14 @@ def from_jax(
323323
setup=test['setup'],
324324
exports='',
325325
test_command=['bash', '-c', test['runTest']],
326-
reserved=reserved_tpu,
326+
reserved=reserved,
327327
network=network,
328328
subnetwork=subnetwork,
329329
)
330330

331331
@staticmethod
332332
def from_pytorch(
333-
test_name: str, reserved_tpu: bool = True, network='default', subnetwork='default'
333+
test_name: str, reserved: bool = False, network='default', subnetwork='default'
334334
):
335335
"""Parses a compiled legacy JSonnet test config from `tests/pytorch`."""
336336
test = _load_compiled_jsonnet(test_name)
@@ -341,7 +341,7 @@ def from_pytorch(
341341
+ '\ncd ~\n' + test['tpuSettings']['tpuVmExtraSetup'],
342342
exports=test['tpuSettings']['tpuVmExports'],
343343
test_command=test['command'],
344-
reserved=reserved_tpu,
344+
reserved=reserved,
345345
network=network,
346346
subnetwork=subnetwork,
347347
)

0 commit comments

Comments
 (0)